Compare commits
4 Commits
4fb3d12244
...
54fca07d43
| Author | SHA1 | Date | |
|---|---|---|---|
| 54fca07d43 | |||
| 574b5712c3 | |||
| 2ff31b2e76 | |||
| d1b9ff570d |
57
ai_trade/heartbeat.py
Normal file
57
ai_trade/heartbeat.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""ai_trade heartbeat — NAS Redis로 worker:ai_trade:heartbeat SET.
|
||||
|
||||
Global Constraints 계약 1: kind=trader, state=market_open|market_closed.
|
||||
ai_trade는 Windows 호스트 실행이라 _shared import 경로가 달라 자체 미니 헬퍼로 둔다.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime as dt
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379")
|
||||
KEY = "worker:ai_trade:heartbeat"
|
||||
INTERVAL = int(os.getenv("HEARTBEAT_INTERVAL", "15"))
|
||||
TTL = int(os.getenv("HEARTBEAT_TTL", "45"))
|
||||
|
||||
|
||||
def build_trader_payload(state: str, signals: int = 0) -> str:
|
||||
"""JSON 문자열 반환. state: 'market_open' | 'market_closed'."""
|
||||
return json.dumps({
|
||||
"name": "ai_trade",
|
||||
"kind": "trader",
|
||||
"state": state,
|
||||
"ts": dt.datetime.now(dt.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||
"last_job_at": None,
|
||||
"jobs_done": signals,
|
||||
"jobs_failed": 0,
|
||||
})
|
||||
|
||||
|
||||
async def heartbeat_loop(state_fn) -> None:
|
||||
"""Redis에 HEARTBEAT_INTERVAL마다 SET EX TTL.
|
||||
|
||||
Args:
|
||||
state_fn: () -> (state: str, signals: int). 호출자가 폴링 윈도우 판정 주입.
|
||||
"""
|
||||
redis = aioredis.from_url(REDIS_URL, decode_responses=False)
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
state, signals = state_fn()
|
||||
payload = build_trader_payload(state, signals)
|
||||
await redis.set(KEY, payload, ex=TTL)
|
||||
logger.debug("ai_trade heartbeat sent: state=%s signals=%d", state, signals)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("ai_trade heartbeat 실패 — 다음 주기에 재시도")
|
||||
await asyncio.sleep(INTERVAL)
|
||||
finally:
|
||||
await redis.aclose()
|
||||
@@ -3,9 +3,12 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from ai_trade import heartbeat as _hb
|
||||
from ai_trade import state as state_mod
|
||||
from ai_trade.chronos_predictor import ChronosPredictor
|
||||
from ai_trade.config import get_settings
|
||||
@@ -13,8 +16,11 @@ from ai_trade.kis_client import KISClient
|
||||
from ai_trade.kis_websocket import KISWebSocket
|
||||
from ai_trade.pull_worker import poll_loop, make_asking_price_callback
|
||||
from ai_trade.rate_limit import SignalDedup
|
||||
from ai_trade.scheduler import _is_polling_window, _is_market_day
|
||||
from ai_trade.stock_client import StockClient
|
||||
|
||||
_KST = ZoneInfo("Asia/Seoul")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -23,6 +29,7 @@ class AppContext:
|
||||
dedup: SignalDedup | None = None
|
||||
shutdown: asyncio.Event | None = None
|
||||
poll_task: asyncio.Task | None = None
|
||||
hb_task: asyncio.Task | None = None
|
||||
kis_client: KISClient | None = None
|
||||
kis_ws: KISWebSocket | None = None
|
||||
chronos: ChronosPredictor | None = None
|
||||
@@ -87,9 +94,27 @@ async def lifespan(app: FastAPI):
|
||||
)
|
||||
)
|
||||
|
||||
def _trader_state() -> tuple[str, int]:
|
||||
"""scheduler의 실제 폴링 윈도우 판정으로 market_open/market_closed 결정."""
|
||||
now = datetime.now(_KST)
|
||||
is_open = _is_market_day(now) and _is_polling_window(now)
|
||||
state_str = "market_open" if is_open else "market_closed"
|
||||
signals = len(state_mod.state.signals)
|
||||
return state_str, signals
|
||||
|
||||
_ctx.hb_task = asyncio.create_task(_hb.heartbeat_loop(_trader_state))
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
# Shutdown heartbeat task
|
||||
if _ctx.hb_task is not None:
|
||||
_ctx.hb_task.cancel()
|
||||
try:
|
||||
await _ctx.hb_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Shutdown poll task
|
||||
if _ctx.shutdown is not None:
|
||||
_ctx.shutdown.set()
|
||||
if _ctx.poll_task is not None:
|
||||
|
||||
38
ai_trade/tests/test_heartbeat.py
Normal file
38
ai_trade/tests/test_heartbeat.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Tests for ai_trade heartbeat payload builder."""
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_trader_payload_market_open():
|
||||
from ai_trade.heartbeat import build_trader_payload
|
||||
|
||||
p = json.loads(build_trader_payload("market_open", signals=2))
|
||||
assert p["name"] == "ai_trade"
|
||||
assert p["kind"] == "trader"
|
||||
assert p["state"] == "market_open"
|
||||
assert p["ts"].endswith("Z")
|
||||
assert p["jobs_done"] == 2
|
||||
|
||||
|
||||
def test_trader_payload_market_closed():
|
||||
from ai_trade.heartbeat import build_trader_payload
|
||||
|
||||
p = json.loads(build_trader_payload("market_closed"))
|
||||
assert p["name"] == "ai_trade"
|
||||
assert p["kind"] == "trader"
|
||||
assert p["state"] == "market_closed"
|
||||
assert p["jobs_done"] == 0
|
||||
assert p["jobs_failed"] == 0
|
||||
assert p["last_job_at"] is None
|
||||
|
||||
|
||||
def test_trader_payload_ts_format():
|
||||
"""ts 필드가 ISO 8601 UTC 형식 (YYYY-MM-DDTHH:MM:SSZ)인지 확인."""
|
||||
from ai_trade.heartbeat import build_trader_payload
|
||||
import re
|
||||
|
||||
p = json.loads(build_trader_payload("market_open"))
|
||||
assert re.match(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z", p["ts"]), (
|
||||
f"ts={p['ts']!r} does not match expected UTC format"
|
||||
)
|
||||
@@ -7,6 +7,7 @@ pytest>=8.0
|
||||
pytest-asyncio>=0.23
|
||||
respx>=0.21
|
||||
websockets>=12
|
||||
redis>=5.0
|
||||
# Phase 3b dependencies (Chronos-2 + ML)
|
||||
transformers>=4.40
|
||||
chronos-forecasting>=1.4
|
||||
|
||||
55
services/_shared/heartbeat.py
Normal file
55
services/_shared/heartbeat.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""분산 워커 heartbeat — worker:<name>:heartbeat SET (TTL). Global Constraints 계약 1."""
|
||||
from __future__ import annotations
|
||||
import asyncio, datetime as dt, json, logging, os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_INTERVAL = int(os.getenv("HEARTBEAT_INTERVAL", "15"))
|
||||
DEFAULT_TTL = int(os.getenv("HEARTBEAT_TTL", "45"))
|
||||
|
||||
|
||||
class WorkerStats:
|
||||
"""worker_loop가 갱신, heartbeat_loop가 읽는 가변 카운터."""
|
||||
def __init__(self):
|
||||
self.busy = False
|
||||
self.jobs_done = 0
|
||||
self.jobs_failed = 0
|
||||
self.last_job_at = None # ISO str | None
|
||||
|
||||
|
||||
def utc_now_iso() -> str:
|
||||
return dt.datetime.now(dt.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
|
||||
def build_payload(name: str, kind: str, state: str, stats: WorkerStats, extra: dict | None = None) -> str:
|
||||
payload = {
|
||||
"name": name, "kind": kind, "state": state, "ts": utc_now_iso(),
|
||||
"last_job_at": stats.last_job_at,
|
||||
"jobs_done": stats.jobs_done, "jobs_failed": stats.jobs_failed,
|
||||
}
|
||||
if extra:
|
||||
payload.update(extra)
|
||||
return json.dumps(payload)
|
||||
|
||||
|
||||
async def render_state(redis, stats: WorkerStats, paused_key: str = "queue:paused") -> str:
|
||||
if await redis.get(paused_key) == b"1":
|
||||
return "paused"
|
||||
return "busy" if stats.busy else "idle"
|
||||
|
||||
|
||||
async def heartbeat_loop(redis, name, kind, stats, *, interval=DEFAULT_INTERVAL,
|
||||
ttl=DEFAULT_TTL, paused_key="queue:paused", state_fn=None):
|
||||
key = f"worker:{name}:heartbeat"
|
||||
logger.info("heartbeat 시작 name=%s ttl=%ds", name, ttl)
|
||||
while True:
|
||||
try:
|
||||
if state_fn is not None:
|
||||
state, extra = await state_fn(redis, stats)
|
||||
else:
|
||||
state, extra = await render_state(redis, stats, paused_key), None
|
||||
await redis.set(key, build_payload(name, kind, state, stats, extra), ex=ttl)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("heartbeat 발신 실패 name=%s", name)
|
||||
await asyncio.sleep(interval)
|
||||
46
services/_shared/tests/test_heartbeat.py
Normal file
46
services/_shared/tests/test_heartbeat.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Tests for _shared.heartbeat — Task A1."""
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Make `_shared` importable (same pattern as test_reliable_queue.py)
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
|
||||
|
||||
from _shared.heartbeat import WorkerStats, build_payload, render_state
|
||||
|
||||
|
||||
def test_build_payload_has_contract_fields():
|
||||
s = WorkerStats(); s.jobs_done = 3; s.last_job_at = "2026-06-29T00:00:00Z"
|
||||
payload = json.loads(build_payload("image-render", "render", "idle", s))
|
||||
assert payload["name"] == "image-render"
|
||||
assert payload["kind"] == "render"
|
||||
assert payload["state"] == "idle"
|
||||
assert payload["jobs_done"] == 3
|
||||
assert payload["last_job_at"] == "2026-06-29T00:00:00Z"
|
||||
assert payload["ts"].endswith("Z")
|
||||
|
||||
|
||||
def test_build_payload_merges_extra():
|
||||
payload = json.loads(build_payload("task-watcher", "watcher", "free", WorkerStats(), extra={"mode": "free"}))
|
||||
assert payload["mode"] == "free"
|
||||
|
||||
|
||||
class _FakeRedis:
|
||||
def __init__(self, paused): self._paused = paused
|
||||
async def get(self, key): return b"1" if self._paused else None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_render_state_paused_overrides_busy():
|
||||
s = WorkerStats(); s.busy = True
|
||||
assert await render_state(_FakeRedis(paused=True), s) == "paused"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_render_state_busy_then_idle():
|
||||
s = WorkerStats(); s.busy = True
|
||||
assert await render_state(_FakeRedis(paused=False), s) == "busy"
|
||||
s.busy = False
|
||||
assert await render_state(_FakeRedis(paused=False), s) == "idle"
|
||||
@@ -3,11 +3,14 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from fastapi import FastAPI
|
||||
|
||||
import worker
|
||||
from _shared.heartbeat import heartbeat_loop
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -16,15 +19,19 @@ logger = logging.getLogger(__name__)
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
worker_task = asyncio.create_task(worker.worker_loop())
|
||||
hb_redis = aioredis.from_url(os.getenv("REDIS_URL", "redis://192.168.45.54:6379"), decode_responses=False)
|
||||
hb_task = asyncio.create_task(heartbeat_loop(hb_redis, "image-render", "render", worker.stats))
|
||||
logger.info("image-render lifespan 시작")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
worker_task.cancel()
|
||||
try:
|
||||
await worker_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
for t in (worker_task, hb_task):
|
||||
t.cancel()
|
||||
try:
|
||||
await t
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await hb_redis.aclose()
|
||||
logger.info("image-render lifespan 종료")
|
||||
|
||||
|
||||
|
||||
@@ -67,3 +67,25 @@ async def test_poll_once_returns_false_on_timeout(monkeypatch):
|
||||
assert handled is False
|
||||
fake_queue.ack.assert_not_awaited()
|
||||
fake_queue.fail.assert_not_awaited()
|
||||
|
||||
|
||||
# ----- heartbeat stats 카운터 -----
|
||||
|
||||
class _OneJobQueue:
|
||||
def __init__(self): self.acked = False
|
||||
async def dequeue(self, timeout=5):
|
||||
if self.acked: return None
|
||||
return ({"job_type": "flux_generation", "task_id": "t1", "params": {}}, b"raw")
|
||||
async def ack(self, raw): self.acked = True
|
||||
async def fail(self, raw, payload): pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_once_increments_jobs_done(monkeypatch):
|
||||
worker.stats.jobs_done = 0
|
||||
monkeypatch.setattr(worker, "run_flux_generation", lambda task_id, params: None)
|
||||
handled = await worker.poll_once(_OneJobQueue())
|
||||
assert handled is True
|
||||
assert worker.stats.jobs_done == 1
|
||||
assert worker.stats.busy is False
|
||||
assert worker.stats.last_job_at is not None
|
||||
|
||||
@@ -18,6 +18,7 @@ from providers.gpt_image import run_gpt_image_generation
|
||||
from providers.nano_banana import run_nano_banana_generation
|
||||
from providers.flux import run_flux_generation
|
||||
from _shared.reliable_queue import ReliableQueue
|
||||
from _shared.heartbeat import WorkerStats, utc_now_iso
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,6 +26,8 @@ REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379")
|
||||
QUEUE_KEY = "queue:image-render"
|
||||
PAUSED_KEY = "queue:paused"
|
||||
|
||||
stats = WorkerStats()
|
||||
|
||||
# string names so `unittest.mock.patch` / `monkeypatch.setattr` on `worker.<name>`
|
||||
# is correctly intercepted by getattr(sys.modules[__name__], ...)
|
||||
_DISPATCH_TABLE = {
|
||||
@@ -59,14 +62,21 @@ async def poll_once(queue: ReliableQueue) -> bool:
|
||||
if result is None:
|
||||
return False
|
||||
payload, raw = result
|
||||
stats.busy = True
|
||||
try:
|
||||
await asyncio.to_thread(_dispatch, payload)
|
||||
except Exception:
|
||||
logger.exception("dispatch unhandled exception task_id=%s",
|
||||
payload.get("task_id"))
|
||||
await queue.fail(raw, payload)
|
||||
stats.jobs_failed += 1
|
||||
stats.last_job_at = utc_now_iso()
|
||||
stats.busy = False
|
||||
return True
|
||||
await queue.ack(raw)
|
||||
stats.jobs_done += 1
|
||||
stats.last_job_at = utc_now_iso()
|
||||
stats.busy = False
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@@ -3,12 +3,15 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from fastapi import FastAPI
|
||||
|
||||
import card_renderer
|
||||
import worker
|
||||
from _shared.heartbeat import heartbeat_loop
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -20,15 +23,19 @@ async def lifespan(app: FastAPI):
|
||||
await card_renderer.init_browser()
|
||||
# 큐 워커 백그라운드 시작
|
||||
worker_task = asyncio.create_task(worker.worker_loop())
|
||||
hb_redis = aioredis.from_url(os.getenv("REDIS_URL", "redis://192.168.45.54:6379"), decode_responses=False)
|
||||
hb_task = asyncio.create_task(heartbeat_loop(hb_redis, "insta-render", "render", worker.stats))
|
||||
logger.info("insta-render lifespan 시작")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
worker_task.cancel()
|
||||
try:
|
||||
await worker_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
for t in (worker_task, hb_task):
|
||||
t.cancel()
|
||||
try:
|
||||
await t
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await hb_redis.aclose()
|
||||
await card_renderer.shutdown_browser()
|
||||
logger.info("insta-render lifespan 종료")
|
||||
|
||||
|
||||
@@ -230,3 +230,27 @@ def test_make_queue_redis_socket_timeout_exceeds_block():
|
||||
c = worker.make_queue_redis()
|
||||
st = c.connection_pool.connection_kwargs.get("socket_timeout")
|
||||
assert st is not None and st > 5 # blmove 블록(5s)보다 커야 안정
|
||||
|
||||
|
||||
# ----- heartbeat stats 카운터 -----
|
||||
|
||||
class _OneJobQueueInsta:
|
||||
def __init__(self): self.acked = False
|
||||
async def dequeue(self, timeout=5):
|
||||
if self.acked: return None
|
||||
return ({"task_id": "t1", "params": {"slate_id": 1, "theme": "default"}}, b"raw")
|
||||
async def ack(self, raw): self.acked = True
|
||||
async def fail(self, raw, payload): pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_once_increments_jobs_done(monkeypatch):
|
||||
worker.stats.jobs_done = 0
|
||||
async def fake_process(client, payload): pass
|
||||
monkeypatch.setattr(worker, "_process_one", fake_process)
|
||||
async with httpx.AsyncClient() as client:
|
||||
handled = await worker.poll_once(_OneJobQueueInsta(), client)
|
||||
assert handled is True
|
||||
assert worker.stats.jobs_done == 1
|
||||
assert worker.stats.busy is False
|
||||
assert worker.stats.last_job_at is not None
|
||||
|
||||
@@ -14,9 +14,11 @@ import redis.asyncio as aioredis
|
||||
|
||||
from card_renderer import render_slate
|
||||
from _shared.reliable_queue import ReliableQueue
|
||||
from _shared.heartbeat import WorkerStats, utc_now_iso
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
stats = WorkerStats()
|
||||
|
||||
REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379")
|
||||
NAS_BASE_URL = os.getenv("NAS_BASE_URL", "http://192.168.45.54:18700")
|
||||
@@ -89,12 +91,19 @@ async def poll_once(queue: ReliableQueue, client: httpx.AsyncClient) -> bool:
|
||||
if result is None:
|
||||
return False
|
||||
payload, raw = result
|
||||
stats.busy = True
|
||||
try:
|
||||
await _process_one(client, payload)
|
||||
except Exception:
|
||||
await queue.fail(raw, payload)
|
||||
stats.jobs_failed += 1
|
||||
stats.last_job_at = utc_now_iso()
|
||||
stats.busy = False
|
||||
return True
|
||||
await queue.ack(raw)
|
||||
stats.jobs_done += 1
|
||||
stats.last_job_at = utc_now_iso()
|
||||
stats.busy = False
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@@ -7,12 +7,15 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
import worker
|
||||
from _shared.heartbeat import heartbeat_loop
|
||||
from providers.sync_ops import (
|
||||
generate_lyrics, get_credits,
|
||||
get_timestamped_lyrics, generate_style_boost,
|
||||
@@ -25,15 +28,19 @@ logger = logging.getLogger(__name__)
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
worker_task = asyncio.create_task(worker.worker_loop())
|
||||
hb_redis = aioredis.from_url(os.getenv("REDIS_URL", "redis://192.168.45.54:6379"), decode_responses=False)
|
||||
hb_task = asyncio.create_task(heartbeat_loop(hb_redis, "music-render", "render", worker.stats))
|
||||
logger.info("music-render lifespan 시작")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
worker_task.cancel()
|
||||
try:
|
||||
await worker_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
for t in (worker_task, hb_task):
|
||||
t.cancel()
|
||||
try:
|
||||
await t
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await hb_redis.aclose()
|
||||
logger.info("music-render lifespan 종료")
|
||||
|
||||
|
||||
|
||||
@@ -167,3 +167,25 @@ async def test_poll_once_returns_false_on_timeout(monkeypatch):
|
||||
dispatch_mock.assert_not_called()
|
||||
fake_queue.ack.assert_not_awaited()
|
||||
fake_queue.fail.assert_not_awaited()
|
||||
|
||||
|
||||
# ----- heartbeat stats 카운터 -----
|
||||
|
||||
class _OneJobQueue:
|
||||
def __init__(self): self.acked = False
|
||||
async def dequeue(self, timeout=5):
|
||||
if self.acked: return None
|
||||
return ({"job_type": "suno_generation", "task_id": "t1", "params": {}}, b"raw")
|
||||
async def ack(self, raw): self.acked = True
|
||||
async def fail(self, raw, payload): pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_once_increments_jobs_done(monkeypatch):
|
||||
worker.stats.jobs_done = 0
|
||||
monkeypatch.setattr(worker, "run_suno_generation", lambda task_id, params: None)
|
||||
handled = await worker.poll_once(_OneJobQueue())
|
||||
assert handled is True
|
||||
assert worker.stats.jobs_done == 1
|
||||
assert worker.stats.busy is False
|
||||
assert worker.stats.last_job_at is not None
|
||||
|
||||
@@ -21,6 +21,7 @@ from providers.suno import (
|
||||
)
|
||||
from providers.local import run_local_generation
|
||||
from _shared.reliable_queue import ReliableQueue
|
||||
from _shared.heartbeat import WorkerStats, utc_now_iso
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,6 +29,8 @@ REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379")
|
||||
QUEUE_KEY = "queue:music-render"
|
||||
PAUSED_KEY = "queue:paused"
|
||||
|
||||
stats = WorkerStats()
|
||||
|
||||
# Maps job_type → module-level function name (string).
|
||||
# _dispatch resolves the name via globals() at call time so unittest.mock.patch
|
||||
# on "worker.<name>" is correctly intercepted.
|
||||
@@ -74,6 +77,7 @@ async def poll_once(queue: ReliableQueue) -> bool:
|
||||
if result is None:
|
||||
return False
|
||||
payload, raw = result
|
||||
stats.busy = True
|
||||
try:
|
||||
# sync provider 함수 — thread로 실행해서 이벤트 루프 블로킹 방지
|
||||
await asyncio.to_thread(_dispatch, payload)
|
||||
@@ -81,8 +85,14 @@ async def poll_once(queue: ReliableQueue) -> bool:
|
||||
logger.exception("dispatch unhandled exception task_id=%s",
|
||||
payload.get("task_id"))
|
||||
await queue.fail(raw, payload)
|
||||
stats.jobs_failed += 1
|
||||
stats.last_job_at = utc_now_iso()
|
||||
stats.busy = False
|
||||
return True
|
||||
await queue.ack(raw)
|
||||
stats.jobs_done += 1
|
||||
stats.last_job_at = utc_now_iso()
|
||||
stats.busy = False
|
||||
return True
|
||||
|
||||
|
||||
|
||||
5
services/task-watcher/conftest.py
Normal file
5
services/task-watcher/conftest.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Make services/ root importable so `from _shared.heartbeat import ...` works during tests."""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
16
services/task-watcher/tests/test_watcher.py
Normal file
16
services/task-watcher/tests/test_watcher.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""task-watcher heartbeat payload — state=mode + mode 필드 검증."""
|
||||
import json
|
||||
|
||||
from _shared.heartbeat import build_payload, WorkerStats
|
||||
|
||||
|
||||
def test_watcher_heartbeat_payload_carries_mode():
|
||||
payload = json.loads(
|
||||
build_payload(
|
||||
"task-watcher", "watcher", "trading",
|
||||
WorkerStats(), extra={"mode": "trading"},
|
||||
)
|
||||
)
|
||||
assert payload["kind"] == "watcher"
|
||||
assert payload["state"] == "trading"
|
||||
assert payload["mode"] == "trading"
|
||||
@@ -15,6 +15,7 @@ from zoneinfo import ZoneInfo
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from mode import current_mode, fetch_holidays, KST
|
||||
from _shared.heartbeat import build_payload, WorkerStats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,6 +24,10 @@ PAUSED_KEY = "queue:paused"
|
||||
LOOP_INTERVAL = 30 # 초
|
||||
HOLIDAYS_REFRESH = 3600 # 1시간
|
||||
PAUSED_TTL = 600 # 10분 (watcher 죽어도 자동 해제)
|
||||
HEARTBEAT_KEY = "worker:task-watcher:heartbeat"
|
||||
HEARTBEAT_TTL = 45 # LOOP_INTERVAL 30s < TTL 45s → 만료 전 갱신
|
||||
|
||||
_HB_STATS = WorkerStats()
|
||||
|
||||
|
||||
async def watcher_loop():
|
||||
@@ -46,6 +51,13 @@ async def watcher_loop():
|
||||
else:
|
||||
await redis.delete(PAUSED_KEY)
|
||||
|
||||
# heartbeat (LOOP_INTERVAL=30s < TTL 45s → 만료 전 갱신)
|
||||
await redis.set(
|
||||
HEARTBEAT_KEY,
|
||||
build_payload("task-watcher", "watcher", mode, _HB_STATS, extra={"mode": mode}),
|
||||
ex=HEARTBEAT_TTL,
|
||||
)
|
||||
|
||||
if mode != last_mode:
|
||||
logger.info("mode 전환: %s → %s (paused=%s)", last_mode, mode, mode == "trading")
|
||||
last_mode = mode
|
||||
|
||||
@@ -3,11 +3,14 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from fastapi import FastAPI
|
||||
|
||||
import worker
|
||||
from _shared.heartbeat import heartbeat_loop
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -16,15 +19,19 @@ logger = logging.getLogger(__name__)
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
worker_task = asyncio.create_task(worker.worker_loop())
|
||||
hb_redis = aioredis.from_url(os.getenv("REDIS_URL", "redis://192.168.45.54:6379"), decode_responses=False)
|
||||
hb_task = asyncio.create_task(heartbeat_loop(hb_redis, "video-render", "render", worker.stats))
|
||||
logger.info("video-render lifespan 시작")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
worker_task.cancel()
|
||||
try:
|
||||
await worker_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
for t in (worker_task, hb_task):
|
||||
t.cancel()
|
||||
try:
|
||||
await t
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await hb_redis.aclose()
|
||||
logger.info("video-render lifespan 종료")
|
||||
|
||||
|
||||
|
||||
@@ -94,3 +94,25 @@ async def test_poll_once_returns_false_on_timeout(monkeypatch):
|
||||
assert handled is False
|
||||
fake_queue.ack.assert_not_awaited()
|
||||
fake_queue.fail.assert_not_awaited()
|
||||
|
||||
|
||||
# ----- heartbeat stats 카운터 -----
|
||||
|
||||
class _OneJobQueue:
|
||||
def __init__(self): self.acked = False
|
||||
async def dequeue(self, timeout=5):
|
||||
if self.acked: return None
|
||||
return ({"job_type": "sora_generation", "task_id": "t1", "params": {}}, b"raw")
|
||||
async def ack(self, raw): self.acked = True
|
||||
async def fail(self, raw, payload): pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_once_increments_jobs_done(monkeypatch):
|
||||
worker.stats.jobs_done = 0
|
||||
monkeypatch.setattr(worker, "run_sora_generation", lambda task_id, params: None)
|
||||
handled = await worker.poll_once(_OneJobQueue())
|
||||
assert handled is True
|
||||
assert worker.stats.jobs_done == 1
|
||||
assert worker.stats.busy is False
|
||||
assert worker.stats.last_job_at is not None
|
||||
|
||||
@@ -19,6 +19,7 @@ from providers.veo import run_veo_generation
|
||||
from providers.kling import run_kling_generation
|
||||
from providers.seedance import run_seedance_generation
|
||||
from _shared.reliable_queue import ReliableQueue
|
||||
from _shared.heartbeat import WorkerStats, utc_now_iso
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -26,6 +27,8 @@ REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379")
|
||||
QUEUE_KEY = "queue:video-render"
|
||||
PAUSED_KEY = "queue:paused"
|
||||
|
||||
stats = WorkerStats()
|
||||
|
||||
# string names so `unittest.mock.patch` on `worker.<name>` is correctly intercepted
|
||||
_DISPATCH_TABLE = {
|
||||
"sora_generation": "run_sora_generation",
|
||||
@@ -60,14 +63,21 @@ async def poll_once(queue: ReliableQueue) -> bool:
|
||||
if result is None:
|
||||
return False
|
||||
payload, raw = result
|
||||
stats.busy = True
|
||||
try:
|
||||
await asyncio.to_thread(_dispatch, payload)
|
||||
except Exception:
|
||||
logger.exception("dispatch unhandled exception task_id=%s",
|
||||
payload.get("task_id"))
|
||||
await queue.fail(raw, payload)
|
||||
stats.jobs_failed += 1
|
||||
stats.last_job_at = utc_now_iso()
|
||||
stats.busy = False
|
||||
return True
|
||||
await queue.ack(raw)
|
||||
stats.jobs_done += 1
|
||||
stats.last_job_at = utc_now_iso()
|
||||
stats.busy = False
|
||||
return True
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user