Files
ai-trade/signal_v2/pull_worker.py
gahusb 9e5fecb369 feat(signal_v2-phase3b): post-close cycle + minute momentum update
scheduler._is_post_close_trigger: 16:00 KST ±1min detection (market day).
pull_worker:
- _run_post_close_cycle: daily fetch (60일) + chronos batch predict →
  state.chronos_predictions + state.daily_ohlcv.
- update_minute_momentum_for_all: 매 cycle 마다 state.minute_momentum 갱신.
- poll_loop signature 확장 (chronos optional).

45 tests pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-16 18:04:32 +09:00

185 lines
6.9 KiB
Python

"""Polling loop — async cron + state update."""
from __future__ import annotations
import asyncio
import logging
from collections import deque
from datetime import datetime
from signal_v2.kis_client import KISClient
from signal_v2.scheduler import (
KST, _is_market_day, _is_polling_window, _next_interval, _is_post_close_trigger,
)
from signal_v2.state import PollState
from signal_v2.stock_client import StockClient
logger = logging.getLogger(__name__)
async def poll_loop(
client: StockClient, state: PollState, shutdown: asyncio.Event,
kis_client: KISClient | None = None,
chronos=None,
) -> None:
"""FastAPI lifespan 에서 asyncio.create_task 로 시작."""
logger.info("poll_loop started")
while not shutdown.is_set():
now = datetime.now(KST)
if _is_market_day(now) and _is_polling_window(now):
try:
await _run_polling_cycle(client, state, kis_client=kis_client)
except Exception:
logger.exception("poll cycle failed")
# Minute momentum 갱신 (매 cycle)
try:
update_minute_momentum_for_all(state)
except Exception:
logger.exception("minute momentum update failed")
# Post-close trigger (16:00 KST)
if _is_post_close_trigger(now) and chronos is not None and kis_client is not None:
try:
await _run_post_close_cycle(kis_client, chronos, state)
except Exception:
logger.exception("post-close cycle failed")
interval = _next_interval(now)
try:
await asyncio.wait_for(shutdown.wait(), timeout=interval)
break
except asyncio.TimeoutError:
continue
logger.info("poll_loop ended")
async def _run_polling_cycle(
client: StockClient, state: PollState,
kis_client: KISClient | None = None,
) -> None:
"""기존 3 endpoint (stock) + KIS 분봉 fetch."""
portfolio, sentiment, screener = await asyncio.gather(
client.get_portfolio(),
client.get_news_sentiment(),
client.run_screener_preview(),
return_exceptions=True,
)
now_iso = datetime.now(KST).isoformat()
for name, result in (
("portfolio", portfolio),
("news_sentiment", sentiment),
("screener_preview", screener),
):
if isinstance(result, dict):
setattr(state, name, result)
state.last_updated[name] = now_iso
state.fetch_errors[name] = 0
else:
state.fetch_errors[name] = state.fetch_errors.get(name, 0) + 1
logger.warning("fetch %s failed: %r", name, result)
# KIS 분봉 + 호가 (kis_client 주어졌을 때만)
if kis_client is not None:
try:
await _run_kis_minute_cycle(kis_client, state)
except Exception:
logger.exception("kis minute cycle failed")
async def _run_kis_minute_cycle(kis_client: KISClient, state: PollState) -> None:
"""KIS 분봉 + 호가 fetch + state 갱신.
- 분봉: portfolio + screener Top-N union 종목 모두
- 호가 (REST): screener-only 종목 (portfolio 는 WebSocket 으로 들어옴)
"""
portfolio_tickers = _portfolio_tickers(state)
screener_tickers = _screener_tickers(state)
all_tickers = list(set(portfolio_tickers) | set(screener_tickers))
# 분봉 fetch (병렬)
minute_results = await asyncio.gather(*[
kis_client.get_minute_ohlcv(t) for t in all_tickers
], return_exceptions=True)
now_iso = datetime.now(KST).isoformat()
for ticker, result in zip(all_tickers, minute_results):
if isinstance(result, list):
buf = state.minute_bars.setdefault(ticker, deque(maxlen=60))
buf.extend(result)
state.last_updated[f"minute_bars/{ticker}"] = now_iso
else:
state.fetch_errors[f"minute_bars/{ticker}"] = (
state.fetch_errors.get(f"minute_bars/{ticker}", 0) + 1
)
# 호가 fetch (REST) — screener-only
screener_only = list(set(screener_tickers) - set(portfolio_tickers))
asking_results = await asyncio.gather(*[
kis_client.get_asking_price(t) for t in screener_only
], return_exceptions=True)
for ticker, result in zip(screener_only, asking_results):
if isinstance(result, dict):
state.asking_price[ticker] = result
state.last_updated[f"asking_price/{ticker}"] = now_iso
def make_asking_price_callback(state: PollState):
"""KIS WebSocket on_asking_price callback factory."""
def _cb(ticker: str, data: dict) -> None:
state.asking_price[ticker] = data
state.last_updated[f"asking_price/{ticker}"] = datetime.now(KST).isoformat()
return _cb
def _portfolio_tickers(state: PollState) -> list[str]:
if state.portfolio is None:
return []
return [h["ticker"] for h in state.portfolio.get("holdings", []) if "ticker" in h]
def _screener_tickers(state: PollState) -> list[str]:
if state.screener_preview is None:
return []
return [i["ticker"] for i in state.screener_preview.get("items", []) if "ticker" in i]
async def _run_post_close_cycle(kis_client, chronos, state) -> None:
"""16:00 KST 종가 후 1회: daily fetch + chronos predict."""
tickers = list(set(_portfolio_tickers(state)) | set(_screener_tickers(state)))
if not tickers:
return
daily_results = await asyncio.gather(*[
kis_client.get_daily_ohlcv(t, days=60) for t in tickers
], return_exceptions=True)
daily_dict = {}
for ticker, result in zip(tickers, daily_results):
if isinstance(result, list) and len(result) >= 30:
daily_dict[ticker] = result
state.daily_ohlcv[ticker] = result
elif isinstance(result, Exception):
state.fetch_errors[f"daily_ohlcv/{ticker}"] = (
state.fetch_errors.get(f"daily_ohlcv/{ticker}", 0) + 1
)
if daily_dict and chronos is not None:
try:
predictions = chronos.predict_batch(daily_dict)
except Exception:
logger.exception("chronos predict_batch failed")
return
for ticker, pred in predictions.items():
state.chronos_predictions[ticker] = {
"median": pred.median,
"q10": pred.q10,
"q90": pred.q90,
"conf": pred.conf,
"as_of": pred.as_of,
}
state.last_updated[f"chronos/{ticker}"] = pred.as_of
def update_minute_momentum_for_all(state) -> None:
"""매 분봉 cycle 후 호출 — 모든 종목 모멘텀 갱신."""
from signal_v2.momentum_classifier import classify_minute_momentum
now_iso = datetime.now(KST).isoformat()
for ticker, bars in state.minute_bars.items():
state.minute_momentum[ticker] = classify_minute_momentum(bars)
state.last_updated[f"momentum/{ticker}"] = now_iso