diff --git a/stock-lab/app/db.py b/stock-lab/app/db.py index 189e54c..4183a10 100644 --- a/stock-lab/app/db.py +++ b/stock-lab/app/db.py @@ -5,11 +5,14 @@ from typing import List, Dict, Any, Optional from app.screener.schema import ensure_screener_schema -DB_PATH = "/app/data/stock.db" +DB_PATH = os.environ.get("STOCK_DB_PATH", "/app/data/stock.db") def _conn() -> sqlite3.Connection: - os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) - conn = sqlite3.connect(DB_PATH) + db_path = os.environ.get("STOCK_DB_PATH", DB_PATH) + parent = os.path.dirname(db_path) + if parent: + os.makedirs(parent, exist_ok=True) + conn = sqlite3.connect(db_path) conn.row_factory = sqlite3.Row return conn diff --git a/stock-lab/app/main.py b/stock-lab/app/main.py index 76cedcc..f66ffd6 100644 --- a/stock-lab/app/main.py +++ b/stock-lab/app/main.py @@ -27,6 +27,10 @@ from .ai_summarizer import summarize_news, OllamaError app = FastAPI() +# Screener 라우터 등록 +from app.screener.router import router as screener_router +app.include_router(screener_router) + # CORS 설정 (프론트엔드 접근 허용) _cors_origins = os.getenv("CORS_ALLOW_ORIGINS", "http://localhost:3007,http://localhost:8080").split(",") app.add_middleware( diff --git a/stock-lab/app/screener/router.py b/stock-lab/app/screener/router.py new file mode 100644 index 0000000..6618c25 --- /dev/null +++ b/stock-lab/app/screener/router.py @@ -0,0 +1,85 @@ +"""FastAPI router for /api/stock/screener/*""" + +from __future__ import annotations + +import datetime as dt +import json +import os +import sqlite3 + +from fastapi import APIRouter, HTTPException + +from . import schemas +from .registry import NODE_REGISTRY, GATE_REGISTRY + + +router = APIRouter(prefix="/api/stock/screener") + + +def _db_path() -> str: + return os.environ.get("STOCK_DB_PATH", "/app/data/stock.db") + + +def _conn() -> sqlite3.Connection: + return sqlite3.connect(_db_path()) + + +# ---------- /nodes ---------- + +@router.get("/nodes", response_model=schemas.NodesResponse) +def get_nodes(): + score_nodes = [ + schemas.NodeMeta( + name=cls.name, label=cls.label, + default_params=cls.default_params, param_schema=cls.param_schema, + ) + for cls in NODE_REGISTRY.values() + ] + gate_nodes = [ + schemas.NodeMeta( + name=cls.name, label=cls.label, + default_params=cls.default_params, param_schema=cls.param_schema, + ) + for cls in GATE_REGISTRY.values() + ] + return schemas.NodesResponse(score_nodes=score_nodes, gate_nodes=gate_nodes) + + +# ---------- /settings ---------- + +@router.get("/settings", response_model=schemas.SettingsResponse) +def get_settings(): + with _conn() as c: + row = c.execute( + "SELECT weights_json, node_params_json, gate_params_json, " + "top_n, rr_ratio, atr_window, atr_stop_mult, updated_at " + "FROM screener_settings WHERE id=1" + ).fetchone() + if row is None: + raise HTTPException(503, "settings not initialized") + return schemas.SettingsResponse( + weights=json.loads(row[0]), + node_params=json.loads(row[1]), + gate_params=json.loads(row[2]), + top_n=row[3], rr_ratio=row[4], atr_window=row[5], atr_stop_mult=row[6], + updated_at=row[7], + ) + + +@router.put("/settings", response_model=schemas.SettingsResponse) +def put_settings(body: schemas.SettingsBody): + now = dt.datetime.utcnow().isoformat() + with _conn() as c: + c.execute( + """UPDATE screener_settings SET + weights_json=?, node_params_json=?, gate_params_json=?, + top_n=?, rr_ratio=?, atr_window=?, atr_stop_mult=?, updated_at=? + WHERE id=1""", + ( + json.dumps(body.weights), json.dumps(body.node_params), + json.dumps(body.gate_params), + body.top_n, body.rr_ratio, body.atr_window, body.atr_stop_mult, now, + ), + ) + c.commit() + return schemas.SettingsResponse(**body.model_dump(), updated_at=now) diff --git a/stock-lab/app/test_screener_router.py b/stock-lab/app/test_screener_router.py new file mode 100644 index 0000000..301db2b --- /dev/null +++ b/stock-lab/app/test_screener_router.py @@ -0,0 +1,64 @@ +import os +import sqlite3 +import pytest +from fastapi.testclient import TestClient + +from app.screener.schema import ensure_screener_schema + + +@pytest.fixture(autouse=True) +def isolated_db(tmp_path, monkeypatch): + db_path = tmp_path / "screener_router.db" + c = sqlite3.connect(db_path) + ensure_screener_schema(c) + c.close() + monkeypatch.setenv("STOCK_DB_PATH", str(db_path)) + + +@pytest.fixture +def client(): + from app.main import app + return TestClient(app) + + +def test_get_nodes_lists_7_score_and_1_gate(client): + r = client.get("/api/stock/screener/nodes") + assert r.status_code == 200 + body = r.json() + assert len(body["score_nodes"]) == 7 + assert len(body["gate_nodes"]) == 1 + assert {n["name"] for n in body["score_nodes"]} == { + "foreign_buy", "volume_surge", "momentum", + "high52w", "rs_rating", "ma_alignment", "vcp_lite", + } + + +def test_settings_get_returns_defaults(client): + r = client.get("/api/stock/screener/settings") + assert r.status_code == 200 + body = r.json() + assert body["weights"]["foreign_buy"] == 1.0 + assert body["top_n"] == 20 + + +def test_settings_put_then_get_round_trip(client): + new_settings = { + "weights": {"foreign_buy": 2.5, "momentum": 1.0, "volume_surge": 1.0, + "high52w": 1.2, "rs_rating": 1.2, "ma_alignment": 1.0, "vcp_lite": 0.8}, + "node_params": {"foreign_buy": {"window_days": 7}}, + "gate_params": {"min_market_cap_won": 100_000_000_000, + "min_avg_value_won": 500_000_000, + "min_listed_days": 60, + "skip_managed": True, "skip_preferred": True, "skip_spac": True, + "skip_halted_days": 3}, + "top_n": 30, + "rr_ratio": 2.5, + "atr_window": 14, + "atr_stop_mult": 2.0, + } + r = client.put("/api/stock/screener/settings", json=new_settings) + assert r.status_code == 200 + r2 = client.get("/api/stock/screener/settings") + body = r2.json() + assert body["weights"]["foreign_buy"] == 2.5 + assert body["top_n"] == 30