feat(kis): cache tokens in sqlite and add inspector
This commit is contained in:
@@ -33,6 +33,8 @@ if str(ROOT) not in sys.path:
|
||||
REAL_DOMAIN = "https://openapi.koreainvestment.com:9443"
|
||||
MOCK_DOMAIN = "https://openapivts.koreainvestment.com:29443"
|
||||
TOKEN_CACHE_DIR = ROOT / "Temp"
|
||||
TOKEN_CACHE_DB_NAME = "kis_tokens.db"
|
||||
TOKEN_REFRESH_SKEW_MINUTES = 10
|
||||
|
||||
|
||||
def _requests():
|
||||
@@ -118,9 +120,13 @@ class KisCredentials:
|
||||
import sqlite3
|
||||
|
||||
def _token_db_path() -> Path:
|
||||
db_dir = ROOT / "src" / "quant_engine"
|
||||
db_dir.mkdir(parents=True, exist_ok=True)
|
||||
return db_dir / "kis_data_collection.db"
|
||||
import os
|
||||
|
||||
override = os.environ.get("KIS_TOKEN_DB_PATH", "").strip()
|
||||
if override:
|
||||
return Path(override)
|
||||
TOKEN_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
return TOKEN_CACHE_DIR / TOKEN_CACHE_DB_NAME
|
||||
|
||||
|
||||
def _init_token_db(conn: sqlite3.Connection) -> None:
|
||||
@@ -140,21 +146,22 @@ def _init_token_db(conn: sqlite3.Connection) -> None:
|
||||
def _issue_or_reuse_token(creds: KisCredentials) -> str:
|
||||
"""KIS는 토큰 발급 빈도를 제한한다 — 만료 전까지 DB 캐시 재사용 필수."""
|
||||
db_path = _token_db_path()
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 1. DB에서 기존 토큰 및 만료 시각 조회
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
with sqlite3.connect(db_path, timeout=30) as conn:
|
||||
_init_token_db(conn)
|
||||
conn.execute("BEGIN IMMEDIATE")
|
||||
row = conn.execute(
|
||||
"SELECT access_token, expires_at FROM kis_tokens WHERE account = ?",
|
||||
(creds.account,)
|
||||
(creds.account,),
|
||||
).fetchone()
|
||||
|
||||
if row:
|
||||
token, expires_at_str = row
|
||||
try:
|
||||
expires_at = dt.datetime.fromisoformat(expires_at_str)
|
||||
# 만료 시간 10분 전까지 재사용 가능 여부 검사
|
||||
if dt.datetime.now(dt.timezone.utc) < expires_at - dt.timedelta(minutes=10):
|
||||
if dt.datetime.now(dt.timezone.utc) < expires_at - dt.timedelta(minutes=TOKEN_REFRESH_SKEW_MINUTES):
|
||||
conn.commit()
|
||||
return token
|
||||
except ValueError:
|
||||
pass
|
||||
@@ -173,16 +180,19 @@ def _issue_or_reuse_token(creds: KisCredentials) -> str:
|
||||
expires_at = dt.datetime.now(dt.timezone.utc) + dt.timedelta(seconds=expires_in_sec)
|
||||
|
||||
# 3. 새로운 토큰 정보를 DB에 안전하게 업서트
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO kis_tokens (account, access_token, expires_at, updated_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(creds.account, access_token, expires_at.isoformat(), dt.datetime.now(dt.timezone.utc).isoformat())
|
||||
(
|
||||
creds.account,
|
||||
access_token,
|
||||
expires_at.isoformat(),
|
||||
dt.datetime.now(dt.timezone.utc).isoformat(),
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
return access_token
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,9 @@ import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("ignore", ResourceWarning)
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(ROOT) not in sys.path:
|
||||
@@ -154,6 +157,101 @@ class TestKisApiClientV1(unittest.TestCase):
|
||||
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
|
||||
def test_issue_or_reuse_token_honors_token_db_override(self):
|
||||
import tempfile
|
||||
import shutil
|
||||
import sqlite3
|
||||
from src.quant_engine.kis_api_client_v1 import _issue_or_reuse_token, KisCredentials
|
||||
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
override_db = Path(tmp_dir) / "custom_kis_tokens.db"
|
||||
|
||||
with patch.dict("os.environ", {"KIS_TOKEN_DB_PATH": str(override_db)}), \
|
||||
patch("src.quant_engine.kis_api_client_v1._requests") as mock_requests:
|
||||
creds = KisCredentials(app_key="k", app_secret="s", account="mock")
|
||||
mock_resp = mock_requests.return_value.post.return_value
|
||||
mock_resp.raise_for_status.return_value = None
|
||||
mock_resp.json.return_value = {
|
||||
"access_token": "override-token-789",
|
||||
"expires_in": "3600",
|
||||
}
|
||||
|
||||
token = _issue_or_reuse_token(creds)
|
||||
self.assertEqual(token, "override-token-789")
|
||||
self.assertTrue(override_db.exists())
|
||||
|
||||
with sqlite3.connect(override_db) as conn:
|
||||
row = conn.execute(
|
||||
"SELECT access_token FROM kis_tokens WHERE account = 'mock'"
|
||||
).fetchone()
|
||||
self.assertIsNotNone(row)
|
||||
self.assertEqual(row[0], "override-token-789")
|
||||
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
|
||||
def test_issue_or_reuse_token_serializes_concurrent_refresh(self):
|
||||
import tempfile
|
||||
import shutil
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from src.quant_engine.kis_api_client_v1 import _issue_or_reuse_token, KisCredentials
|
||||
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
override_db = Path(tmp_dir) / "kis_tokens.db"
|
||||
barrier = threading.Barrier(2)
|
||||
post_calls = []
|
||||
|
||||
class _Response:
|
||||
def raise_for_status(self):
|
||||
return None
|
||||
|
||||
def json(self):
|
||||
return {
|
||||
"access_token": "concurrent-token-001",
|
||||
"expires_in": "3600",
|
||||
}
|
||||
|
||||
def _post(*args, **kwargs):
|
||||
post_calls.append(time.time())
|
||||
time.sleep(0.2)
|
||||
return _Response()
|
||||
|
||||
with patch.dict("os.environ", {"KIS_TOKEN_DB_PATH": str(override_db)}), \
|
||||
patch("src.quant_engine.kis_api_client_v1._requests") as mock_requests:
|
||||
mock_requests.return_value.post.side_effect = _post
|
||||
creds = KisCredentials(app_key="k", app_secret="s", account="mock")
|
||||
|
||||
results: list[str] = []
|
||||
errors: list[BaseException] = []
|
||||
|
||||
def worker() -> None:
|
||||
try:
|
||||
barrier.wait(timeout=5)
|
||||
results.append(_issue_or_reuse_token(creds))
|
||||
except BaseException as exc: # pragma: no cover - test harness diagnostic
|
||||
errors.append(exc)
|
||||
|
||||
t1 = threading.Thread(target=worker)
|
||||
t2 = threading.Thread(target=worker)
|
||||
t1.start()
|
||||
t2.start()
|
||||
t1.join(timeout=10)
|
||||
t2.join(timeout=10)
|
||||
|
||||
self.assertEqual(errors, [])
|
||||
self.assertEqual(results, ["concurrent-token-001", "concurrent-token-001"])
|
||||
self.assertEqual(len(post_calls), 1)
|
||||
|
||||
with sqlite3.connect(override_db) as conn:
|
||||
row = conn.execute(
|
||||
"SELECT access_token FROM kis_tokens WHERE account = 'mock'"
|
||||
).fetchone()
|
||||
self.assertIsNotNone(row)
|
||||
self.assertEqual(row[0], "concurrent-token-001")
|
||||
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sqlite3
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from src.quant_engine.kis_api_client_v1 import TOKEN_REFRESH_SKEW_MINUTES, _token_db_path
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Inspect the KIS Open API token cache.")
|
||||
parser.add_argument("--account", default=None, help="Optional account filter (real|mock)")
|
||||
parser.add_argument("--json", action="store_true", help="Print JSON output")
|
||||
args = parser.parse_args()
|
||||
|
||||
db_path = _token_db_path()
|
||||
if not db_path.exists():
|
||||
payload = {"gate": "DATA_MISSING", "db_path": str(db_path), "rows": []}
|
||||
print(json.dumps(payload, ensure_ascii=False, indent=2) if args.json else f"DATA_MISSING: {db_path}")
|
||||
return 0
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
try:
|
||||
rows = conn.execute(
|
||||
"SELECT account, access_token, expires_at, updated_at FROM kis_tokens ORDER BY account"
|
||||
).fetchall()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
data = []
|
||||
for account, access_token, expires_at, updated_at in rows:
|
||||
if args.account and account != args.account:
|
||||
continue
|
||||
data.append(
|
||||
{
|
||||
"account": account,
|
||||
"access_token_prefix": f"{access_token[:6]}..." if access_token else "",
|
||||
"expires_at": expires_at,
|
||||
"updated_at": updated_at,
|
||||
}
|
||||
)
|
||||
|
||||
payload = {
|
||||
"gate": "PASS" if data else "DATA_MISSING",
|
||||
"db_path": str(db_path),
|
||||
"refresh_skew_minutes": TOKEN_REFRESH_SKEW_MINUTES,
|
||||
"row_count": len(data),
|
||||
"rows": data,
|
||||
}
|
||||
if args.json:
|
||||
print(json.dumps(payload, ensure_ascii=False, indent=2))
|
||||
else:
|
||||
print(f"db_path: {payload['db_path']}")
|
||||
print(f"refresh_skew_minutes: {payload['refresh_skew_minutes']}")
|
||||
print(f"row_count: {payload['row_count']}")
|
||||
for row in data:
|
||||
print(
|
||||
f"- account={row['account']} token={row['access_token_prefix']} "
|
||||
f"expires_at={row['expires_at']} updated_at={row['updated_at']}"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user