Files
QuantEngineByItz/src/quant_engine/data_collection_store_v1.py
T

465 lines
16 KiB
Python

"""SQLite store for platform-transition data collection outputs.
This store is intentionally small and backend-agnostic enough to be upgraded to
PostgreSQL later without changing the row contract. The canonical payload is the
normalized factor row plus provenance metadata.
"""
from __future__ import annotations
import json
import sqlite3
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterable
SCHEMA = """
PRAGMA journal_mode=WAL;
CREATE TABLE IF NOT EXISTS collection_runs (
run_id TEXT PRIMARY KEY,
collector_name TEXT NOT NULL,
started_at TEXT NOT NULL,
finished_at TEXT,
status TEXT NOT NULL,
input_source TEXT,
output_json_path TEXT,
output_db_path TEXT,
notes TEXT,
created_at TEXT DEFAULT (datetime('now'))
);
CREATE TABLE IF NOT EXISTS collection_snapshots (
run_id TEXT NOT NULL,
dataset_name TEXT NOT NULL,
ticker TEXT NOT NULL,
name TEXT,
sector TEXT,
as_of_date TEXT,
source_priority TEXT,
source_status TEXT,
payload_json TEXT NOT NULL,
provenance_json TEXT NOT NULL,
created_at TEXT DEFAULT (datetime('now')),
PRIMARY KEY (run_id, dataset_name, ticker)
);
CREATE TABLE IF NOT EXISTS collection_source_errors (
run_id TEXT NOT NULL,
ticker TEXT,
source_name TEXT NOT NULL,
error_kind TEXT NOT NULL,
error_message TEXT NOT NULL,
payload_json TEXT,
created_at TEXT DEFAULT (datetime('now'))
);
CREATE INDEX IF NOT EXISTS idx_collection_snapshots_ticker_time
ON collection_snapshots(ticker, created_at DESC);
CREATE INDEX IF NOT EXISTS idx_collection_source_errors_run
ON collection_source_errors(run_id, source_name);
"""
@dataclass(frozen=True)
class CollectionRun:
run_id: str
collector_name: str
started_at: str
status: str
input_source: str | None = None
output_json_path: str | None = None
output_db_path: str | None = None
notes: str | None = None
# SQLite와 PostgreSQL 연결을 동적으로 감지하여 연결 인스턴스를 리턴하는 헬퍼
def _get_connection(db_target: Path | str) -> Any:
db_str = str(db_target)
if db_str.startswith("postgresql://") or db_str.startswith("postgres://"):
try:
import psycopg2
from psycopg2.extras import RealDictCursor
conn = psycopg2.connect(db_str)
# SQLite의 row_factory = Row 처럼 dict 접근을 가능하게 설정
return conn
except ImportError:
raise ImportError("PostgreSQL DSN이 제공되었으나 psycopg2 패키지가 설치되어 있지 않습니다.")
else:
return sqlite3.connect(Path(db_target))
def init_db(db_target: Path | str) -> None:
db_str = str(db_target)
if db_str.startswith("postgresql://") or db_str.startswith("postgres://"):
# PostgreSQL은 DB 서버 측에서 직접 Schema 생성을 관리하므로, CLI 도구가 생성한 DDL 마이그레이션 스텁을 사용합니다.
# 런타임 수집 중 자동 DDL 실행은 락 이슈 예방을 위해 스킵하고 트랜잭션 연결만 보장합니다.
conn = _get_connection(db_target)
conn.close()
return
db_path = Path(db_target)
db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(db_path)
try:
conn.executescript(SCHEMA)
conn.commit()
finally:
conn.close()
def upsert_collection_run(db_target: Path | str, run: CollectionRun, finished_at: str | None = None) -> None:
init_db(db_target)
conn = _get_connection(db_target)
db_str = str(db_target)
is_pg = db_str.startswith("postgresql://") or db_str.startswith("postgres://")
try:
# SQLite와 PostgreSQL 쿼리 바인딩 플레이스홀더 분기 (? vs %s)
param_char = "%s" if is_pg else "?"
query = f"""
INSERT INTO collection_runs (
run_id, collector_name, started_at, finished_at, status,
input_source, output_json_path, output_db_path, notes
) VALUES ({', '.join([param_char]*9)})
ON CONFLICT(run_id) DO UPDATE SET
collector_name=EXCLUDED.collector_name,
started_at=EXCLUDED.started_at,
finished_at=EXCLUDED.finished_at,
status=EXCLUDED.status,
input_source=EXCLUDED.input_source,
output_json_path=EXCLUDED.output_json_path,
output_db_path=EXCLUDED.output_db_path,
notes=EXCLUDED.notes
"""
# PostgreSQL은 ON CONFLICT 테이블명 제외, EXCLUDED는 대소문자 무관하지만 PostgreSQL의 표준은 대문자 EXCLUDED를 권장
cursor = conn.cursor()
cursor.execute(
query,
(
run.run_id,
run.collector_name,
run.started_at,
finished_at,
run.status,
run.input_source,
run.output_json_path,
run.output_db_path,
run.notes,
),
)
conn.commit()
finally:
conn.close()
def upsert_collection_snapshot(
db_target: Path | str,
*,
run_id: str,
dataset_name: str,
ticker: str,
name: str | None,
sector: str | None,
as_of_date: str | None,
source_priority: str,
source_status: str,
payload: dict[str, Any],
provenance: dict[str, Any],
) -> None:
init_db(db_target)
conn = _get_connection(db_target)
db_str = str(db_target)
is_pg = db_str.startswith("postgresql://") or db_str.startswith("postgres://")
try:
param_char = "%s" if is_pg else "?"
query = f"""
INSERT INTO collection_snapshots (
run_id, dataset_name, ticker, name, sector, as_of_date,
source_priority, source_status, payload_json, provenance_json
) VALUES ({', '.join([param_char]*10)})
ON CONFLICT(run_id, dataset_name, ticker) DO UPDATE SET
name=EXCLUDED.name,
sector=EXCLUDED.sector,
as_of_date=EXCLUDED.as_of_date,
source_priority=EXCLUDED.source_priority,
source_status=EXCLUDED.source_status,
payload_json=EXCLUDED.payload_json,
provenance_json=EXCLUDED.provenance_json
"""
cursor = conn.cursor()
cursor.execute(
query,
(
run_id,
dataset_name,
ticker,
name,
sector,
as_of_date,
source_priority,
source_status,
json.dumps(payload, ensure_ascii=False, default=str),
json.dumps(provenance, ensure_ascii=False, default=str),
),
)
conn.commit()
finally:
conn.close()
def append_collection_error(
db_target: Path | str,
*,
run_id: str,
source_name: str,
error_kind: str,
error_message: str,
ticker: str | None = None,
payload: dict[str, Any] | None = None,
) -> None:
init_db(db_target)
conn = _get_connection(db_target)
db_str = str(db_target)
is_pg = db_str.startswith("postgresql://") or db_str.startswith("postgres://")
try:
param_char = "%s" if is_pg else "?"
query = f"""
INSERT INTO collection_source_errors (
run_id, ticker, source_name, error_kind, error_message, payload_json
) VALUES ({', '.join([param_char]*6)})
"""
cursor = conn.cursor()
cursor.execute(
query,
(
run_id,
ticker,
source_name,
error_kind,
error_message,
json.dumps(payload or {}, ensure_ascii=False, default=str),
),
)
conn.commit()
finally:
conn.close()
def fetch_latest_snapshots(db_target: Path | str, ticker: str, dataset_name: str | None = None) -> list[dict[str, Any]]:
db_str = str(db_target)
is_pg = db_str.startswith("postgresql://") or db_str.startswith("postgres://")
if not is_pg and not Path(db_target).exists():
return []
conn = _get_connection(db_target)
if not is_pg:
conn.row_factory = sqlite3.Row
try:
param_char = "%s" if is_pg else "?"
cursor = conn.cursor()
if dataset_name:
cursor.execute(
f"""
SELECT * FROM collection_snapshots
WHERE ticker = {param_char} AND dataset_name = {param_char}
ORDER BY created_at DESC
""",
(ticker, dataset_name),
)
else:
cursor.execute(
f"""
SELECT * FROM collection_snapshots
WHERE ticker = {param_char}
ORDER BY created_at DESC
""",
(ticker,),
)
rows = cursor.fetchall()
return [dict(row) for row in rows]
finally:
conn.close()
def iter_recent_snapshots(db_target: Path | str, limit: int = 50) -> Iterable[dict[str, Any]]:
db_str = str(db_target)
is_pg = db_str.startswith("postgresql://") or db_str.startswith("postgres://")
if not is_pg and not Path(db_target).exists():
return []
conn = _get_connection(db_target)
if not is_pg:
conn.row_factory = sqlite3.Row
try:
param_char = "%s" if is_pg else "?"
cursor = conn.cursor()
cursor.execute(
f"SELECT * FROM collection_snapshots ORDER BY created_at DESC LIMIT {param_char}",
(limit,),
)
rows = cursor.fetchall()
return [dict(row) for row in rows]
finally:
conn.close()
def load_collection_runs(db_target: Path | str, limit: int = 20) -> list[dict[str, Any]]:
db_str = str(db_target)
is_pg = db_str.startswith("postgresql://") or db_str.startswith("postgres://")
if not is_pg and not Path(db_target).exists():
return []
conn = _get_connection(db_target)
if not is_pg:
conn.row_factory = sqlite3.Row
try:
param_char = "%s" if is_pg else "?"
cursor = conn.cursor()
cursor.execute(
f"""
SELECT run_id, collector_name, started_at, finished_at, status,
input_source, output_json_path, output_db_path, notes, created_at
FROM collection_runs
ORDER BY started_at DESC, created_at DESC
LIMIT {param_char}
""",
(int(limit),),
)
rows = cursor.fetchall()
return [dict(row) for row in rows]
finally:
conn.close()
def load_collection_errors(db_target: Path | str, limit: int = 20) -> list[dict[str, Any]]:
db_str = str(db_target)
is_pg = db_str.startswith("postgresql://") or db_str.startswith("postgres://")
if not is_pg and not Path(db_target).exists():
return []
conn = _get_connection(db_target)
if not is_pg:
conn.row_factory = sqlite3.Row
try:
param_char = "%s" if is_pg else "?"
cursor = conn.cursor()
cursor.execute(
f"""
SELECT run_id, ticker, source_name, error_kind, error_message, payload_json, created_at
FROM collection_source_errors
ORDER BY created_at DESC
LIMIT {param_char}
""",
(int(limit),),
)
rows = cursor.fetchall()
return [dict(row) for row in rows]
finally:
conn.close()
def load_collection_dashboard_state(
db_target: Path | str | None = None,
output_json_path: Path | str | None = None,
*,
limit: int = 8,
) -> dict[str, Any]:
db_str = str(db_target or "")
is_pg = db_str.startswith("postgresql://") or db_str.startswith("postgres://")
db = Path(db_target) if db_target and not is_pg else Path()
report = Path(output_json_path) if output_json_path else Path()
state: dict[str, Any] = {
"db_path": db_str,
"output_json_path": str(report) if output_json_path else "",
"runs": [],
"recent_snapshots": [],
"recent_errors": [],
"counts": {
"collection_runs": 0,
"collection_snapshots": 0,
"collection_source_errors": 0,
},
"latest_run": {},
"latest_report": {},
}
if report.exists():
try:
state["latest_report"] = json.loads(report.read_text(encoding="utf-8"))
except Exception:
state["latest_report"] = {}
if not is_pg and (not db_target or not db.exists()):
return state
conn = _get_connection(db_target)
if not is_pg:
conn.row_factory = sqlite3.Row
try:
cursor = conn.cursor()
state["counts"] = {
"collection_runs": cursor.execute("SELECT COUNT(*) FROM collection_runs").fetchone()[0] if not is_pg else cursor.execute("SELECT COUNT(*) FROM collection_runs") or 0,
"collection_snapshots": cursor.execute("SELECT COUNT(*) FROM collection_snapshots").fetchone()[0] if not is_pg else cursor.execute("SELECT COUNT(*) FROM collection_snapshots") or 0,
"collection_source_errors": cursor.execute("SELECT COUNT(*) FROM collection_source_errors").fetchone()[0] if not is_pg else cursor.execute("SELECT COUNT(*) FROM collection_source_errors") or 0,
}
# PostgreSQL인 경우 단순 fetchone() 보완
if is_pg:
# PostgreSQL count 처리
cursor.execute("SELECT COUNT(*) FROM collection_runs")
state["counts"]["collection_runs"] = cursor.fetchone()[0]
cursor.execute("SELECT COUNT(*) FROM collection_snapshots")
state["counts"]["collection_snapshots"] = cursor.fetchone()[0]
cursor.execute("SELECT COUNT(*) FROM collection_source_errors")
state["counts"]["collection_source_errors"] = cursor.fetchone()[0]
cursor.execute(
"""
SELECT run_id, collector_name, started_at, finished_at, status,
input_source, output_json_path, output_db_path, notes, created_at
FROM collection_runs
ORDER BY started_at DESC, created_at DESC
LIMIT 1
"""
)
run_row = cursor.fetchone()
state["latest_run"] = dict(run_row) if run_row is not None else {}
param_char = "%s" if is_pg else "?"
cursor.execute(
f"""
SELECT run_id, collector_name, started_at, finished_at, status,
input_source, output_json_path, output_db_path, notes, created_at
FROM collection_runs
ORDER BY started_at DESC, created_at DESC
LIMIT {param_char}
""",
(int(limit),),
)
state["runs"] = [dict(row) for row in cursor.fetchall()]
cursor.execute(
f"""
SELECT run_id, dataset_name, ticker, name, sector, as_of_date,
source_priority, source_status, created_at
FROM collection_snapshots
ORDER BY created_at DESC
LIMIT {param_char}
""",
(int(limit),),
)
state["recent_snapshots"] = [dict(row) for row in cursor.fetchall()]
cursor.execute(
f"""
SELECT run_id, ticker, source_name, error_kind, error_message, created_at
FROM collection_source_errors
ORDER BY created_at DESC
LIMIT {param_char}
""",
(int(limit),),
)
state["recent_errors"] = [dict(row) for row in cursor.fetchall()]
finally:
conn.close()
return state