diff --git a/src/quant_engine/kis_api_client_v1.py b/src/quant_engine/kis_api_client_v1.py index 4772a37..55c8152 100644 --- a/src/quant_engine/kis_api_client_v1.py +++ b/src/quant_engine/kis_api_client_v1.py @@ -115,23 +115,51 @@ class KisCredentials: return cls(app_key=app_key, app_secret=app_secret, account=account) -def _token_cache_path(creds: KisCredentials) -> Path: - TOKEN_CACHE_DIR.mkdir(parents=True, exist_ok=True) - return TOKEN_CACHE_DIR / f"kis_token_cache_{creds.account}.json" +import sqlite3 + +def _token_db_path() -> Path: + db_dir = ROOT / "outputs" / "kis_data_collection" + db_dir.mkdir(parents=True, exist_ok=True) + return db_dir / "kis_data_collection.db" + + +def _init_token_db(conn: sqlite3.Connection) -> None: + conn.execute( + """ + CREATE TABLE IF NOT EXISTS kis_tokens ( + account TEXT PRIMARY KEY, + access_token TEXT NOT NULL, + expires_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ) + """ + ) + conn.commit() def _issue_or_reuse_token(creds: KisCredentials) -> str: - """KIS는 토큰 발급 빈도를 제한한다 — 만료 전까지 캐시 재사용 필수.""" - cache_path = _token_cache_path(creds) - if cache_path.exists(): - try: - cached = json.loads(cache_path.read_text(encoding="utf-8")) - expires_at = dt.datetime.fromisoformat(cached["expires_at"]) - if dt.datetime.now(dt.timezone.utc) < expires_at - dt.timedelta(minutes=10): - return cached["access_token"] - except (json.JSONDecodeError, KeyError, ValueError): - pass + """KIS는 토큰 발급 빈도를 제한한다 — 만료 전까지 DB 캐시 재사용 필수.""" + db_path = _token_db_path() + + # 1. DB에서 기존 토큰 및 만료 시각 조회 + with sqlite3.connect(db_path) as conn: + _init_token_db(conn) + row = conn.execute( + "SELECT access_token, expires_at FROM kis_tokens WHERE account = ?", + (creds.account,) + ).fetchone() + + if row: + token, expires_at_str = row + try: + expires_at = dt.datetime.fromisoformat(expires_at_str) + # 만료 시간 10분 전까지 재사용 가능 여부 검사 + if dt.datetime.now(dt.timezone.utc) < expires_at - dt.timedelta(minutes=10): + return token + except ValueError: + pass + # 2. 토큰이 만료되었거나 없을 시 KIS API로 새로 발급 요청 requests = _requests() resp = requests.post( f"{creds.domain}/oauth2/tokenP", @@ -143,10 +171,18 @@ def _issue_or_reuse_token(creds: KisCredentials) -> str: access_token = body["access_token"] expires_in_sec = int(body.get("expires_in", 86400)) expires_at = dt.datetime.now(dt.timezone.utc) + dt.timedelta(seconds=expires_in_sec) - cache_path.write_text( - json.dumps({"access_token": access_token, "expires_at": expires_at.isoformat()}, ensure_ascii=False), - encoding="utf-8", - ) + + # 3. 새로운 토큰 정보를 DB에 안전하게 업서트 + with sqlite3.connect(db_path) as conn: + conn.execute( + """ + INSERT OR REPLACE INTO kis_tokens (account, access_token, expires_at, updated_at) + VALUES (?, ?, ?, ?) + """, + (creds.account, access_token, expires_at.isoformat(), dt.datetime.now(dt.timezone.utc).isoformat()) + ) + conn.commit() + return access_token