Compare commits

...

4 Commits

Author SHA1 Message Date
54fca07d43 feat(ai_trade): NAS Redis heartbeat (trader market_open/closed)
- ai_trade/heartbeat.py: build_trader_payload() + heartbeat_loop() 자체 미니 헬퍼
  (Windows 호스트 실행이라 _shared import 경로 달라 독립 구현, 계약은 동일)
- ai_trade/main.py: lifespan에 hb_task spawn + shutdown 시 cancel
  state_fn = scheduler._is_market_day & _is_polling_window(KST now) 조합
  signals = len(state.signals) 실시간 주입
- requirements.txt: redis>=5.0 추가
- ai_trade/tests/test_heartbeat.py: build_trader_payload 3케이스 TDD 검증

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_019LV86jBozkNhSFXJA412fq
2026-07-01 01:07:00 +09:00
574b5712c3 feat(task-watcher): heartbeat 발신 (state=mode, paused 이유 노출)
- watcher_loop 에서 mode 판정 직후 worker:task-watcher:heartbeat SET EX 45
- payload: build_payload(state=mode, extra={"mode": mode})
- LOOP_INTERVAL 30s < TTL 45s → 만료 전 주기적 갱신
- conftest.py 추가: services/ 를 sys.path에 주입해 _shared import 가능
- tests/test_watcher.py: payload kind/state/mode 필드 검증 (1 passed)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-07-01 00:59:28 +09:00
2ff31b2e76 feat(render-workers): 4 render 워커 heartbeat 배선 + poll_once 카운터
- services/_shared/heartbeat.py (A1) WorkerStats/utc_now_iso/heartbeat_loop 소비
- image-render / video-render / music-render / insta-render 각 worker.py:
  stats = WorkerStats() 모듈 레벨 추가, poll_once에서 dispatch 전 busy=True,
  ack 후 jobs_done+1 / fail 후 jobs_failed+1 + last_job_at + busy=False
- 각 main.py: lifespan에 aioredis(decode_responses=False) + heartbeat_loop 태스크 spawn,
  종료 시 cancel + aclose
- 각 tests/test_worker.py: test_poll_once_increments_jobs_done 추가
  (image:flux / video:sora / music:suno / insta:_process_one mock)

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_019LV86jBozkNhSFXJA412fq
2026-07-01 00:52:57 +09:00
d1b9ff570d feat(_shared): 워커 heartbeat 모듈 (worker:<name>:heartbeat TTL SET)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-07-01 00:43:01 +09:00
21 changed files with 433 additions and 21 deletions

57
ai_trade/heartbeat.py Normal file
View File

@@ -0,0 +1,57 @@
"""ai_trade heartbeat — NAS Redis로 worker:ai_trade:heartbeat SET.
Global Constraints 계약 1: kind=trader, state=market_open|market_closed.
ai_trade는 Windows 호스트 실행이라 _shared import 경로가 달라 자체 미니 헬퍼로 둔다.
"""
from __future__ import annotations
import asyncio
import datetime as dt
import json
import logging
import os
import redis.asyncio as aioredis
logger = logging.getLogger(__name__)
REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379")
KEY = "worker:ai_trade:heartbeat"
INTERVAL = int(os.getenv("HEARTBEAT_INTERVAL", "15"))
TTL = int(os.getenv("HEARTBEAT_TTL", "45"))
def build_trader_payload(state: str, signals: int = 0) -> str:
"""JSON 문자열 반환. state: 'market_open' | 'market_closed'."""
return json.dumps({
"name": "ai_trade",
"kind": "trader",
"state": state,
"ts": dt.datetime.now(dt.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"),
"last_job_at": None,
"jobs_done": signals,
"jobs_failed": 0,
})
async def heartbeat_loop(state_fn) -> None:
"""Redis에 HEARTBEAT_INTERVAL마다 SET EX TTL.
Args:
state_fn: () -> (state: str, signals: int). 호출자가 폴링 윈도우 판정 주입.
"""
redis = aioredis.from_url(REDIS_URL, decode_responses=False)
try:
while True:
try:
state, signals = state_fn()
payload = build_trader_payload(state, signals)
await redis.set(KEY, payload, ex=TTL)
logger.debug("ai_trade heartbeat sent: state=%s signals=%d", state, signals)
except asyncio.CancelledError:
raise
except Exception:
logger.exception("ai_trade heartbeat 실패 — 다음 주기에 재시도")
await asyncio.sleep(INTERVAL)
finally:
await redis.aclose()

View File

@@ -3,9 +3,12 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime
from zoneinfo import ZoneInfo
from fastapi import FastAPI from fastapi import FastAPI
from ai_trade import heartbeat as _hb
from ai_trade import state as state_mod from ai_trade import state as state_mod
from ai_trade.chronos_predictor import ChronosPredictor from ai_trade.chronos_predictor import ChronosPredictor
from ai_trade.config import get_settings from ai_trade.config import get_settings
@@ -13,8 +16,11 @@ from ai_trade.kis_client import KISClient
from ai_trade.kis_websocket import KISWebSocket from ai_trade.kis_websocket import KISWebSocket
from ai_trade.pull_worker import poll_loop, make_asking_price_callback from ai_trade.pull_worker import poll_loop, make_asking_price_callback
from ai_trade.rate_limit import SignalDedup from ai_trade.rate_limit import SignalDedup
from ai_trade.scheduler import _is_polling_window, _is_market_day
from ai_trade.stock_client import StockClient from ai_trade.stock_client import StockClient
_KST = ZoneInfo("Asia/Seoul")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -23,6 +29,7 @@ class AppContext:
dedup: SignalDedup | None = None dedup: SignalDedup | None = None
shutdown: asyncio.Event | None = None shutdown: asyncio.Event | None = None
poll_task: asyncio.Task | None = None poll_task: asyncio.Task | None = None
hb_task: asyncio.Task | None = None
kis_client: KISClient | None = None kis_client: KISClient | None = None
kis_ws: KISWebSocket | None = None kis_ws: KISWebSocket | None = None
chronos: ChronosPredictor | None = None chronos: ChronosPredictor | None = None
@@ -87,9 +94,27 @@ async def lifespan(app: FastAPI):
) )
) )
def _trader_state() -> tuple[str, int]:
"""scheduler의 실제 폴링 윈도우 판정으로 market_open/market_closed 결정."""
now = datetime.now(_KST)
is_open = _is_market_day(now) and _is_polling_window(now)
state_str = "market_open" if is_open else "market_closed"
signals = len(state_mod.state.signals)
return state_str, signals
_ctx.hb_task = asyncio.create_task(_hb.heartbeat_loop(_trader_state))
yield yield
# Shutdown # Shutdown heartbeat task
if _ctx.hb_task is not None:
_ctx.hb_task.cancel()
try:
await _ctx.hb_task
except asyncio.CancelledError:
pass
# Shutdown poll task
if _ctx.shutdown is not None: if _ctx.shutdown is not None:
_ctx.shutdown.set() _ctx.shutdown.set()
if _ctx.poll_task is not None: if _ctx.poll_task is not None:

View File

@@ -0,0 +1,38 @@
"""Tests for ai_trade heartbeat payload builder."""
import json
import pytest
def test_trader_payload_market_open():
from ai_trade.heartbeat import build_trader_payload
p = json.loads(build_trader_payload("market_open", signals=2))
assert p["name"] == "ai_trade"
assert p["kind"] == "trader"
assert p["state"] == "market_open"
assert p["ts"].endswith("Z")
assert p["jobs_done"] == 2
def test_trader_payload_market_closed():
from ai_trade.heartbeat import build_trader_payload
p = json.loads(build_trader_payload("market_closed"))
assert p["name"] == "ai_trade"
assert p["kind"] == "trader"
assert p["state"] == "market_closed"
assert p["jobs_done"] == 0
assert p["jobs_failed"] == 0
assert p["last_job_at"] is None
def test_trader_payload_ts_format():
"""ts 필드가 ISO 8601 UTC 형식 (YYYY-MM-DDTHH:MM:SSZ)인지 확인."""
from ai_trade.heartbeat import build_trader_payload
import re
p = json.loads(build_trader_payload("market_open"))
assert re.match(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z", p["ts"]), (
f"ts={p['ts']!r} does not match expected UTC format"
)

View File

@@ -7,6 +7,7 @@ pytest>=8.0
pytest-asyncio>=0.23 pytest-asyncio>=0.23
respx>=0.21 respx>=0.21
websockets>=12 websockets>=12
redis>=5.0
# Phase 3b dependencies (Chronos-2 + ML) # Phase 3b dependencies (Chronos-2 + ML)
transformers>=4.40 transformers>=4.40
chronos-forecasting>=1.4 chronos-forecasting>=1.4

View File

@@ -0,0 +1,55 @@
"""분산 워커 heartbeat — worker:<name>:heartbeat SET (TTL). Global Constraints 계약 1."""
from __future__ import annotations
import asyncio, datetime as dt, json, logging, os
logger = logging.getLogger(__name__)
DEFAULT_INTERVAL = int(os.getenv("HEARTBEAT_INTERVAL", "15"))
DEFAULT_TTL = int(os.getenv("HEARTBEAT_TTL", "45"))
class WorkerStats:
"""worker_loop가 갱신, heartbeat_loop가 읽는 가변 카운터."""
def __init__(self):
self.busy = False
self.jobs_done = 0
self.jobs_failed = 0
self.last_job_at = None # ISO str | None
def utc_now_iso() -> str:
return dt.datetime.now(dt.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
def build_payload(name: str, kind: str, state: str, stats: WorkerStats, extra: dict | None = None) -> str:
payload = {
"name": name, "kind": kind, "state": state, "ts": utc_now_iso(),
"last_job_at": stats.last_job_at,
"jobs_done": stats.jobs_done, "jobs_failed": stats.jobs_failed,
}
if extra:
payload.update(extra)
return json.dumps(payload)
async def render_state(redis, stats: WorkerStats, paused_key: str = "queue:paused") -> str:
if await redis.get(paused_key) == b"1":
return "paused"
return "busy" if stats.busy else "idle"
async def heartbeat_loop(redis, name, kind, stats, *, interval=DEFAULT_INTERVAL,
ttl=DEFAULT_TTL, paused_key="queue:paused", state_fn=None):
key = f"worker:{name}:heartbeat"
logger.info("heartbeat 시작 name=%s ttl=%ds", name, ttl)
while True:
try:
if state_fn is not None:
state, extra = await state_fn(redis, stats)
else:
state, extra = await render_state(redis, stats, paused_key), None
await redis.set(key, build_payload(name, kind, state, stats, extra), ex=ttl)
except asyncio.CancelledError:
raise
except Exception:
logger.exception("heartbeat 발신 실패 name=%s", name)
await asyncio.sleep(interval)

View File

@@ -0,0 +1,46 @@
"""Tests for _shared.heartbeat — Task A1."""
import json
import sys
from pathlib import Path
import pytest
# Make `_shared` importable (same pattern as test_reliable_queue.py)
sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
from _shared.heartbeat import WorkerStats, build_payload, render_state
def test_build_payload_has_contract_fields():
s = WorkerStats(); s.jobs_done = 3; s.last_job_at = "2026-06-29T00:00:00Z"
payload = json.loads(build_payload("image-render", "render", "idle", s))
assert payload["name"] == "image-render"
assert payload["kind"] == "render"
assert payload["state"] == "idle"
assert payload["jobs_done"] == 3
assert payload["last_job_at"] == "2026-06-29T00:00:00Z"
assert payload["ts"].endswith("Z")
def test_build_payload_merges_extra():
payload = json.loads(build_payload("task-watcher", "watcher", "free", WorkerStats(), extra={"mode": "free"}))
assert payload["mode"] == "free"
class _FakeRedis:
def __init__(self, paused): self._paused = paused
async def get(self, key): return b"1" if self._paused else None
@pytest.mark.asyncio
async def test_render_state_paused_overrides_busy():
s = WorkerStats(); s.busy = True
assert await render_state(_FakeRedis(paused=True), s) == "paused"
@pytest.mark.asyncio
async def test_render_state_busy_then_idle():
s = WorkerStats(); s.busy = True
assert await render_state(_FakeRedis(paused=False), s) == "busy"
s.busy = False
assert await render_state(_FakeRedis(paused=False), s) == "idle"

View File

@@ -3,11 +3,14 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import redis.asyncio as aioredis
from fastapi import FastAPI from fastapi import FastAPI
import worker import worker
from _shared.heartbeat import heartbeat_loop
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -16,15 +19,19 @@ logger = logging.getLogger(__name__)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
worker_task = asyncio.create_task(worker.worker_loop()) worker_task = asyncio.create_task(worker.worker_loop())
hb_redis = aioredis.from_url(os.getenv("REDIS_URL", "redis://192.168.45.54:6379"), decode_responses=False)
hb_task = asyncio.create_task(heartbeat_loop(hb_redis, "image-render", "render", worker.stats))
logger.info("image-render lifespan 시작") logger.info("image-render lifespan 시작")
try: try:
yield yield
finally: finally:
worker_task.cancel() for t in (worker_task, hb_task):
t.cancel()
try: try:
await worker_task await t
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
await hb_redis.aclose()
logger.info("image-render lifespan 종료") logger.info("image-render lifespan 종료")

View File

@@ -67,3 +67,25 @@ async def test_poll_once_returns_false_on_timeout(monkeypatch):
assert handled is False assert handled is False
fake_queue.ack.assert_not_awaited() fake_queue.ack.assert_not_awaited()
fake_queue.fail.assert_not_awaited() fake_queue.fail.assert_not_awaited()
# ----- heartbeat stats 카운터 -----
class _OneJobQueue:
def __init__(self): self.acked = False
async def dequeue(self, timeout=5):
if self.acked: return None
return ({"job_type": "flux_generation", "task_id": "t1", "params": {}}, b"raw")
async def ack(self, raw): self.acked = True
async def fail(self, raw, payload): pass
@pytest.mark.asyncio
async def test_poll_once_increments_jobs_done(monkeypatch):
worker.stats.jobs_done = 0
monkeypatch.setattr(worker, "run_flux_generation", lambda task_id, params: None)
handled = await worker.poll_once(_OneJobQueue())
assert handled is True
assert worker.stats.jobs_done == 1
assert worker.stats.busy is False
assert worker.stats.last_job_at is not None

View File

@@ -18,6 +18,7 @@ from providers.gpt_image import run_gpt_image_generation
from providers.nano_banana import run_nano_banana_generation from providers.nano_banana import run_nano_banana_generation
from providers.flux import run_flux_generation from providers.flux import run_flux_generation
from _shared.reliable_queue import ReliableQueue from _shared.reliable_queue import ReliableQueue
from _shared.heartbeat import WorkerStats, utc_now_iso
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -25,6 +26,8 @@ REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379")
QUEUE_KEY = "queue:image-render" QUEUE_KEY = "queue:image-render"
PAUSED_KEY = "queue:paused" PAUSED_KEY = "queue:paused"
stats = WorkerStats()
# string names so `unittest.mock.patch` / `monkeypatch.setattr` on `worker.<name>` # string names so `unittest.mock.patch` / `monkeypatch.setattr` on `worker.<name>`
# is correctly intercepted by getattr(sys.modules[__name__], ...) # is correctly intercepted by getattr(sys.modules[__name__], ...)
_DISPATCH_TABLE = { _DISPATCH_TABLE = {
@@ -59,14 +62,21 @@ async def poll_once(queue: ReliableQueue) -> bool:
if result is None: if result is None:
return False return False
payload, raw = result payload, raw = result
stats.busy = True
try: try:
await asyncio.to_thread(_dispatch, payload) await asyncio.to_thread(_dispatch, payload)
except Exception: except Exception:
logger.exception("dispatch unhandled exception task_id=%s", logger.exception("dispatch unhandled exception task_id=%s",
payload.get("task_id")) payload.get("task_id"))
await queue.fail(raw, payload) await queue.fail(raw, payload)
stats.jobs_failed += 1
stats.last_job_at = utc_now_iso()
stats.busy = False
return True return True
await queue.ack(raw) await queue.ack(raw)
stats.jobs_done += 1
stats.last_job_at = utc_now_iso()
stats.busy = False
return True return True

View File

@@ -3,12 +3,15 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import redis.asyncio as aioredis
from fastapi import FastAPI from fastapi import FastAPI
import card_renderer import card_renderer
import worker import worker
from _shared.heartbeat import heartbeat_loop
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -20,15 +23,19 @@ async def lifespan(app: FastAPI):
await card_renderer.init_browser() await card_renderer.init_browser()
# 큐 워커 백그라운드 시작 # 큐 워커 백그라운드 시작
worker_task = asyncio.create_task(worker.worker_loop()) worker_task = asyncio.create_task(worker.worker_loop())
hb_redis = aioredis.from_url(os.getenv("REDIS_URL", "redis://192.168.45.54:6379"), decode_responses=False)
hb_task = asyncio.create_task(heartbeat_loop(hb_redis, "insta-render", "render", worker.stats))
logger.info("insta-render lifespan 시작") logger.info("insta-render lifespan 시작")
try: try:
yield yield
finally: finally:
worker_task.cancel() for t in (worker_task, hb_task):
t.cancel()
try: try:
await worker_task await t
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
await hb_redis.aclose()
await card_renderer.shutdown_browser() await card_renderer.shutdown_browser()
logger.info("insta-render lifespan 종료") logger.info("insta-render lifespan 종료")

View File

@@ -230,3 +230,27 @@ def test_make_queue_redis_socket_timeout_exceeds_block():
c = worker.make_queue_redis() c = worker.make_queue_redis()
st = c.connection_pool.connection_kwargs.get("socket_timeout") st = c.connection_pool.connection_kwargs.get("socket_timeout")
assert st is not None and st > 5 # blmove 블록(5s)보다 커야 안정 assert st is not None and st > 5 # blmove 블록(5s)보다 커야 안정
# ----- heartbeat stats 카운터 -----
class _OneJobQueueInsta:
def __init__(self): self.acked = False
async def dequeue(self, timeout=5):
if self.acked: return None
return ({"task_id": "t1", "params": {"slate_id": 1, "theme": "default"}}, b"raw")
async def ack(self, raw): self.acked = True
async def fail(self, raw, payload): pass
@pytest.mark.asyncio
async def test_poll_once_increments_jobs_done(monkeypatch):
worker.stats.jobs_done = 0
async def fake_process(client, payload): pass
monkeypatch.setattr(worker, "_process_one", fake_process)
async with httpx.AsyncClient() as client:
handled = await worker.poll_once(_OneJobQueueInsta(), client)
assert handled is True
assert worker.stats.jobs_done == 1
assert worker.stats.busy is False
assert worker.stats.last_job_at is not None

View File

@@ -14,9 +14,11 @@ import redis.asyncio as aioredis
from card_renderer import render_slate from card_renderer import render_slate
from _shared.reliable_queue import ReliableQueue from _shared.reliable_queue import ReliableQueue
from _shared.heartbeat import WorkerStats, utc_now_iso
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
stats = WorkerStats()
REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379") REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379")
NAS_BASE_URL = os.getenv("NAS_BASE_URL", "http://192.168.45.54:18700") NAS_BASE_URL = os.getenv("NAS_BASE_URL", "http://192.168.45.54:18700")
@@ -89,12 +91,19 @@ async def poll_once(queue: ReliableQueue, client: httpx.AsyncClient) -> bool:
if result is None: if result is None:
return False return False
payload, raw = result payload, raw = result
stats.busy = True
try: try:
await _process_one(client, payload) await _process_one(client, payload)
except Exception: except Exception:
await queue.fail(raw, payload) await queue.fail(raw, payload)
stats.jobs_failed += 1
stats.last_job_at = utc_now_iso()
stats.busy = False
return True return True
await queue.ack(raw) await queue.ack(raw)
stats.jobs_done += 1
stats.last_job_at = utc_now_iso()
stats.busy = False
return True return True

View File

@@ -7,12 +7,15 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import redis.asyncio as aioredis
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
import worker import worker
from _shared.heartbeat import heartbeat_loop
from providers.sync_ops import ( from providers.sync_ops import (
generate_lyrics, get_credits, generate_lyrics, get_credits,
get_timestamped_lyrics, generate_style_boost, get_timestamped_lyrics, generate_style_boost,
@@ -25,15 +28,19 @@ logger = logging.getLogger(__name__)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
worker_task = asyncio.create_task(worker.worker_loop()) worker_task = asyncio.create_task(worker.worker_loop())
hb_redis = aioredis.from_url(os.getenv("REDIS_URL", "redis://192.168.45.54:6379"), decode_responses=False)
hb_task = asyncio.create_task(heartbeat_loop(hb_redis, "music-render", "render", worker.stats))
logger.info("music-render lifespan 시작") logger.info("music-render lifespan 시작")
try: try:
yield yield
finally: finally:
worker_task.cancel() for t in (worker_task, hb_task):
t.cancel()
try: try:
await worker_task await t
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
await hb_redis.aclose()
logger.info("music-render lifespan 종료") logger.info("music-render lifespan 종료")

View File

@@ -167,3 +167,25 @@ async def test_poll_once_returns_false_on_timeout(monkeypatch):
dispatch_mock.assert_not_called() dispatch_mock.assert_not_called()
fake_queue.ack.assert_not_awaited() fake_queue.ack.assert_not_awaited()
fake_queue.fail.assert_not_awaited() fake_queue.fail.assert_not_awaited()
# ----- heartbeat stats 카운터 -----
class _OneJobQueue:
def __init__(self): self.acked = False
async def dequeue(self, timeout=5):
if self.acked: return None
return ({"job_type": "suno_generation", "task_id": "t1", "params": {}}, b"raw")
async def ack(self, raw): self.acked = True
async def fail(self, raw, payload): pass
@pytest.mark.asyncio
async def test_poll_once_increments_jobs_done(monkeypatch):
worker.stats.jobs_done = 0
monkeypatch.setattr(worker, "run_suno_generation", lambda task_id, params: None)
handled = await worker.poll_once(_OneJobQueue())
assert handled is True
assert worker.stats.jobs_done == 1
assert worker.stats.busy is False
assert worker.stats.last_job_at is not None

View File

@@ -21,6 +21,7 @@ from providers.suno import (
) )
from providers.local import run_local_generation from providers.local import run_local_generation
from _shared.reliable_queue import ReliableQueue from _shared.reliable_queue import ReliableQueue
from _shared.heartbeat import WorkerStats, utc_now_iso
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -28,6 +29,8 @@ REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379")
QUEUE_KEY = "queue:music-render" QUEUE_KEY = "queue:music-render"
PAUSED_KEY = "queue:paused" PAUSED_KEY = "queue:paused"
stats = WorkerStats()
# Maps job_type → module-level function name (string). # Maps job_type → module-level function name (string).
# _dispatch resolves the name via globals() at call time so unittest.mock.patch # _dispatch resolves the name via globals() at call time so unittest.mock.patch
# on "worker.<name>" is correctly intercepted. # on "worker.<name>" is correctly intercepted.
@@ -74,6 +77,7 @@ async def poll_once(queue: ReliableQueue) -> bool:
if result is None: if result is None:
return False return False
payload, raw = result payload, raw = result
stats.busy = True
try: try:
# sync provider 함수 — thread로 실행해서 이벤트 루프 블로킹 방지 # sync provider 함수 — thread로 실행해서 이벤트 루프 블로킹 방지
await asyncio.to_thread(_dispatch, payload) await asyncio.to_thread(_dispatch, payload)
@@ -81,8 +85,14 @@ async def poll_once(queue: ReliableQueue) -> bool:
logger.exception("dispatch unhandled exception task_id=%s", logger.exception("dispatch unhandled exception task_id=%s",
payload.get("task_id")) payload.get("task_id"))
await queue.fail(raw, payload) await queue.fail(raw, payload)
stats.jobs_failed += 1
stats.last_job_at = utc_now_iso()
stats.busy = False
return True return True
await queue.ack(raw) await queue.ack(raw)
stats.jobs_done += 1
stats.last_job_at = utc_now_iso()
stats.busy = False
return True return True

View File

@@ -0,0 +1,5 @@
"""Make services/ root importable so `from _shared.heartbeat import ...` works during tests."""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

View File

@@ -0,0 +1,16 @@
"""task-watcher heartbeat payload — state=mode + mode 필드 검증."""
import json
from _shared.heartbeat import build_payload, WorkerStats
def test_watcher_heartbeat_payload_carries_mode():
payload = json.loads(
build_payload(
"task-watcher", "watcher", "trading",
WorkerStats(), extra={"mode": "trading"},
)
)
assert payload["kind"] == "watcher"
assert payload["state"] == "trading"
assert payload["mode"] == "trading"

View File

@@ -15,6 +15,7 @@ from zoneinfo import ZoneInfo
import redis.asyncio as aioredis import redis.asyncio as aioredis
from mode import current_mode, fetch_holidays, KST from mode import current_mode, fetch_holidays, KST
from _shared.heartbeat import build_payload, WorkerStats
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -23,6 +24,10 @@ PAUSED_KEY = "queue:paused"
LOOP_INTERVAL = 30 # 초 LOOP_INTERVAL = 30 # 초
HOLIDAYS_REFRESH = 3600 # 1시간 HOLIDAYS_REFRESH = 3600 # 1시간
PAUSED_TTL = 600 # 10분 (watcher 죽어도 자동 해제) PAUSED_TTL = 600 # 10분 (watcher 죽어도 자동 해제)
HEARTBEAT_KEY = "worker:task-watcher:heartbeat"
HEARTBEAT_TTL = 45 # LOOP_INTERVAL 30s < TTL 45s → 만료 전 갱신
_HB_STATS = WorkerStats()
async def watcher_loop(): async def watcher_loop():
@@ -46,6 +51,13 @@ async def watcher_loop():
else: else:
await redis.delete(PAUSED_KEY) await redis.delete(PAUSED_KEY)
# heartbeat (LOOP_INTERVAL=30s < TTL 45s → 만료 전 갱신)
await redis.set(
HEARTBEAT_KEY,
build_payload("task-watcher", "watcher", mode, _HB_STATS, extra={"mode": mode}),
ex=HEARTBEAT_TTL,
)
if mode != last_mode: if mode != last_mode:
logger.info("mode 전환: %s%s (paused=%s)", last_mode, mode, mode == "trading") logger.info("mode 전환: %s%s (paused=%s)", last_mode, mode, mode == "trading")
last_mode = mode last_mode = mode

View File

@@ -3,11 +3,14 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import redis.asyncio as aioredis
from fastapi import FastAPI from fastapi import FastAPI
import worker import worker
from _shared.heartbeat import heartbeat_loop
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -16,15 +19,19 @@ logger = logging.getLogger(__name__)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
worker_task = asyncio.create_task(worker.worker_loop()) worker_task = asyncio.create_task(worker.worker_loop())
hb_redis = aioredis.from_url(os.getenv("REDIS_URL", "redis://192.168.45.54:6379"), decode_responses=False)
hb_task = asyncio.create_task(heartbeat_loop(hb_redis, "video-render", "render", worker.stats))
logger.info("video-render lifespan 시작") logger.info("video-render lifespan 시작")
try: try:
yield yield
finally: finally:
worker_task.cancel() for t in (worker_task, hb_task):
t.cancel()
try: try:
await worker_task await t
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
await hb_redis.aclose()
logger.info("video-render lifespan 종료") logger.info("video-render lifespan 종료")

View File

@@ -94,3 +94,25 @@ async def test_poll_once_returns_false_on_timeout(monkeypatch):
assert handled is False assert handled is False
fake_queue.ack.assert_not_awaited() fake_queue.ack.assert_not_awaited()
fake_queue.fail.assert_not_awaited() fake_queue.fail.assert_not_awaited()
# ----- heartbeat stats 카운터 -----
class _OneJobQueue:
def __init__(self): self.acked = False
async def dequeue(self, timeout=5):
if self.acked: return None
return ({"job_type": "sora_generation", "task_id": "t1", "params": {}}, b"raw")
async def ack(self, raw): self.acked = True
async def fail(self, raw, payload): pass
@pytest.mark.asyncio
async def test_poll_once_increments_jobs_done(monkeypatch):
worker.stats.jobs_done = 0
monkeypatch.setattr(worker, "run_sora_generation", lambda task_id, params: None)
handled = await worker.poll_once(_OneJobQueue())
assert handled is True
assert worker.stats.jobs_done == 1
assert worker.stats.busy is False
assert worker.stats.last_job_at is not None

View File

@@ -19,6 +19,7 @@ from providers.veo import run_veo_generation
from providers.kling import run_kling_generation from providers.kling import run_kling_generation
from providers.seedance import run_seedance_generation from providers.seedance import run_seedance_generation
from _shared.reliable_queue import ReliableQueue from _shared.reliable_queue import ReliableQueue
from _shared.heartbeat import WorkerStats, utc_now_iso
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -26,6 +27,8 @@ REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379")
QUEUE_KEY = "queue:video-render" QUEUE_KEY = "queue:video-render"
PAUSED_KEY = "queue:paused" PAUSED_KEY = "queue:paused"
stats = WorkerStats()
# string names so `unittest.mock.patch` on `worker.<name>` is correctly intercepted # string names so `unittest.mock.patch` on `worker.<name>` is correctly intercepted
_DISPATCH_TABLE = { _DISPATCH_TABLE = {
"sora_generation": "run_sora_generation", "sora_generation": "run_sora_generation",
@@ -60,14 +63,21 @@ async def poll_once(queue: ReliableQueue) -> bool:
if result is None: if result is None:
return False return False
payload, raw = result payload, raw = result
stats.busy = True
try: try:
await asyncio.to_thread(_dispatch, payload) await asyncio.to_thread(_dispatch, payload)
except Exception: except Exception:
logger.exception("dispatch unhandled exception task_id=%s", logger.exception("dispatch unhandled exception task_id=%s",
payload.get("task_id")) payload.get("task_id"))
await queue.fail(raw, payload) await queue.fail(raw, payload)
stats.jobs_failed += 1
stats.last_job_at = utc_now_iso()
stats.busy = False
return True return True
await queue.ack(raw) await queue.ack(raw)
stats.jobs_done += 1
stats.last_job_at = utc_now_iso()
stats.busy = False
return True return True