diff --git a/signal_v2/pull_worker.py b/signal_v2/pull_worker.py index b369867..ed57636 100644 --- a/signal_v2/pull_worker.py +++ b/signal_v2/pull_worker.py @@ -2,8 +2,10 @@ from __future__ import annotations import asyncio import logging +from collections import deque from datetime import datetime +from signal_v2.kis_client import KISClient from signal_v2.scheduler import ( KST, _is_market_day, _is_polling_window, _next_interval, ) @@ -14,7 +16,8 @@ logger = logging.getLogger(__name__) async def poll_loop( - client: StockClient, state: PollState, shutdown: asyncio.Event + client: StockClient, state: PollState, shutdown: asyncio.Event, + kis_client: KISClient | None = None, ) -> None: """FastAPI lifespan 에서 asyncio.create_task 로 시작.""" logger.info("poll_loop started") @@ -22,7 +25,7 @@ async def poll_loop( now = datetime.now(KST) if _is_market_day(now) and _is_polling_window(now): try: - await _run_polling_cycle(client, state) + await _run_polling_cycle(client, state, kis_client=kis_client) except Exception: logger.exception("poll cycle failed") interval = _next_interval(now) @@ -34,8 +37,11 @@ async def poll_loop( logger.info("poll_loop ended") -async def _run_polling_cycle(client: StockClient, state: PollState) -> None: - """3 endpoint 병렬 fetch + state 갱신.""" +async def _run_polling_cycle( + client: StockClient, state: PollState, + kis_client: KISClient | None = None, +) -> None: + """기존 3 endpoint (stock) + KIS 분봉 fetch.""" portfolio, sentiment, screener = await asyncio.gather( client.get_portfolio(), client.get_news_sentiment(), @@ -56,3 +62,66 @@ async def _run_polling_cycle(client: StockClient, state: PollState) -> None: else: state.fetch_errors[name] = state.fetch_errors.get(name, 0) + 1 logger.warning("fetch %s failed: %r", name, result) + + # KIS 분봉 + 호가 (kis_client 주어졌을 때만) + if kis_client is not None: + try: + await _run_kis_minute_cycle(kis_client, state) + except Exception: + logger.exception("kis minute cycle failed") + + +async def _run_kis_minute_cycle(kis_client: KISClient, state: PollState) -> None: + """KIS 분봉 + 호가 fetch + state 갱신. + + - 분봉: portfolio + screener Top-N union 종목 모두 + - 호가 (REST): screener-only 종목 (portfolio 는 WebSocket 으로 들어옴) + """ + portfolio_tickers = _portfolio_tickers(state) + screener_tickers = _screener_tickers(state) + all_tickers = list(set(portfolio_tickers) | set(screener_tickers)) + + # 분봉 fetch (병렬) + minute_results = await asyncio.gather(*[ + kis_client.get_minute_ohlcv(t) for t in all_tickers + ], return_exceptions=True) + now_iso = datetime.now(KST).isoformat() + for ticker, result in zip(all_tickers, minute_results): + if isinstance(result, list): + buf = state.minute_bars.setdefault(ticker, deque(maxlen=60)) + buf.extend(result) + state.last_updated[f"minute_bars/{ticker}"] = now_iso + else: + state.fetch_errors[f"minute_bars/{ticker}"] = ( + state.fetch_errors.get(f"minute_bars/{ticker}", 0) + 1 + ) + + # 호가 fetch (REST) — screener-only + screener_only = list(set(screener_tickers) - set(portfolio_tickers)) + asking_results = await asyncio.gather(*[ + kis_client.get_asking_price(t) for t in screener_only + ], return_exceptions=True) + for ticker, result in zip(screener_only, asking_results): + if isinstance(result, dict): + state.asking_price[ticker] = result + state.last_updated[f"asking_price/{ticker}"] = now_iso + + +def make_asking_price_callback(state: PollState): + """KIS WebSocket on_asking_price callback factory.""" + def _cb(ticker: str, data: dict) -> None: + state.asking_price[ticker] = data + state.last_updated[f"asking_price/{ticker}"] = datetime.now(KST).isoformat() + return _cb + + +def _portfolio_tickers(state: PollState) -> list[str]: + if state.portfolio is None: + return [] + return [h["ticker"] for h in state.portfolio.get("holdings", []) if "ticker" in h] + + +def _screener_tickers(state: PollState) -> list[str]: + if state.screener_preview is None: + return [] + return [i["ticker"] for i in state.screener_preview.get("items", []) if "ticker" in i] diff --git a/signal_v2/tests/test_pull_worker.py b/signal_v2/tests/test_pull_worker.py new file mode 100644 index 0000000..81e8265 --- /dev/null +++ b/signal_v2/tests/test_pull_worker.py @@ -0,0 +1,55 @@ +"""Tests for pull_worker (Phase 3a additions).""" +from collections import deque +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from signal_v2.state import PollState + + +async def test_minute_polling_cycle_updates_state_minute_bars(): + """KIS REST mock 의 분봉 데이터가 state.minute_bars[ticker] deque 에 들어간다.""" + from signal_v2.pull_worker import _run_kis_minute_cycle + + state = PollState() + state.portfolio = {"holdings": [{"ticker": "005930"}, {"ticker": "000660"}]} + state.screener_preview = { + "items": [{"ticker": "005930"}, {"ticker": "035720"}] + } + + kis_client_mock = MagicMock() + kis_client_mock.get_minute_ohlcv = AsyncMock(side_effect=[ + [{"datetime": "2026-05-18T09:00:00+09:00", "open": 78000, + "high": 78500, "low": 77900, "close": 78300, "volume": 12345}], + [{"datetime": "2026-05-18T09:00:00+09:00", "open": 180000, + "high": 181000, "low": 179800, "close": 180500, "volume": 5000}], + [{"datetime": "2026-05-18T09:00:00+09:00", "open": 51000, + "high": 51200, "low": 50800, "close": 51100, "volume": 8000}], + ]) + kis_client_mock.get_asking_price = AsyncMock(return_value={ + "bid_total": 600, "ask_total": 400, "bid_ratio": 0.6, + "current_price": 51100, "as_of": "2026-05-18T09:00:30+09:00", + }) + + await _run_kis_minute_cycle(kis_client_mock, state) + + # 3 unique tickers (005930, 000660, 035720) + assert "005930" in state.minute_bars + assert "000660" in state.minute_bars + assert "035720" in state.minute_bars + assert len(state.minute_bars["005930"]) >= 1 + # asking_price 만 screener-only ticker (035720) 에 들어가야 함 + # (portfolio = 005930, 000660 는 WebSocket 으로 들어옴) + assert "035720" in state.asking_price + + +def test_websocket_message_updates_state_asking_price(): + """WebSocket callback factory → state.asking_price 갱신.""" + from signal_v2.pull_worker import make_asking_price_callback + + state = PollState() + cb = make_asking_price_callback(state) + cb("005930", {"bid_total": 1000, "ask_total": 800, "bid_ratio": 0.555, + "current_price": 78500, "as_of": "2026-05-18T10:00:00+09:00"}) + assert state.asking_price["005930"]["bid_total"] == 1000 + assert "asking_price/005930" in state.last_updated