feat(ai_trade): NAS Redis heartbeat (trader market_open/closed)
- ai_trade/heartbeat.py: build_trader_payload() + heartbeat_loop() 자체 미니 헬퍼 (Windows 호스트 실행이라 _shared import 경로 달라 독립 구현, 계약은 동일) - ai_trade/main.py: lifespan에 hb_task spawn + shutdown 시 cancel state_fn = scheduler._is_market_day & _is_polling_window(KST now) 조합 signals = len(state.signals) 실시간 주입 - requirements.txt: redis>=5.0 추가 - ai_trade/tests/test_heartbeat.py: build_trader_payload 3케이스 TDD 검증 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_019LV86jBozkNhSFXJA412fq
This commit is contained in:
57
ai_trade/heartbeat.py
Normal file
57
ai_trade/heartbeat.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""ai_trade heartbeat — NAS Redis로 worker:ai_trade:heartbeat SET.
|
||||||
|
|
||||||
|
Global Constraints 계약 1: kind=trader, state=market_open|market_closed.
|
||||||
|
ai_trade는 Windows 호스트 실행이라 _shared import 경로가 달라 자체 미니 헬퍼로 둔다.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import datetime as dt
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379")
|
||||||
|
KEY = "worker:ai_trade:heartbeat"
|
||||||
|
INTERVAL = int(os.getenv("HEARTBEAT_INTERVAL", "15"))
|
||||||
|
TTL = int(os.getenv("HEARTBEAT_TTL", "45"))
|
||||||
|
|
||||||
|
|
||||||
|
def build_trader_payload(state: str, signals: int = 0) -> str:
|
||||||
|
"""JSON 문자열 반환. state: 'market_open' | 'market_closed'."""
|
||||||
|
return json.dumps({
|
||||||
|
"name": "ai_trade",
|
||||||
|
"kind": "trader",
|
||||||
|
"state": state,
|
||||||
|
"ts": dt.datetime.now(dt.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||||
|
"last_job_at": None,
|
||||||
|
"jobs_done": signals,
|
||||||
|
"jobs_failed": 0,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
async def heartbeat_loop(state_fn) -> None:
|
||||||
|
"""Redis에 HEARTBEAT_INTERVAL마다 SET EX TTL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_fn: () -> (state: str, signals: int). 호출자가 폴링 윈도우 판정 주입.
|
||||||
|
"""
|
||||||
|
redis = aioredis.from_url(REDIS_URL, decode_responses=False)
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
state, signals = state_fn()
|
||||||
|
payload = build_trader_payload(state, signals)
|
||||||
|
await redis.set(KEY, payload, ex=TTL)
|
||||||
|
logger.debug("ai_trade heartbeat sent: state=%s signals=%d", state, signals)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
logger.exception("ai_trade heartbeat 실패 — 다음 주기에 재시도")
|
||||||
|
await asyncio.sleep(INTERVAL)
|
||||||
|
finally:
|
||||||
|
await redis.aclose()
|
||||||
@@ -3,9 +3,12 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from datetime import datetime
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from ai_trade import heartbeat as _hb
|
||||||
from ai_trade import state as state_mod
|
from ai_trade import state as state_mod
|
||||||
from ai_trade.chronos_predictor import ChronosPredictor
|
from ai_trade.chronos_predictor import ChronosPredictor
|
||||||
from ai_trade.config import get_settings
|
from ai_trade.config import get_settings
|
||||||
@@ -13,8 +16,11 @@ from ai_trade.kis_client import KISClient
|
|||||||
from ai_trade.kis_websocket import KISWebSocket
|
from ai_trade.kis_websocket import KISWebSocket
|
||||||
from ai_trade.pull_worker import poll_loop, make_asking_price_callback
|
from ai_trade.pull_worker import poll_loop, make_asking_price_callback
|
||||||
from ai_trade.rate_limit import SignalDedup
|
from ai_trade.rate_limit import SignalDedup
|
||||||
|
from ai_trade.scheduler import _is_polling_window, _is_market_day
|
||||||
from ai_trade.stock_client import StockClient
|
from ai_trade.stock_client import StockClient
|
||||||
|
|
||||||
|
_KST = ZoneInfo("Asia/Seoul")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -23,6 +29,7 @@ class AppContext:
|
|||||||
dedup: SignalDedup | None = None
|
dedup: SignalDedup | None = None
|
||||||
shutdown: asyncio.Event | None = None
|
shutdown: asyncio.Event | None = None
|
||||||
poll_task: asyncio.Task | None = None
|
poll_task: asyncio.Task | None = None
|
||||||
|
hb_task: asyncio.Task | None = None
|
||||||
kis_client: KISClient | None = None
|
kis_client: KISClient | None = None
|
||||||
kis_ws: KISWebSocket | None = None
|
kis_ws: KISWebSocket | None = None
|
||||||
chronos: ChronosPredictor | None = None
|
chronos: ChronosPredictor | None = None
|
||||||
@@ -87,9 +94,27 @@ async def lifespan(app: FastAPI):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _trader_state() -> tuple[str, int]:
|
||||||
|
"""scheduler의 실제 폴링 윈도우 판정으로 market_open/market_closed 결정."""
|
||||||
|
now = datetime.now(_KST)
|
||||||
|
is_open = _is_market_day(now) and _is_polling_window(now)
|
||||||
|
state_str = "market_open" if is_open else "market_closed"
|
||||||
|
signals = len(state_mod.state.signals)
|
||||||
|
return state_str, signals
|
||||||
|
|
||||||
|
_ctx.hb_task = asyncio.create_task(_hb.heartbeat_loop(_trader_state))
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Shutdown
|
# Shutdown heartbeat task
|
||||||
|
if _ctx.hb_task is not None:
|
||||||
|
_ctx.hb_task.cancel()
|
||||||
|
try:
|
||||||
|
await _ctx.hb_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Shutdown poll task
|
||||||
if _ctx.shutdown is not None:
|
if _ctx.shutdown is not None:
|
||||||
_ctx.shutdown.set()
|
_ctx.shutdown.set()
|
||||||
if _ctx.poll_task is not None:
|
if _ctx.poll_task is not None:
|
||||||
|
|||||||
38
ai_trade/tests/test_heartbeat.py
Normal file
38
ai_trade/tests/test_heartbeat.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""Tests for ai_trade heartbeat payload builder."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def test_trader_payload_market_open():
|
||||||
|
from ai_trade.heartbeat import build_trader_payload
|
||||||
|
|
||||||
|
p = json.loads(build_trader_payload("market_open", signals=2))
|
||||||
|
assert p["name"] == "ai_trade"
|
||||||
|
assert p["kind"] == "trader"
|
||||||
|
assert p["state"] == "market_open"
|
||||||
|
assert p["ts"].endswith("Z")
|
||||||
|
assert p["jobs_done"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_trader_payload_market_closed():
|
||||||
|
from ai_trade.heartbeat import build_trader_payload
|
||||||
|
|
||||||
|
p = json.loads(build_trader_payload("market_closed"))
|
||||||
|
assert p["name"] == "ai_trade"
|
||||||
|
assert p["kind"] == "trader"
|
||||||
|
assert p["state"] == "market_closed"
|
||||||
|
assert p["jobs_done"] == 0
|
||||||
|
assert p["jobs_failed"] == 0
|
||||||
|
assert p["last_job_at"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_trader_payload_ts_format():
|
||||||
|
"""ts 필드가 ISO 8601 UTC 형식 (YYYY-MM-DDTHH:MM:SSZ)인지 확인."""
|
||||||
|
from ai_trade.heartbeat import build_trader_payload
|
||||||
|
import re
|
||||||
|
|
||||||
|
p = json.loads(build_trader_payload("market_open"))
|
||||||
|
assert re.match(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z", p["ts"]), (
|
||||||
|
f"ts={p['ts']!r} does not match expected UTC format"
|
||||||
|
)
|
||||||
@@ -7,6 +7,7 @@ pytest>=8.0
|
|||||||
pytest-asyncio>=0.23
|
pytest-asyncio>=0.23
|
||||||
respx>=0.21
|
respx>=0.21
|
||||||
websockets>=12
|
websockets>=12
|
||||||
|
redis>=5.0
|
||||||
# Phase 3b dependencies (Chronos-2 + ML)
|
# Phase 3b dependencies (Chronos-2 + ML)
|
||||||
transformers>=4.40
|
transformers>=4.40
|
||||||
chronos-forecasting>=1.4
|
chronos-forecasting>=1.4
|
||||||
|
|||||||
Reference in New Issue
Block a user