diff --git a/ai_trade/heartbeat.py b/ai_trade/heartbeat.py new file mode 100644 index 0000000..a2d7241 --- /dev/null +++ b/ai_trade/heartbeat.py @@ -0,0 +1,57 @@ +"""ai_trade heartbeat — NAS Redis로 worker:ai_trade:heartbeat SET. + +Global Constraints 계약 1: kind=trader, state=market_open|market_closed. +ai_trade는 Windows 호스트 실행이라 _shared import 경로가 달라 자체 미니 헬퍼로 둔다. +""" +from __future__ import annotations + +import asyncio +import datetime as dt +import json +import logging +import os + +import redis.asyncio as aioredis + +logger = logging.getLogger(__name__) + +REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379") +KEY = "worker:ai_trade:heartbeat" +INTERVAL = int(os.getenv("HEARTBEAT_INTERVAL", "15")) +TTL = int(os.getenv("HEARTBEAT_TTL", "45")) + + +def build_trader_payload(state: str, signals: int = 0) -> str: + """JSON 문자열 반환. state: 'market_open' | 'market_closed'.""" + return json.dumps({ + "name": "ai_trade", + "kind": "trader", + "state": state, + "ts": dt.datetime.now(dt.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"), + "last_job_at": None, + "jobs_done": signals, + "jobs_failed": 0, + }) + + +async def heartbeat_loop(state_fn) -> None: + """Redis에 HEARTBEAT_INTERVAL마다 SET EX TTL. + + Args: + state_fn: () -> (state: str, signals: int). 호출자가 폴링 윈도우 판정 주입. + """ + redis = aioredis.from_url(REDIS_URL, decode_responses=False) + try: + while True: + try: + state, signals = state_fn() + payload = build_trader_payload(state, signals) + await redis.set(KEY, payload, ex=TTL) + logger.debug("ai_trade heartbeat sent: state=%s signals=%d", state, signals) + except asyncio.CancelledError: + raise + except Exception: + logger.exception("ai_trade heartbeat 실패 — 다음 주기에 재시도") + await asyncio.sleep(INTERVAL) + finally: + await redis.aclose() diff --git a/ai_trade/main.py b/ai_trade/main.py index 0d78344..1ed7e5d 100644 --- a/ai_trade/main.py +++ b/ai_trade/main.py @@ -3,9 +3,12 @@ from __future__ import annotations import asyncio import logging from contextlib import asynccontextmanager +from datetime import datetime +from zoneinfo import ZoneInfo from fastapi import FastAPI +from ai_trade import heartbeat as _hb from ai_trade import state as state_mod from ai_trade.chronos_predictor import ChronosPredictor from ai_trade.config import get_settings @@ -13,8 +16,11 @@ from ai_trade.kis_client import KISClient from ai_trade.kis_websocket import KISWebSocket from ai_trade.pull_worker import poll_loop, make_asking_price_callback from ai_trade.rate_limit import SignalDedup +from ai_trade.scheduler import _is_polling_window, _is_market_day from ai_trade.stock_client import StockClient +_KST = ZoneInfo("Asia/Seoul") + logger = logging.getLogger(__name__) @@ -23,6 +29,7 @@ class AppContext: dedup: SignalDedup | None = None shutdown: asyncio.Event | None = None poll_task: asyncio.Task | None = None + hb_task: asyncio.Task | None = None kis_client: KISClient | None = None kis_ws: KISWebSocket | None = None chronos: ChronosPredictor | None = None @@ -87,9 +94,27 @@ async def lifespan(app: FastAPI): ) ) + def _trader_state() -> tuple[str, int]: + """scheduler의 실제 폴링 윈도우 판정으로 market_open/market_closed 결정.""" + now = datetime.now(_KST) + is_open = _is_market_day(now) and _is_polling_window(now) + state_str = "market_open" if is_open else "market_closed" + signals = len(state_mod.state.signals) + return state_str, signals + + _ctx.hb_task = asyncio.create_task(_hb.heartbeat_loop(_trader_state)) + yield - # Shutdown + # Shutdown heartbeat task + if _ctx.hb_task is not None: + _ctx.hb_task.cancel() + try: + await _ctx.hb_task + except asyncio.CancelledError: + pass + + # Shutdown poll task if _ctx.shutdown is not None: _ctx.shutdown.set() if _ctx.poll_task is not None: diff --git a/ai_trade/tests/test_heartbeat.py b/ai_trade/tests/test_heartbeat.py new file mode 100644 index 0000000..f1cd97e --- /dev/null +++ b/ai_trade/tests/test_heartbeat.py @@ -0,0 +1,38 @@ +"""Tests for ai_trade heartbeat payload builder.""" +import json + +import pytest + + +def test_trader_payload_market_open(): + from ai_trade.heartbeat import build_trader_payload + + p = json.loads(build_trader_payload("market_open", signals=2)) + assert p["name"] == "ai_trade" + assert p["kind"] == "trader" + assert p["state"] == "market_open" + assert p["ts"].endswith("Z") + assert p["jobs_done"] == 2 + + +def test_trader_payload_market_closed(): + from ai_trade.heartbeat import build_trader_payload + + p = json.loads(build_trader_payload("market_closed")) + assert p["name"] == "ai_trade" + assert p["kind"] == "trader" + assert p["state"] == "market_closed" + assert p["jobs_done"] == 0 + assert p["jobs_failed"] == 0 + assert p["last_job_at"] is None + + +def test_trader_payload_ts_format(): + """ts 필드가 ISO 8601 UTC 형식 (YYYY-MM-DDTHH:MM:SSZ)인지 확인.""" + from ai_trade.heartbeat import build_trader_payload + import re + + p = json.loads(build_trader_payload("market_open")) + assert re.match(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z", p["ts"]), ( + f"ts={p['ts']!r} does not match expected UTC format" + ) diff --git a/requirements.txt b/requirements.txt index abadae1..5ca5336 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ pytest>=8.0 pytest-asyncio>=0.23 respx>=0.21 websockets>=12 +redis>=5.0 # Phase 3b dependencies (Chronos-2 + ML) transformers>=4.40 chronos-forecasting>=1.4