feat(kis): cache tokens in sqlite and add inspector

This commit is contained in:
2026-06-23 18:00:34 +09:00
parent a343db5812
commit 357d2507da
3 changed files with 207 additions and 27 deletions
+37 -27
View File
@@ -33,6 +33,8 @@ if str(ROOT) not in sys.path:
REAL_DOMAIN = "https://openapi.koreainvestment.com:9443"
MOCK_DOMAIN = "https://openapivts.koreainvestment.com:29443"
TOKEN_CACHE_DIR = ROOT / "Temp"
TOKEN_CACHE_DB_NAME = "kis_tokens.db"
TOKEN_REFRESH_SKEW_MINUTES = 10
def _requests():
@@ -118,9 +120,13 @@ class KisCredentials:
import sqlite3
def _token_db_path() -> Path:
db_dir = ROOT / "src" / "quant_engine"
db_dir.mkdir(parents=True, exist_ok=True)
return db_dir / "kis_data_collection.db"
import os
override = os.environ.get("KIS_TOKEN_DB_PATH", "").strip()
if override:
return Path(override)
TOKEN_CACHE_DIR.mkdir(parents=True, exist_ok=True)
return TOKEN_CACHE_DIR / TOKEN_CACHE_DB_NAME
def _init_token_db(conn: sqlite3.Connection) -> None:
@@ -140,50 +146,54 @@ def _init_token_db(conn: sqlite3.Connection) -> None:
def _issue_or_reuse_token(creds: KisCredentials) -> str:
"""KIS는 토큰 발급 빈도를 제한한다 — 만료 전까지 DB 캐시 재사용 필수."""
db_path = _token_db_path()
# 1. DB에서 기존 토큰 및 만료 시각 조회
with sqlite3.connect(db_path) as conn:
db_path.parent.mkdir(parents=True, exist_ok=True)
with sqlite3.connect(db_path, timeout=30) as conn:
_init_token_db(conn)
conn.execute("BEGIN IMMEDIATE")
row = conn.execute(
"SELECT access_token, expires_at FROM kis_tokens WHERE account = ?",
(creds.account,)
(creds.account,),
).fetchone()
if row:
token, expires_at_str = row
try:
expires_at = dt.datetime.fromisoformat(expires_at_str)
# 만료 시간 10분 전까지 재사용 가능 여부 검사
if dt.datetime.now(dt.timezone.utc) < expires_at - dt.timedelta(minutes=10):
if dt.datetime.now(dt.timezone.utc) < expires_at - dt.timedelta(minutes=TOKEN_REFRESH_SKEW_MINUTES):
conn.commit()
return token
except ValueError:
pass
# 2. 토큰이 만료되었거나 없을 시 KIS API로 새로 발급 요청
requests = _requests()
resp = requests.post(
f"{creds.domain}/oauth2/tokenP",
json={"grant_type": "client_credentials", "appkey": creds.app_key, "appsecret": creds.app_secret},
timeout=15,
)
resp.raise_for_status()
body = resp.json()
access_token = body["access_token"]
expires_in_sec = int(body.get("expires_in", 86400))
expires_at = dt.datetime.now(dt.timezone.utc) + dt.timedelta(seconds=expires_in_sec)
# 3. 새로운 토큰 정보를 DB에 안전하게 업서트
with sqlite3.connect(db_path) as conn:
# 2. 토큰이 만료되었거나 없을 시 KIS API로 새로 발급 요청
requests = _requests()
resp = requests.post(
f"{creds.domain}/oauth2/tokenP",
json={"grant_type": "client_credentials", "appkey": creds.app_key, "appsecret": creds.app_secret},
timeout=15,
)
resp.raise_for_status()
body = resp.json()
access_token = body["access_token"]
expires_in_sec = int(body.get("expires_in", 86400))
expires_at = dt.datetime.now(dt.timezone.utc) + dt.timedelta(seconds=expires_in_sec)
# 3. 새로운 토큰 정보를 DB에 안전하게 업서트
conn.execute(
"""
INSERT OR REPLACE INTO kis_tokens (account, access_token, expires_at, updated_at)
VALUES (?, ?, ?, ?)
""",
(creds.account, access_token, expires_at.isoformat(), dt.datetime.now(dt.timezone.utc).isoformat())
(
creds.account,
access_token,
expires_at.isoformat(),
dt.datetime.now(dt.timezone.utc).isoformat(),
),
)
conn.commit()
return access_token
return access_token
def _send_request(creds: KisCredentials, path: str, tr_id: str, params: dict[str, Any]) -> dict[str, Any]:
+98
View File
@@ -4,6 +4,9 @@ import sys
import unittest
from pathlib import Path
from unittest.mock import patch
import warnings
warnings.simplefilter("ignore", ResourceWarning)
ROOT = Path(__file__).resolve().parents[2]
if str(ROOT) not in sys.path:
@@ -154,6 +157,101 @@ class TestKisApiClientV1(unittest.TestCase):
shutil.rmtree(tmp_dir, ignore_errors=True)
def test_issue_or_reuse_token_honors_token_db_override(self):
import tempfile
import shutil
import sqlite3
from src.quant_engine.kis_api_client_v1 import _issue_or_reuse_token, KisCredentials
tmp_dir = tempfile.mkdtemp()
override_db = Path(tmp_dir) / "custom_kis_tokens.db"
with patch.dict("os.environ", {"KIS_TOKEN_DB_PATH": str(override_db)}), \
patch("src.quant_engine.kis_api_client_v1._requests") as mock_requests:
creds = KisCredentials(app_key="k", app_secret="s", account="mock")
mock_resp = mock_requests.return_value.post.return_value
mock_resp.raise_for_status.return_value = None
mock_resp.json.return_value = {
"access_token": "override-token-789",
"expires_in": "3600",
}
token = _issue_or_reuse_token(creds)
self.assertEqual(token, "override-token-789")
self.assertTrue(override_db.exists())
with sqlite3.connect(override_db) as conn:
row = conn.execute(
"SELECT access_token FROM kis_tokens WHERE account = 'mock'"
).fetchone()
self.assertIsNotNone(row)
self.assertEqual(row[0], "override-token-789")
shutil.rmtree(tmp_dir, ignore_errors=True)
def test_issue_or_reuse_token_serializes_concurrent_refresh(self):
import tempfile
import shutil
import sqlite3
import threading
import time
from src.quant_engine.kis_api_client_v1 import _issue_or_reuse_token, KisCredentials
tmp_dir = tempfile.mkdtemp()
override_db = Path(tmp_dir) / "kis_tokens.db"
barrier = threading.Barrier(2)
post_calls = []
class _Response:
def raise_for_status(self):
return None
def json(self):
return {
"access_token": "concurrent-token-001",
"expires_in": "3600",
}
def _post(*args, **kwargs):
post_calls.append(time.time())
time.sleep(0.2)
return _Response()
with patch.dict("os.environ", {"KIS_TOKEN_DB_PATH": str(override_db)}), \
patch("src.quant_engine.kis_api_client_v1._requests") as mock_requests:
mock_requests.return_value.post.side_effect = _post
creds = KisCredentials(app_key="k", app_secret="s", account="mock")
results: list[str] = []
errors: list[BaseException] = []
def worker() -> None:
try:
barrier.wait(timeout=5)
results.append(_issue_or_reuse_token(creds))
except BaseException as exc: # pragma: no cover - test harness diagnostic
errors.append(exc)
t1 = threading.Thread(target=worker)
t2 = threading.Thread(target=worker)
t1.start()
t2.start()
t1.join(timeout=10)
t2.join(timeout=10)
self.assertEqual(errors, [])
self.assertEqual(results, ["concurrent-token-001", "concurrent-token-001"])
self.assertEqual(len(post_calls), 1)
with sqlite3.connect(override_db) as conn:
row = conn.execute(
"SELECT access_token FROM kis_tokens WHERE account = 'mock'"
).fetchone()
self.assertIsNotNone(row)
self.assertEqual(row[0], "concurrent-token-001")
shutil.rmtree(tmp_dir, ignore_errors=True)
if __name__ == "__main__":
unittest.main()
+72
View File
@@ -0,0 +1,72 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import sqlite3
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from src.quant_engine.kis_api_client_v1 import TOKEN_REFRESH_SKEW_MINUTES, _token_db_path
def main() -> int:
parser = argparse.ArgumentParser(description="Inspect the KIS Open API token cache.")
parser.add_argument("--account", default=None, help="Optional account filter (real|mock)")
parser.add_argument("--json", action="store_true", help="Print JSON output")
args = parser.parse_args()
db_path = _token_db_path()
if not db_path.exists():
payload = {"gate": "DATA_MISSING", "db_path": str(db_path), "rows": []}
print(json.dumps(payload, ensure_ascii=False, indent=2) if args.json else f"DATA_MISSING: {db_path}")
return 0
conn = sqlite3.connect(db_path)
try:
rows = conn.execute(
"SELECT account, access_token, expires_at, updated_at FROM kis_tokens ORDER BY account"
).fetchall()
finally:
conn.close()
data = []
for account, access_token, expires_at, updated_at in rows:
if args.account and account != args.account:
continue
data.append(
{
"account": account,
"access_token_prefix": f"{access_token[:6]}..." if access_token else "",
"expires_at": expires_at,
"updated_at": updated_at,
}
)
payload = {
"gate": "PASS" if data else "DATA_MISSING",
"db_path": str(db_path),
"refresh_skew_minutes": TOKEN_REFRESH_SKEW_MINUTES,
"row_count": len(data),
"rows": data,
}
if args.json:
print(json.dumps(payload, ensure_ascii=False, indent=2))
else:
print(f"db_path: {payload['db_path']}")
print(f"refresh_skew_minutes: {payload['refresh_skew_minutes']}")
print(f"row_count: {payload['row_count']}")
for row in data:
print(
f"- account={row['account']} token={row['access_token_prefix']} "
f"expires_at={row['expires_at']} updated_at={row['updated_at']}"
)
return 0
if __name__ == "__main__":
raise SystemExit(main())