diff --git a/src/quant_engine/kis_api_client_v1.py b/src/quant_engine/kis_api_client_v1.py index 3a0e607..898b419 100644 --- a/src/quant_engine/kis_api_client_v1.py +++ b/src/quant_engine/kis_api_client_v1.py @@ -33,6 +33,8 @@ if str(ROOT) not in sys.path: REAL_DOMAIN = "https://openapi.koreainvestment.com:9443" MOCK_DOMAIN = "https://openapivts.koreainvestment.com:29443" TOKEN_CACHE_DIR = ROOT / "Temp" +TOKEN_CACHE_DB_NAME = "kis_tokens.db" +TOKEN_REFRESH_SKEW_MINUTES = 10 def _requests(): @@ -118,9 +120,13 @@ class KisCredentials: import sqlite3 def _token_db_path() -> Path: - db_dir = ROOT / "src" / "quant_engine" - db_dir.mkdir(parents=True, exist_ok=True) - return db_dir / "kis_data_collection.db" + import os + + override = os.environ.get("KIS_TOKEN_DB_PATH", "").strip() + if override: + return Path(override) + TOKEN_CACHE_DIR.mkdir(parents=True, exist_ok=True) + return TOKEN_CACHE_DIR / TOKEN_CACHE_DB_NAME def _init_token_db(conn: sqlite3.Connection) -> None: @@ -140,50 +146,54 @@ def _init_token_db(conn: sqlite3.Connection) -> None: def _issue_or_reuse_token(creds: KisCredentials) -> str: """KIS는 토큰 발급 빈도를 제한한다 — 만료 전까지 DB 캐시 재사용 필수.""" db_path = _token_db_path() - - # 1. DB에서 기존 토큰 및 만료 시각 조회 - with sqlite3.connect(db_path) as conn: + db_path.parent.mkdir(parents=True, exist_ok=True) + + with sqlite3.connect(db_path, timeout=30) as conn: _init_token_db(conn) + conn.execute("BEGIN IMMEDIATE") row = conn.execute( "SELECT access_token, expires_at FROM kis_tokens WHERE account = ?", - (creds.account,) + (creds.account,), ).fetchone() - if row: token, expires_at_str = row try: expires_at = dt.datetime.fromisoformat(expires_at_str) # 만료 시간 10분 전까지 재사용 가능 여부 검사 - if dt.datetime.now(dt.timezone.utc) < expires_at - dt.timedelta(minutes=10): + if dt.datetime.now(dt.timezone.utc) < expires_at - dt.timedelta(minutes=TOKEN_REFRESH_SKEW_MINUTES): + conn.commit() return token except ValueError: pass - # 2. 토큰이 만료되었거나 없을 시 KIS API로 새로 발급 요청 - requests = _requests() - resp = requests.post( - f"{creds.domain}/oauth2/tokenP", - json={"grant_type": "client_credentials", "appkey": creds.app_key, "appsecret": creds.app_secret}, - timeout=15, - ) - resp.raise_for_status() - body = resp.json() - access_token = body["access_token"] - expires_in_sec = int(body.get("expires_in", 86400)) - expires_at = dt.datetime.now(dt.timezone.utc) + dt.timedelta(seconds=expires_in_sec) - - # 3. 새로운 토큰 정보를 DB에 안전하게 업서트 - with sqlite3.connect(db_path) as conn: + # 2. 토큰이 만료되었거나 없을 시 KIS API로 새로 발급 요청 + requests = _requests() + resp = requests.post( + f"{creds.domain}/oauth2/tokenP", + json={"grant_type": "client_credentials", "appkey": creds.app_key, "appsecret": creds.app_secret}, + timeout=15, + ) + resp.raise_for_status() + body = resp.json() + access_token = body["access_token"] + expires_in_sec = int(body.get("expires_in", 86400)) + expires_at = dt.datetime.now(dt.timezone.utc) + dt.timedelta(seconds=expires_in_sec) + + # 3. 새로운 토큰 정보를 DB에 안전하게 업서트 conn.execute( """ INSERT OR REPLACE INTO kis_tokens (account, access_token, expires_at, updated_at) VALUES (?, ?, ?, ?) """, - (creds.account, access_token, expires_at.isoformat(), dt.datetime.now(dt.timezone.utc).isoformat()) + ( + creds.account, + access_token, + expires_at.isoformat(), + dt.datetime.now(dt.timezone.utc).isoformat(), + ), ) conn.commit() - - return access_token + return access_token def _send_request(creds: KisCredentials, path: str, tr_id: str, params: dict[str, Any]) -> dict[str, Any]: diff --git a/tests/unit/test_kis_api_client_v1.py b/tests/unit/test_kis_api_client_v1.py index 57fdc6d..c690526 100644 --- a/tests/unit/test_kis_api_client_v1.py +++ b/tests/unit/test_kis_api_client_v1.py @@ -4,6 +4,9 @@ import sys import unittest from pathlib import Path from unittest.mock import patch +import warnings + +warnings.simplefilter("ignore", ResourceWarning) ROOT = Path(__file__).resolve().parents[2] if str(ROOT) not in sys.path: @@ -154,6 +157,101 @@ class TestKisApiClientV1(unittest.TestCase): shutil.rmtree(tmp_dir, ignore_errors=True) + def test_issue_or_reuse_token_honors_token_db_override(self): + import tempfile + import shutil + import sqlite3 + from src.quant_engine.kis_api_client_v1 import _issue_or_reuse_token, KisCredentials + + tmp_dir = tempfile.mkdtemp() + override_db = Path(tmp_dir) / "custom_kis_tokens.db" + + with patch.dict("os.environ", {"KIS_TOKEN_DB_PATH": str(override_db)}), \ + patch("src.quant_engine.kis_api_client_v1._requests") as mock_requests: + creds = KisCredentials(app_key="k", app_secret="s", account="mock") + mock_resp = mock_requests.return_value.post.return_value + mock_resp.raise_for_status.return_value = None + mock_resp.json.return_value = { + "access_token": "override-token-789", + "expires_in": "3600", + } + + token = _issue_or_reuse_token(creds) + self.assertEqual(token, "override-token-789") + self.assertTrue(override_db.exists()) + + with sqlite3.connect(override_db) as conn: + row = conn.execute( + "SELECT access_token FROM kis_tokens WHERE account = 'mock'" + ).fetchone() + self.assertIsNotNone(row) + self.assertEqual(row[0], "override-token-789") + + shutil.rmtree(tmp_dir, ignore_errors=True) + + def test_issue_or_reuse_token_serializes_concurrent_refresh(self): + import tempfile + import shutil + import sqlite3 + import threading + import time + from src.quant_engine.kis_api_client_v1 import _issue_or_reuse_token, KisCredentials + + tmp_dir = tempfile.mkdtemp() + override_db = Path(tmp_dir) / "kis_tokens.db" + barrier = threading.Barrier(2) + post_calls = [] + + class _Response: + def raise_for_status(self): + return None + + def json(self): + return { + "access_token": "concurrent-token-001", + "expires_in": "3600", + } + + def _post(*args, **kwargs): + post_calls.append(time.time()) + time.sleep(0.2) + return _Response() + + with patch.dict("os.environ", {"KIS_TOKEN_DB_PATH": str(override_db)}), \ + patch("src.quant_engine.kis_api_client_v1._requests") as mock_requests: + mock_requests.return_value.post.side_effect = _post + creds = KisCredentials(app_key="k", app_secret="s", account="mock") + + results: list[str] = [] + errors: list[BaseException] = [] + + def worker() -> None: + try: + barrier.wait(timeout=5) + results.append(_issue_or_reuse_token(creds)) + except BaseException as exc: # pragma: no cover - test harness diagnostic + errors.append(exc) + + t1 = threading.Thread(target=worker) + t2 = threading.Thread(target=worker) + t1.start() + t2.start() + t1.join(timeout=10) + t2.join(timeout=10) + + self.assertEqual(errors, []) + self.assertEqual(results, ["concurrent-token-001", "concurrent-token-001"]) + self.assertEqual(len(post_calls), 1) + + with sqlite3.connect(override_db) as conn: + row = conn.execute( + "SELECT access_token FROM kis_tokens WHERE account = 'mock'" + ).fetchone() + self.assertIsNotNone(row) + self.assertEqual(row[0], "concurrent-token-001") + + shutil.rmtree(tmp_dir, ignore_errors=True) + if __name__ == "__main__": unittest.main() diff --git a/tools/inspect_kis_token_cache_v1.py b/tools/inspect_kis_token_cache_v1.py new file mode 100644 index 0000000..853d0ca --- /dev/null +++ b/tools/inspect_kis_token_cache_v1.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import json +import sqlite3 +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +from src.quant_engine.kis_api_client_v1 import TOKEN_REFRESH_SKEW_MINUTES, _token_db_path + + +def main() -> int: + parser = argparse.ArgumentParser(description="Inspect the KIS Open API token cache.") + parser.add_argument("--account", default=None, help="Optional account filter (real|mock)") + parser.add_argument("--json", action="store_true", help="Print JSON output") + args = parser.parse_args() + + db_path = _token_db_path() + if not db_path.exists(): + payload = {"gate": "DATA_MISSING", "db_path": str(db_path), "rows": []} + print(json.dumps(payload, ensure_ascii=False, indent=2) if args.json else f"DATA_MISSING: {db_path}") + return 0 + + conn = sqlite3.connect(db_path) + try: + rows = conn.execute( + "SELECT account, access_token, expires_at, updated_at FROM kis_tokens ORDER BY account" + ).fetchall() + finally: + conn.close() + + data = [] + for account, access_token, expires_at, updated_at in rows: + if args.account and account != args.account: + continue + data.append( + { + "account": account, + "access_token_prefix": f"{access_token[:6]}..." if access_token else "", + "expires_at": expires_at, + "updated_at": updated_at, + } + ) + + payload = { + "gate": "PASS" if data else "DATA_MISSING", + "db_path": str(db_path), + "refresh_skew_minutes": TOKEN_REFRESH_SKEW_MINUTES, + "row_count": len(data), + "rows": data, + } + if args.json: + print(json.dumps(payload, ensure_ascii=False, indent=2)) + else: + print(f"db_path: {payload['db_path']}") + print(f"refresh_skew_minutes: {payload['refresh_skew_minutes']}") + print(f"row_count: {payload['row_count']}") + for row in data: + print( + f"- account={row['account']} token={row['access_token_prefix']} " + f"expires_at={row['expires_at']} updated_at={row['updated_at']}" + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())