Compare commits
9 Commits
b690900cfc
...
760f914d3b
| Author | SHA1 | Date | |
|---|---|---|---|
| 760f914d3b | |||
| 8eefe9d79d | |||
| 91de16675b | |||
| 44888d6ede | |||
| 9e5fecb369 | |||
| 28f9c8c3a6 | |||
| c5a88fab66 | |||
| 7056cf2fa6 | |||
| 4ac7da8670 |
@@ -7,3 +7,7 @@ pytest>=8.0
|
|||||||
pytest-asyncio>=0.23
|
pytest-asyncio>=0.23
|
||||||
respx>=0.21
|
respx>=0.21
|
||||||
websockets>=12
|
websockets>=12
|
||||||
|
# Phase 3b dependencies (Chronos-2 + ML)
|
||||||
|
transformers>=4.40
|
||||||
|
chronos-forecasting>=1.4
|
||||||
|
# torch: typically already installed via V1 venv; if not, install with CUDA support manually
|
||||||
|
|||||||
132
signal_v2/chronos_predictor.py
Normal file
132
signal_v2/chronos_predictor.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
"""Chronos-2 zero-shot forecaster wrapper."""
|
||||||
|
from __future__ import annotations
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
KST = ZoneInfo("Asia/Seoul")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChronosPrediction:
|
||||||
|
median: float
|
||||||
|
q10: float
|
||||||
|
q90: float
|
||||||
|
conf: float
|
||||||
|
as_of: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChronosPredictor:
|
||||||
|
"""HuggingFace Chronos-2 zero-shot forecaster."""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = "amazon/chronos-2", device: str | None = None):
|
||||||
|
# BaseChronosPipeline auto-detects model variant (Chronos / ChronosBolt / Chronos-2)
|
||||||
|
# and returns the appropriate sub-pipeline. ChronosPipeline only supports legacy T5.
|
||||||
|
import torch
|
||||||
|
try:
|
||||||
|
from chronos import BaseChronosPipeline
|
||||||
|
pipeline_cls = BaseChronosPipeline
|
||||||
|
except ImportError:
|
||||||
|
from chronos import ChronosPipeline
|
||||||
|
pipeline_cls = ChronosPipeline
|
||||||
|
|
||||||
|
self._device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
# Always use float32 — Korean stock prices (e.g. 280,000원) exceed FP16 max (~65,504)
|
||||||
|
# causing inf in quantile output. FP32 is safe for typical price magnitudes.
|
||||||
|
dtype = torch.float32
|
||||||
|
logger.info("Loading Chronos pipeline: %s on %s (cls=%s)",
|
||||||
|
model_name, self._device, pipeline_cls.__name__)
|
||||||
|
# Try `dtype` (newer API) first, fall back to `torch_dtype` (older)
|
||||||
|
try:
|
||||||
|
self._pipeline = pipeline_cls.from_pretrained(
|
||||||
|
model_name, device_map=self._device, dtype=dtype,
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
self._pipeline = pipeline_cls.from_pretrained(
|
||||||
|
model_name, device_map=self._device, torch_dtype=dtype,
|
||||||
|
)
|
||||||
|
logger.info("Chronos pipeline loaded.")
|
||||||
|
|
||||||
|
def predict_batch(
|
||||||
|
self,
|
||||||
|
daily_ohlcv_dict: dict[str, list[dict]],
|
||||||
|
prediction_length: int = 1,
|
||||||
|
num_samples: int = 100,
|
||||||
|
) -> dict[str, ChronosPrediction]:
|
||||||
|
"""종목별 1-day return 분포 예측.
|
||||||
|
|
||||||
|
ChronosBolt / Chronos-2 등 신모델은 predict_quantiles 사용 (deterministic).
|
||||||
|
Legacy ChronosPipeline (T5) 는 sample-based predict.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
tickers = list(daily_ohlcv_dict.keys())
|
||||||
|
if not tickers:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
torch.tensor([bar["close"] for bar in daily_ohlcv_dict[t]], dtype=torch.float32)
|
||||||
|
for t in tickers
|
||||||
|
]
|
||||||
|
now_iso = datetime.now(KST).isoformat()
|
||||||
|
results: dict[str, ChronosPrediction] = {}
|
||||||
|
|
||||||
|
# Modern API: predict_quantiles (ChronosBolt / Chronos-2)
|
||||||
|
if hasattr(self._pipeline, "predict_quantiles"):
|
||||||
|
quantile_levels = [0.1, 0.5, 0.9]
|
||||||
|
# ChronosBolt API: positional `inputs` (first arg). Older variants use `context`.
|
||||||
|
try:
|
||||||
|
quantiles_tensor, _ = self._pipeline.predict_quantiles(
|
||||||
|
contexts,
|
||||||
|
prediction_length=prediction_length,
|
||||||
|
quantile_levels=quantile_levels,
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
quantiles_tensor, _ = self._pipeline.predict_quantiles(
|
||||||
|
context=contexts,
|
||||||
|
prediction_length=prediction_length,
|
||||||
|
quantile_levels=quantile_levels,
|
||||||
|
)
|
||||||
|
quantiles_np = (
|
||||||
|
quantiles_tensor.cpu().numpy()
|
||||||
|
if hasattr(quantiles_tensor, "cpu")
|
||||||
|
else np.asarray(quantiles_tensor)
|
||||||
|
)
|
||||||
|
# shape: [num_series, prediction_length, 3]
|
||||||
|
for i, ticker in enumerate(tickers):
|
||||||
|
q10_price, q50_price, q90_price = quantiles_np[i, 0, :]
|
||||||
|
last_close = daily_ohlcv_dict[ticker][-1]["close"]
|
||||||
|
median = float((q50_price - last_close) / last_close)
|
||||||
|
q10 = float((q10_price - last_close) / last_close)
|
||||||
|
q90 = float((q90_price - last_close) / last_close)
|
||||||
|
spread = (q90 - q10) / max(abs(median), 0.001)
|
||||||
|
conf = float(max(0.0, min(1.0, 1.0 - spread / 2.0)))
|
||||||
|
results[ticker] = ChronosPrediction(
|
||||||
|
median=median, q10=q10, q90=q90, conf=conf, as_of=now_iso,
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Legacy API: sample-based predict (ChronosPipeline T5)
|
||||||
|
forecasts = self._pipeline.predict(
|
||||||
|
context=contexts,
|
||||||
|
prediction_length=prediction_length,
|
||||||
|
num_samples=num_samples,
|
||||||
|
)
|
||||||
|
forecasts_np = forecasts.numpy() if hasattr(forecasts, "numpy") else np.asarray(forecasts)
|
||||||
|
for i, ticker in enumerate(tickers):
|
||||||
|
samples = forecasts_np[i, :, 0]
|
||||||
|
last_close = daily_ohlcv_dict[ticker][-1]["close"]
|
||||||
|
returns = (samples - last_close) / last_close
|
||||||
|
median = float(np.quantile(returns, 0.5))
|
||||||
|
q10 = float(np.quantile(returns, 0.1))
|
||||||
|
q90 = float(np.quantile(returns, 0.9))
|
||||||
|
spread = (q90 - q10) / max(abs(median), 0.001)
|
||||||
|
conf = float(max(0.0, min(1.0, 1.0 - spread / 2.0)))
|
||||||
|
results[ticker] = ChronosPrediction(
|
||||||
|
median=median, q10=q10, q90=q90, conf=conf, as_of=now_iso,
|
||||||
|
)
|
||||||
|
return results
|
||||||
@@ -34,6 +34,7 @@ class Settings:
|
|||||||
str(Path(__file__).parent.parent / "signal_v1" / "data" / "kis_token.json"))
|
str(Path(__file__).parent.parent / "signal_v1" / "data" / "kis_token.json"))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
chronos_model: str = field(default_factory=lambda: os.getenv("CHRONOS_MODEL", "amazon/chronos-2"))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def kis_is_virtual(self) -> bool:
|
def kis_is_virtual(self) -> bool:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime, timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
@@ -153,3 +153,41 @@ class KISClient:
|
|||||||
"current_price": current_price,
|
"current_price": current_price,
|
||||||
"as_of": datetime.now(KST).isoformat(),
|
"as_of": datetime.now(KST).isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def get_daily_ohlcv(self, ticker: str, days: int = 60) -> list[dict]:
|
||||||
|
"""KRX 일봉 OHLCV (TR_ID FHKST03010100).
|
||||||
|
|
||||||
|
Returns: [{"datetime", "open", "high", "low", "close", "volume"}, ...]
|
||||||
|
시간 오름차순.
|
||||||
|
"""
|
||||||
|
path = "/uapi/domestic-stock/v1/quotations/inquire-daily-itemchartprice"
|
||||||
|
today = datetime.now(KST).strftime("%Y%m%d")
|
||||||
|
start_date = (datetime.now(KST) - timedelta(days=days * 2)).strftime("%Y%m%d")
|
||||||
|
params = {
|
||||||
|
"FID_COND_MRKT_DIV_CODE": "J",
|
||||||
|
"FID_INPUT_ISCD": ticker,
|
||||||
|
"FID_INPUT_DATE_1": start_date,
|
||||||
|
"FID_INPUT_DATE_2": today,
|
||||||
|
"FID_PERIOD_DIV_CODE": "D",
|
||||||
|
"FID_ORG_ADJ_PRC": "1",
|
||||||
|
}
|
||||||
|
raw = await self._request_with_retry(
|
||||||
|
"GET", path, tr_id="FHKST03010100", params=params,
|
||||||
|
)
|
||||||
|
output2 = raw.get("output2", [])
|
||||||
|
bars = []
|
||||||
|
for row in output2:
|
||||||
|
try:
|
||||||
|
date = row["stck_bsop_date"]
|
||||||
|
bars.append({
|
||||||
|
"datetime": f"{date[:4]}-{date[4:6]}-{date[6:]}",
|
||||||
|
"open": int(row["stck_oprc"]),
|
||||||
|
"high": int(row["stck_hgpr"]),
|
||||||
|
"low": int(row["stck_lwpr"]),
|
||||||
|
"close": int(row["stck_clpr"]),
|
||||||
|
"volume": int(row["acml_vol"]),
|
||||||
|
})
|
||||||
|
except (KeyError, ValueError):
|
||||||
|
continue
|
||||||
|
bars.reverse()
|
||||||
|
return bars[-days:]
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from contextlib import asynccontextmanager
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
from signal_v2 import state as state_mod
|
from signal_v2 import state as state_mod
|
||||||
|
from signal_v2.chronos_predictor import ChronosPredictor
|
||||||
from signal_v2.config import get_settings
|
from signal_v2.config import get_settings
|
||||||
from signal_v2.kis_client import KISClient
|
from signal_v2.kis_client import KISClient
|
||||||
from signal_v2.kis_websocket import KISWebSocket
|
from signal_v2.kis_websocket import KISWebSocket
|
||||||
@@ -24,6 +25,7 @@ class AppContext:
|
|||||||
poll_task: asyncio.Task | None = None
|
poll_task: asyncio.Task | None = None
|
||||||
kis_client: KISClient | None = None
|
kis_client: KISClient | None = None
|
||||||
kis_ws: KISWebSocket | None = None
|
kis_ws: KISWebSocket | None = None
|
||||||
|
chronos: ChronosPredictor | None = None
|
||||||
|
|
||||||
|
|
||||||
_ctx = AppContext()
|
_ctx = AppContext()
|
||||||
@@ -69,10 +71,17 @@ async def lifespan(app: FastAPI):
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("KIS WebSocket startup failed — continuing without realtime asking_price")
|
logger.exception("KIS WebSocket startup failed — continuing without realtime asking_price")
|
||||||
|
|
||||||
|
# Load Chronos (heavy: ~1GB model download first time)
|
||||||
|
try:
|
||||||
|
_ctx.chronos = ChronosPredictor(model_name=settings.chronos_model)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("ChronosPredictor load failed — continuing without chronos predictions")
|
||||||
|
|
||||||
_ctx.poll_task = asyncio.create_task(
|
_ctx.poll_task = asyncio.create_task(
|
||||||
poll_loop(
|
poll_loop(
|
||||||
_ctx.client, state_mod.state, _ctx.shutdown,
|
_ctx.client, state_mod.state, _ctx.shutdown,
|
||||||
kis_client=_ctx.kis_client,
|
kis_client=_ctx.kis_client,
|
||||||
|
chronos=_ctx.chronos,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
69
signal_v2/momentum_classifier.py
Normal file
69
signal_v2/momentum_classifier.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
"""분봉 OHLCV → 5-level 모멘텀 분류."""
|
||||||
|
from __future__ import annotations
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
# 분류 카테고리
|
||||||
|
STRONG_UP = "strong_up"
|
||||||
|
WEAK_UP = "weak_up"
|
||||||
|
NEUTRAL = "neutral"
|
||||||
|
WEAK_DOWN = "weak_down"
|
||||||
|
STRONG_DOWN = "strong_down"
|
||||||
|
|
||||||
|
_BARS_PER_5MIN = 5
|
||||||
|
_LOOKBACK_5MIN_BARS = 5
|
||||||
|
_VOLUME_AVG_WINDOW = 12 # 60분 = 5분봉 12개
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_1min_to_5min(minute_bars: list[dict]) -> list[dict]:
|
||||||
|
"""1분봉 N개 → 5분봉 floor(N/5) 개. 시간 오름차순.
|
||||||
|
|
||||||
|
각 5분봉: open=첫 1분봉 open, high=max, low=min, close=마지막 close, volume=sum.
|
||||||
|
"""
|
||||||
|
bars_5min = []
|
||||||
|
chunks = len(minute_bars) // _BARS_PER_5MIN
|
||||||
|
for i in range(chunks):
|
||||||
|
chunk = minute_bars[i * _BARS_PER_5MIN : (i + 1) * _BARS_PER_5MIN]
|
||||||
|
bars_5min.append({
|
||||||
|
"datetime": chunk[0]["datetime"],
|
||||||
|
"open": chunk[0]["open"],
|
||||||
|
"high": max(b["high"] for b in chunk),
|
||||||
|
"low": min(b["low"] for b in chunk),
|
||||||
|
"close": chunk[-1]["close"],
|
||||||
|
"volume": sum(b["volume"] for b in chunk),
|
||||||
|
})
|
||||||
|
return bars_5min
|
||||||
|
|
||||||
|
|
||||||
|
def classify_minute_momentum(minute_bars: deque) -> str:
|
||||||
|
"""1분봉 deque → 5-level 모멘텀 분류.
|
||||||
|
|
||||||
|
Returns: STRONG_UP / WEAK_UP / NEUTRAL / WEAK_DOWN / STRONG_DOWN
|
||||||
|
"""
|
||||||
|
minute_list = list(minute_bars)
|
||||||
|
if len(minute_list) < _BARS_PER_5MIN * _LOOKBACK_5MIN_BARS:
|
||||||
|
return NEUTRAL # 데이터 부족
|
||||||
|
|
||||||
|
bars_5min = aggregate_1min_to_5min(minute_list)
|
||||||
|
if len(bars_5min) < _LOOKBACK_5MIN_BARS:
|
||||||
|
return NEUTRAL
|
||||||
|
|
||||||
|
recent = bars_5min[-_LOOKBACK_5MIN_BARS:]
|
||||||
|
up_count = sum(1 for b in recent if b["close"] > b["open"])
|
||||||
|
|
||||||
|
# 거래량 multiplier: recent 5 avg vs 60분 avg
|
||||||
|
recent_vol_avg = sum(b["volume"] for b in recent) / len(recent)
|
||||||
|
long_window = bars_5min[-_VOLUME_AVG_WINDOW:]
|
||||||
|
long_vol_avg = sum(b["volume"] for b in long_window) / len(long_window)
|
||||||
|
vol_mult = recent_vol_avg / long_vol_avg if long_vol_avg > 0 else 1.0
|
||||||
|
|
||||||
|
# 5-level 분류
|
||||||
|
if up_count == 5 and vol_mult >= 1.5:
|
||||||
|
return STRONG_UP
|
||||||
|
elif up_count >= 3 and vol_mult >= 1.0:
|
||||||
|
return WEAK_UP
|
||||||
|
elif up_count == 0 and vol_mult >= 1.5:
|
||||||
|
return STRONG_DOWN
|
||||||
|
elif up_count <= 2 and vol_mult < 1.0:
|
||||||
|
return WEAK_DOWN
|
||||||
|
else:
|
||||||
|
return NEUTRAL
|
||||||
@@ -7,7 +7,7 @@ from datetime import datetime
|
|||||||
|
|
||||||
from signal_v2.kis_client import KISClient
|
from signal_v2.kis_client import KISClient
|
||||||
from signal_v2.scheduler import (
|
from signal_v2.scheduler import (
|
||||||
KST, _is_market_day, _is_polling_window, _next_interval,
|
KST, _is_market_day, _is_polling_window, _next_interval, _is_post_close_trigger,
|
||||||
)
|
)
|
||||||
from signal_v2.state import PollState
|
from signal_v2.state import PollState
|
||||||
from signal_v2.stock_client import StockClient
|
from signal_v2.stock_client import StockClient
|
||||||
@@ -18,6 +18,7 @@ logger = logging.getLogger(__name__)
|
|||||||
async def poll_loop(
|
async def poll_loop(
|
||||||
client: StockClient, state: PollState, shutdown: asyncio.Event,
|
client: StockClient, state: PollState, shutdown: asyncio.Event,
|
||||||
kis_client: KISClient | None = None,
|
kis_client: KISClient | None = None,
|
||||||
|
chronos=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""FastAPI lifespan 에서 asyncio.create_task 로 시작."""
|
"""FastAPI lifespan 에서 asyncio.create_task 로 시작."""
|
||||||
logger.info("poll_loop started")
|
logger.info("poll_loop started")
|
||||||
@@ -28,6 +29,17 @@ async def poll_loop(
|
|||||||
await _run_polling_cycle(client, state, kis_client=kis_client)
|
await _run_polling_cycle(client, state, kis_client=kis_client)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("poll cycle failed")
|
logger.exception("poll cycle failed")
|
||||||
|
# Minute momentum 갱신 (매 cycle)
|
||||||
|
try:
|
||||||
|
update_minute_momentum_for_all(state)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("minute momentum update failed")
|
||||||
|
# Post-close trigger (16:00 KST)
|
||||||
|
if _is_post_close_trigger(now) and chronos is not None and kis_client is not None:
|
||||||
|
try:
|
||||||
|
await _run_post_close_cycle(kis_client, chronos, state)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("post-close cycle failed")
|
||||||
interval = _next_interval(now)
|
interval = _next_interval(now)
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(shutdown.wait(), timeout=interval)
|
await asyncio.wait_for(shutdown.wait(), timeout=interval)
|
||||||
@@ -125,3 +137,48 @@ def _screener_tickers(state: PollState) -> list[str]:
|
|||||||
if state.screener_preview is None:
|
if state.screener_preview is None:
|
||||||
return []
|
return []
|
||||||
return [i["ticker"] for i in state.screener_preview.get("items", []) if "ticker" in i]
|
return [i["ticker"] for i in state.screener_preview.get("items", []) if "ticker" in i]
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_post_close_cycle(kis_client, chronos, state) -> None:
|
||||||
|
"""16:00 KST 종가 후 1회: daily fetch + chronos predict."""
|
||||||
|
tickers = list(set(_portfolio_tickers(state)) | set(_screener_tickers(state)))
|
||||||
|
if not tickers:
|
||||||
|
return
|
||||||
|
|
||||||
|
daily_results = await asyncio.gather(*[
|
||||||
|
kis_client.get_daily_ohlcv(t, days=60) for t in tickers
|
||||||
|
], return_exceptions=True)
|
||||||
|
daily_dict = {}
|
||||||
|
for ticker, result in zip(tickers, daily_results):
|
||||||
|
if isinstance(result, list) and len(result) >= 30:
|
||||||
|
daily_dict[ticker] = result
|
||||||
|
state.daily_ohlcv[ticker] = result
|
||||||
|
elif isinstance(result, Exception):
|
||||||
|
state.fetch_errors[f"daily_ohlcv/{ticker}"] = (
|
||||||
|
state.fetch_errors.get(f"daily_ohlcv/{ticker}", 0) + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
if daily_dict and chronos is not None:
|
||||||
|
try:
|
||||||
|
predictions = chronos.predict_batch(daily_dict)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("chronos predict_batch failed")
|
||||||
|
return
|
||||||
|
for ticker, pred in predictions.items():
|
||||||
|
state.chronos_predictions[ticker] = {
|
||||||
|
"median": pred.median,
|
||||||
|
"q10": pred.q10,
|
||||||
|
"q90": pred.q90,
|
||||||
|
"conf": pred.conf,
|
||||||
|
"as_of": pred.as_of,
|
||||||
|
}
|
||||||
|
state.last_updated[f"chronos/{ticker}"] = pred.as_of
|
||||||
|
|
||||||
|
|
||||||
|
def update_minute_momentum_for_all(state) -> None:
|
||||||
|
"""매 분봉 cycle 후 호출 — 모든 종목 모멘텀 갱신."""
|
||||||
|
from signal_v2.momentum_classifier import classify_minute_momentum
|
||||||
|
now_iso = datetime.now(KST).isoformat()
|
||||||
|
for ticker, bars in state.minute_bars.items():
|
||||||
|
state.minute_momentum[ticker] = classify_minute_momentum(bars)
|
||||||
|
state.last_updated[f"momentum/{ticker}"] = now_iso
|
||||||
|
|||||||
@@ -76,6 +76,14 @@ def _seconds_until_nxt_or_market_open(now: datetime) -> float:
|
|||||||
return 86400.0
|
return 86400.0
|
||||||
|
|
||||||
|
|
||||||
|
def _is_post_close_trigger(now: datetime) -> bool:
|
||||||
|
"""16:00 KST ±1분 (post-close cycle 트리거). 평일/영업일만."""
|
||||||
|
if not _is_market_day(now):
|
||||||
|
return False
|
||||||
|
t = now.time()
|
||||||
|
return time(16, 0) <= t < time(16, 1)
|
||||||
|
|
||||||
|
|
||||||
def _seconds_until_next_market_open(now: datetime) -> float:
|
def _seconds_until_next_market_open(now: datetime) -> float:
|
||||||
"""다음 영업일의 07:00 KST 까지 초수 (휴장일/주말용)."""
|
"""다음 영업일의 07:00 KST 까지 초수 (휴장일/주말용)."""
|
||||||
candidate = now.replace(hour=7, minute=0, second=0, microsecond=0)
|
candidate = now.replace(hour=7, minute=0, second=0, microsecond=0)
|
||||||
|
|||||||
@@ -8,9 +8,12 @@ class PollState:
|
|||||||
portfolio: dict | None = None
|
portfolio: dict | None = None
|
||||||
news_sentiment: dict | None = None
|
news_sentiment: dict | None = None
|
||||||
screener_preview: dict | None = None
|
screener_preview: dict | None = None
|
||||||
# Phase 3a additions
|
|
||||||
minute_bars: dict[str, deque] = field(default_factory=dict)
|
minute_bars: dict[str, deque] = field(default_factory=dict)
|
||||||
asking_price: dict[str, dict] = field(default_factory=dict)
|
asking_price: dict[str, dict] = field(default_factory=dict)
|
||||||
|
# Phase 3b additions
|
||||||
|
daily_ohlcv: dict[str, list[dict]] = field(default_factory=dict)
|
||||||
|
chronos_predictions: dict[str, dict] = field(default_factory=dict)
|
||||||
|
minute_momentum: dict[str, str] = field(default_factory=dict)
|
||||||
last_updated: dict[str, str] = field(default_factory=dict)
|
last_updated: dict[str, str] = field(default_factory=dict)
|
||||||
fetch_errors: dict[str, int] = field(default_factory=dict)
|
fetch_errors: dict[str, int] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|||||||
92
signal_v2/tests/test_chronos_predictor.py
Normal file
92
signal_v2/tests/test_chronos_predictor.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""Tests for ChronosPredictor (model mock)."""
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_pipeline():
|
||||||
|
"""Mock BaseChronosPipeline.from_pretrained returning a mock pipeline object."""
|
||||||
|
with patch("chronos.BaseChronosPipeline") as cls:
|
||||||
|
cls.__name__ = "BaseChronosPipeline"
|
||||||
|
instance = MagicMock()
|
||||||
|
# ChronosBolt API: predict_quantiles returns (quantiles_tensor, mean_tensor)
|
||||||
|
# Modern (predict_quantiles) branch will be used since hasattr(MagicMock, "predict_quantiles") is True.
|
||||||
|
cls.from_pretrained.return_value = instance
|
||||||
|
yield instance
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_torch_cpu():
|
||||||
|
with patch("torch.cuda.is_available", return_value=False):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def _daily_ohlcv(close_seq):
|
||||||
|
return [{"datetime": f"2026-05-{i+1:02d}", "open": c, "high": c, "low": c,
|
||||||
|
"close": c, "volume": 1000} for i, c in enumerate(close_seq)]
|
||||||
|
|
||||||
|
|
||||||
|
def _mk_quantiles_tensor(q10_price: float, q50_price: float, q90_price: float):
|
||||||
|
"""Helper: build predict_quantiles return tensor shape [1, 1, 3]."""
|
||||||
|
import torch
|
||||||
|
return torch.tensor([[[q10_price, q50_price, q90_price]]], dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def test_predict_batch_returns_prediction_dict(mock_pipeline, mock_torch_cpu):
|
||||||
|
"""mock predict_quantiles → dict[ticker, ChronosPrediction]. last_close=100, q50=102 → median≈+2%."""
|
||||||
|
quantiles = _mk_quantiles_tensor(101.5, 102.0, 102.5) # narrow around 102
|
||||||
|
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
|
||||||
|
|
||||||
|
from signal_v2.chronos_predictor import ChronosPredictor, ChronosPrediction
|
||||||
|
predictor = ChronosPredictor(model_name="mock-model")
|
||||||
|
daily = {"005930": _daily_ohlcv([100] * 60)}
|
||||||
|
result = predictor.predict_batch(daily)
|
||||||
|
assert "005930" in result
|
||||||
|
pred = result["005930"]
|
||||||
|
assert isinstance(pred, ChronosPrediction)
|
||||||
|
assert abs(pred.median - 0.02) < 0.001
|
||||||
|
|
||||||
|
|
||||||
|
def test_conf_high_when_distribution_narrow(mock_pipeline, mock_torch_cpu):
|
||||||
|
"""좁은 distribution (q90-q10 작음, median 0 아님) → conf ≈ 1."""
|
||||||
|
# last_close=100, q10=101.99, q50=102.00, q90=102.01
|
||||||
|
# returns: q10=0.0199, q50=0.02, q90=0.0201
|
||||||
|
# spread = (0.0201 - 0.0199) / max(0.02, 0.001) = 0.0002/0.02 = 0.01 → conf = 1 - 0.005 = 0.995
|
||||||
|
quantiles = _mk_quantiles_tensor(101.99, 102.0, 102.01)
|
||||||
|
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
|
||||||
|
|
||||||
|
from signal_v2.chronos_predictor import ChronosPredictor
|
||||||
|
predictor = ChronosPredictor(model_name="mock-model")
|
||||||
|
daily = {"005930": _daily_ohlcv([100] * 60)}
|
||||||
|
result = predictor.predict_batch(daily)
|
||||||
|
assert result["005930"].conf > 0.8
|
||||||
|
|
||||||
|
|
||||||
|
def test_conf_low_when_distribution_wide(mock_pipeline, mock_torch_cpu):
|
||||||
|
"""넓은 distribution → conf ≈ 0."""
|
||||||
|
# last_close=100, q10=70, q50=100, q90=130
|
||||||
|
# returns: q10=-0.3, q50=0.0, q90=0.3
|
||||||
|
# spread = (0.3 - (-0.3)) / max(0.0, 0.001) = 0.6 / 0.001 = 600 → conf = max(0, 1 - 300) = 0
|
||||||
|
quantiles = _mk_quantiles_tensor(70.0, 100.0, 130.0)
|
||||||
|
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
|
||||||
|
|
||||||
|
from signal_v2.chronos_predictor import ChronosPredictor
|
||||||
|
predictor = ChronosPredictor(model_name="mock-model")
|
||||||
|
daily = {"005930": _daily_ohlcv([100] * 60)}
|
||||||
|
result = predictor.predict_batch(daily)
|
||||||
|
assert result["005930"].conf < 0.3
|
||||||
|
|
||||||
|
|
||||||
|
def test_return_computed_from_price_relative_to_last_close(mock_pipeline, mock_torch_cpu):
|
||||||
|
"""price 예측 → last_close 대비 return 변환. last_close=100, q50=110 → return ≈ +10%."""
|
||||||
|
quantiles = _mk_quantiles_tensor(109.0, 110.0, 111.0)
|
||||||
|
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
|
||||||
|
|
||||||
|
from signal_v2.chronos_predictor import ChronosPredictor
|
||||||
|
predictor = ChronosPredictor(model_name="mock-model")
|
||||||
|
# last close = 100
|
||||||
|
daily = {"005930": _daily_ohlcv(list(range(41, 101)))} # last = 100
|
||||||
|
result = predictor.predict_batch(daily)
|
||||||
|
assert abs(result["005930"].median - 0.10) < 0.001
|
||||||
@@ -126,3 +126,36 @@ async def test_get_asking_price_computes_bid_ratio(kis_client_factory):
|
|||||||
assert "as_of" in data
|
assert "as_of" in data
|
||||||
finally:
|
finally:
|
||||||
await client.close()
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_get_daily_ohlcv_returns_60_bars(kis_client_factory):
|
||||||
|
"""KIS daily endpoint returns 60 ascending bars after parsing."""
|
||||||
|
# Build 60 KIS-format daily bars (descending dates as KIS does)
|
||||||
|
sample_output2 = []
|
||||||
|
for i in range(60):
|
||||||
|
# Generate a fake date 60 days ago, descending
|
||||||
|
day = 60 - i
|
||||||
|
sample_output2.append({
|
||||||
|
"stck_bsop_date": f"2026{(((day-1)//30)+1):02d}{(((day-1)%30)+1):02d}",
|
||||||
|
"stck_oprc": "78000", "stck_hgpr": "78500",
|
||||||
|
"stck_lwpr": "77800", "stck_clpr": str(78000 + i),
|
||||||
|
"acml_vol": "12345",
|
||||||
|
})
|
||||||
|
|
||||||
|
respx.get(
|
||||||
|
"https://openapivts.koreainvestment.com:29443/uapi/domestic-stock/v1/quotations/inquire-daily-itemchartprice"
|
||||||
|
).mock(return_value=httpx.Response(200, json={"output2": sample_output2}))
|
||||||
|
|
||||||
|
client = kis_client_factory()
|
||||||
|
try:
|
||||||
|
bars = await client.get_daily_ohlcv("005930", days=60)
|
||||||
|
# KIS returns descending; client reverses to ascending
|
||||||
|
assert len(bars) == 60
|
||||||
|
# Ascending order: first item has smaller datetime than last
|
||||||
|
assert bars[0]["datetime"] < bars[-1]["datetime"]
|
||||||
|
assert isinstance(bars[0]["open"], int)
|
||||||
|
assert isinstance(bars[0]["close"], int)
|
||||||
|
assert "datetime" in bars[0]
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
|||||||
92
signal_v2/tests/test_momentum_classifier.py
Normal file
92
signal_v2/tests/test_momentum_classifier.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""Tests for minute momentum classifier."""
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
from signal_v2.momentum_classifier import (
|
||||||
|
aggregate_1min_to_5min, classify_minute_momentum,
|
||||||
|
STRONG_UP, WEAK_UP, NEUTRAL, WEAK_DOWN, STRONG_DOWN,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _bar(open_, high, low, close, volume):
|
||||||
|
return {
|
||||||
|
"datetime": "2026-05-18T09:00:00+09:00",
|
||||||
|
"open": open_, "high": high, "low": low, "close": close, "volume": volume,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_chunks(num_chunks_up: int, num_chunks_total: int, base_vol: int = 1000):
|
||||||
|
"""num_chunks_total 개의 5-bar 청크. num_chunks_up 청크는 양봉, 나머지는 음봉.
|
||||||
|
각 청크는 5개 1분봉. 거래량 = base_vol per bar.
|
||||||
|
"""
|
||||||
|
bars = []
|
||||||
|
for i in range(num_chunks_total):
|
||||||
|
is_up = i < num_chunks_up
|
||||||
|
o, c = (100, 110) if is_up else (110, 100)
|
||||||
|
for j in range(5):
|
||||||
|
bars.append(_bar(o, max(o, c) + 5, min(o, c) - 5, c, base_vol))
|
||||||
|
return bars
|
||||||
|
|
||||||
|
|
||||||
|
def test_strong_up_5_consecutive_green_with_high_volume():
|
||||||
|
"""직전 5개 5분봉 모두 양봉 + 거래량 1.5x → STRONG_UP."""
|
||||||
|
# 60분 (12 5분봉) 데이터: 7 normal + 5 high-vol up
|
||||||
|
older = _make_chunks(num_chunks_up=3, num_chunks_total=7, base_vol=1000)
|
||||||
|
recent = _make_chunks(num_chunks_up=5, num_chunks_total=5, base_vol=2500)
|
||||||
|
minute_bars = deque(older + recent, maxlen=60)
|
||||||
|
assert classify_minute_momentum(minute_bars) == STRONG_UP
|
||||||
|
|
||||||
|
|
||||||
|
def test_weak_up_3of5_green_normal_volume():
|
||||||
|
"""직전 5개 5분봉 중 3-4개 양봉 + 거래량 ≥ 1.0x → WEAK_UP."""
|
||||||
|
older = _make_chunks(num_chunks_up=3, num_chunks_total=7, base_vol=1000)
|
||||||
|
# 5 chunks: 3 up + 2 down, normal vol
|
||||||
|
recent_up = _make_chunks(num_chunks_up=3, num_chunks_total=3, base_vol=1000)
|
||||||
|
recent_down = _make_chunks(num_chunks_up=0, num_chunks_total=2, base_vol=1000)
|
||||||
|
minute_bars = deque(older + recent_up + recent_down, maxlen=60)
|
||||||
|
assert classify_minute_momentum(minute_bars) == WEAK_UP
|
||||||
|
|
||||||
|
|
||||||
|
def test_neutral_mixed():
|
||||||
|
"""up_count=2, vol normal → NEUTRAL (rule 미해당)."""
|
||||||
|
older = _make_chunks(num_chunks_up=3, num_chunks_total=7, base_vol=1000)
|
||||||
|
recent_up = _make_chunks(num_chunks_up=2, num_chunks_total=2, base_vol=1000)
|
||||||
|
recent_down = _make_chunks(num_chunks_up=0, num_chunks_total=3, base_vol=1000)
|
||||||
|
minute_bars = deque(older + recent_up + recent_down, maxlen=60)
|
||||||
|
# up_count=2, vol_mult=1.0 → 어느 분기 조건도 만족 안 함 → NEUTRAL
|
||||||
|
assert classify_minute_momentum(minute_bars) == NEUTRAL
|
||||||
|
|
||||||
|
|
||||||
|
def test_weak_down_low_green_low_volume():
|
||||||
|
"""up_count <= 2 + vol < 1.0 → WEAK_DOWN."""
|
||||||
|
older = _make_chunks(num_chunks_up=3, num_chunks_total=7, base_vol=1000)
|
||||||
|
recent_up = _make_chunks(num_chunks_up=1, num_chunks_total=1, base_vol=500)
|
||||||
|
recent_down = _make_chunks(num_chunks_up=0, num_chunks_total=4, base_vol=500)
|
||||||
|
minute_bars = deque(older + recent_up + recent_down, maxlen=60)
|
||||||
|
# recent 5 chunks avg vol = 500, long 12 avg ≈ (7*1000 + 5*500) / 12 ≈ 791 → vol_mult ≈ 0.63
|
||||||
|
assert classify_minute_momentum(minute_bars) == WEAK_DOWN
|
||||||
|
|
||||||
|
|
||||||
|
def test_strong_down_5_consecutive_red_high_volume():
|
||||||
|
"""직전 5개 5분봉 모두 음봉 + 거래량 1.5x → STRONG_DOWN."""
|
||||||
|
older = _make_chunks(num_chunks_up=3, num_chunks_total=7, base_vol=1000)
|
||||||
|
recent = _make_chunks(num_chunks_up=0, num_chunks_total=5, base_vol=2500)
|
||||||
|
minute_bars = deque(older + recent, maxlen=60)
|
||||||
|
assert classify_minute_momentum(minute_bars) == STRONG_DOWN
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_1min_to_5min_correctness():
|
||||||
|
"""5 1분봉 → 1개 5분봉 — open/close/high/low/volume 정확."""
|
||||||
|
bars = [
|
||||||
|
_bar(100, 105, 99, 102, 1000),
|
||||||
|
_bar(102, 108, 101, 107, 1500),
|
||||||
|
_bar(107, 110, 105, 106, 800),
|
||||||
|
_bar(106, 109, 104, 108, 1200),
|
||||||
|
_bar(108, 112, 107, 111, 900),
|
||||||
|
]
|
||||||
|
result = aggregate_1min_to_5min(bars)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["open"] == 100 # 첫 bar
|
||||||
|
assert result[0]["close"] == 111 # 마지막 bar
|
||||||
|
assert result[0]["high"] == 112 # max
|
||||||
|
assert result[0]["low"] == 99 # min
|
||||||
|
assert result[0]["volume"] == 5400 # sum
|
||||||
@@ -53,3 +53,45 @@ def test_websocket_message_updates_state_asking_price():
|
|||||||
"current_price": 78500, "as_of": "2026-05-18T10:00:00+09:00"})
|
"current_price": 78500, "as_of": "2026-05-18T10:00:00+09:00"})
|
||||||
assert state.asking_price["005930"]["bid_total"] == 1000
|
assert state.asking_price["005930"]["bid_total"] == 1000
|
||||||
assert "asking_price/005930" in state.last_updated
|
assert "asking_price/005930" in state.last_updated
|
||||||
|
|
||||||
|
|
||||||
|
async def test_post_close_cycle_updates_chronos_predictions():
|
||||||
|
"""mock kis + mock chronos → state.chronos_predictions + state.daily_ohlcv 갱신."""
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from signal_v2.pull_worker import _run_post_close_cycle
|
||||||
|
from signal_v2.chronos_predictor import ChronosPrediction
|
||||||
|
from signal_v2.state import PollState
|
||||||
|
|
||||||
|
state = PollState()
|
||||||
|
state.portfolio = {"holdings": [{"ticker": "005930"}]}
|
||||||
|
state.screener_preview = {"items": [{"ticker": "000660"}]}
|
||||||
|
|
||||||
|
kis_mock = MagicMock()
|
||||||
|
daily_005930 = [{"datetime": f"2026-05-{i+1:02d}", "open": 100, "high": 105,
|
||||||
|
"low": 95, "close": 100 + i, "volume": 1000} for i in range(60)]
|
||||||
|
daily_000660 = [{"datetime": f"2026-05-{i+1:02d}", "open": 200, "high": 210,
|
||||||
|
"low": 190, "close": 200 + i, "volume": 2000} for i in range(60)]
|
||||||
|
# _run_post_close_cycle iterates tickers and calls get_daily_ohlcv per ticker.
|
||||||
|
# Order depends on set() so use side_effect mapping if possible, otherwise list.
|
||||||
|
async def fake_daily(ticker, days=60):
|
||||||
|
if ticker == "005930":
|
||||||
|
return daily_005930
|
||||||
|
if ticker == "000660":
|
||||||
|
return daily_000660
|
||||||
|
return []
|
||||||
|
kis_mock.get_daily_ohlcv = AsyncMock(side_effect=fake_daily)
|
||||||
|
|
||||||
|
chronos_mock = MagicMock()
|
||||||
|
chronos_mock.predict_batch = MagicMock(return_value={
|
||||||
|
"005930": ChronosPrediction(0.02, -0.01, 0.04, 0.85, "2026-05-18T16:00:00+09:00"),
|
||||||
|
"000660": ChronosPrediction(0.03, -0.02, 0.06, 0.75, "2026-05-18T16:00:00+09:00"),
|
||||||
|
})
|
||||||
|
|
||||||
|
await _run_post_close_cycle(kis_mock, chronos_mock, state)
|
||||||
|
|
||||||
|
assert "005930" in state.chronos_predictions
|
||||||
|
assert "000660" in state.chronos_predictions
|
||||||
|
assert state.chronos_predictions["005930"]["median"] == 0.02
|
||||||
|
assert state.chronos_predictions["005930"]["conf"] == 0.85
|
||||||
|
assert "005930" in state.daily_ohlcv
|
||||||
|
assert "chronos/005930" in state.last_updated
|
||||||
|
|||||||
Reference in New Issue
Block a user