feat(render-workers): 4 render 워커 heartbeat 배선 + poll_once 카운터

- services/_shared/heartbeat.py (A1) WorkerStats/utc_now_iso/heartbeat_loop 소비
- image-render / video-render / music-render / insta-render 각 worker.py:
  stats = WorkerStats() 모듈 레벨 추가, poll_once에서 dispatch 전 busy=True,
  ack 후 jobs_done+1 / fail 후 jobs_failed+1 + last_job_at + busy=False
- 각 main.py: lifespan에 aioredis(decode_responses=False) + heartbeat_loop 태스크 spawn,
  종료 시 cancel + aclose
- 각 tests/test_worker.py: test_poll_once_increments_jobs_done 추가
  (image:flux / video:sora / music:suno / insta:_process_one mock)

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_019LV86jBozkNhSFXJA412fq
This commit is contained in:
2026-07-01 00:52:57 +09:00
parent d1b9ff570d
commit 2ff31b2e76
12 changed files with 177 additions and 20 deletions

View File

@@ -3,11 +3,14 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import redis.asyncio as aioredis
from fastapi import FastAPI from fastapi import FastAPI
import worker import worker
from _shared.heartbeat import heartbeat_loop
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -16,15 +19,19 @@ logger = logging.getLogger(__name__)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
worker_task = asyncio.create_task(worker.worker_loop()) worker_task = asyncio.create_task(worker.worker_loop())
hb_redis = aioredis.from_url(os.getenv("REDIS_URL", "redis://192.168.45.54:6379"), decode_responses=False)
hb_task = asyncio.create_task(heartbeat_loop(hb_redis, "image-render", "render", worker.stats))
logger.info("image-render lifespan 시작") logger.info("image-render lifespan 시작")
try: try:
yield yield
finally: finally:
worker_task.cancel() for t in (worker_task, hb_task):
try: t.cancel()
await worker_task try:
except asyncio.CancelledError: await t
pass except asyncio.CancelledError:
pass
await hb_redis.aclose()
logger.info("image-render lifespan 종료") logger.info("image-render lifespan 종료")

View File

@@ -67,3 +67,25 @@ async def test_poll_once_returns_false_on_timeout(monkeypatch):
assert handled is False assert handled is False
fake_queue.ack.assert_not_awaited() fake_queue.ack.assert_not_awaited()
fake_queue.fail.assert_not_awaited() fake_queue.fail.assert_not_awaited()
# ----- heartbeat stats 카운터 -----
class _OneJobQueue:
def __init__(self): self.acked = False
async def dequeue(self, timeout=5):
if self.acked: return None
return ({"job_type": "flux_generation", "task_id": "t1", "params": {}}, b"raw")
async def ack(self, raw): self.acked = True
async def fail(self, raw, payload): pass
@pytest.mark.asyncio
async def test_poll_once_increments_jobs_done(monkeypatch):
worker.stats.jobs_done = 0
monkeypatch.setattr(worker, "run_flux_generation", lambda task_id, params: None)
handled = await worker.poll_once(_OneJobQueue())
assert handled is True
assert worker.stats.jobs_done == 1
assert worker.stats.busy is False
assert worker.stats.last_job_at is not None

View File

@@ -18,6 +18,7 @@ from providers.gpt_image import run_gpt_image_generation
from providers.nano_banana import run_nano_banana_generation from providers.nano_banana import run_nano_banana_generation
from providers.flux import run_flux_generation from providers.flux import run_flux_generation
from _shared.reliable_queue import ReliableQueue from _shared.reliable_queue import ReliableQueue
from _shared.heartbeat import WorkerStats, utc_now_iso
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -25,6 +26,8 @@ REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379")
QUEUE_KEY = "queue:image-render" QUEUE_KEY = "queue:image-render"
PAUSED_KEY = "queue:paused" PAUSED_KEY = "queue:paused"
stats = WorkerStats()
# string names so `unittest.mock.patch` / `monkeypatch.setattr` on `worker.<name>` # string names so `unittest.mock.patch` / `monkeypatch.setattr` on `worker.<name>`
# is correctly intercepted by getattr(sys.modules[__name__], ...) # is correctly intercepted by getattr(sys.modules[__name__], ...)
_DISPATCH_TABLE = { _DISPATCH_TABLE = {
@@ -59,14 +62,21 @@ async def poll_once(queue: ReliableQueue) -> bool:
if result is None: if result is None:
return False return False
payload, raw = result payload, raw = result
stats.busy = True
try: try:
await asyncio.to_thread(_dispatch, payload) await asyncio.to_thread(_dispatch, payload)
except Exception: except Exception:
logger.exception("dispatch unhandled exception task_id=%s", logger.exception("dispatch unhandled exception task_id=%s",
payload.get("task_id")) payload.get("task_id"))
await queue.fail(raw, payload) await queue.fail(raw, payload)
stats.jobs_failed += 1
stats.last_job_at = utc_now_iso()
stats.busy = False
return True return True
await queue.ack(raw) await queue.ack(raw)
stats.jobs_done += 1
stats.last_job_at = utc_now_iso()
stats.busy = False
return True return True

View File

@@ -3,12 +3,15 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import redis.asyncio as aioredis
from fastapi import FastAPI from fastapi import FastAPI
import card_renderer import card_renderer
import worker import worker
from _shared.heartbeat import heartbeat_loop
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -20,15 +23,19 @@ async def lifespan(app: FastAPI):
await card_renderer.init_browser() await card_renderer.init_browser()
# 큐 워커 백그라운드 시작 # 큐 워커 백그라운드 시작
worker_task = asyncio.create_task(worker.worker_loop()) worker_task = asyncio.create_task(worker.worker_loop())
hb_redis = aioredis.from_url(os.getenv("REDIS_URL", "redis://192.168.45.54:6379"), decode_responses=False)
hb_task = asyncio.create_task(heartbeat_loop(hb_redis, "insta-render", "render", worker.stats))
logger.info("insta-render lifespan 시작") logger.info("insta-render lifespan 시작")
try: try:
yield yield
finally: finally:
worker_task.cancel() for t in (worker_task, hb_task):
try: t.cancel()
await worker_task try:
except asyncio.CancelledError: await t
pass except asyncio.CancelledError:
pass
await hb_redis.aclose()
await card_renderer.shutdown_browser() await card_renderer.shutdown_browser()
logger.info("insta-render lifespan 종료") logger.info("insta-render lifespan 종료")

View File

@@ -230,3 +230,27 @@ def test_make_queue_redis_socket_timeout_exceeds_block():
c = worker.make_queue_redis() c = worker.make_queue_redis()
st = c.connection_pool.connection_kwargs.get("socket_timeout") st = c.connection_pool.connection_kwargs.get("socket_timeout")
assert st is not None and st > 5 # blmove 블록(5s)보다 커야 안정 assert st is not None and st > 5 # blmove 블록(5s)보다 커야 안정
# ----- heartbeat stats 카운터 -----
class _OneJobQueueInsta:
def __init__(self): self.acked = False
async def dequeue(self, timeout=5):
if self.acked: return None
return ({"task_id": "t1", "params": {"slate_id": 1, "theme": "default"}}, b"raw")
async def ack(self, raw): self.acked = True
async def fail(self, raw, payload): pass
@pytest.mark.asyncio
async def test_poll_once_increments_jobs_done(monkeypatch):
worker.stats.jobs_done = 0
async def fake_process(client, payload): pass
monkeypatch.setattr(worker, "_process_one", fake_process)
async with httpx.AsyncClient() as client:
handled = await worker.poll_once(_OneJobQueueInsta(), client)
assert handled is True
assert worker.stats.jobs_done == 1
assert worker.stats.busy is False
assert worker.stats.last_job_at is not None

View File

@@ -14,9 +14,11 @@ import redis.asyncio as aioredis
from card_renderer import render_slate from card_renderer import render_slate
from _shared.reliable_queue import ReliableQueue from _shared.reliable_queue import ReliableQueue
from _shared.heartbeat import WorkerStats, utc_now_iso
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
stats = WorkerStats()
REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379") REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379")
NAS_BASE_URL = os.getenv("NAS_BASE_URL", "http://192.168.45.54:18700") NAS_BASE_URL = os.getenv("NAS_BASE_URL", "http://192.168.45.54:18700")
@@ -89,12 +91,19 @@ async def poll_once(queue: ReliableQueue, client: httpx.AsyncClient) -> bool:
if result is None: if result is None:
return False return False
payload, raw = result payload, raw = result
stats.busy = True
try: try:
await _process_one(client, payload) await _process_one(client, payload)
except Exception: except Exception:
await queue.fail(raw, payload) await queue.fail(raw, payload)
stats.jobs_failed += 1
stats.last_job_at = utc_now_iso()
stats.busy = False
return True return True
await queue.ack(raw) await queue.ack(raw)
stats.jobs_done += 1
stats.last_job_at = utc_now_iso()
stats.busy = False
return True return True

View File

@@ -7,12 +7,15 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import redis.asyncio as aioredis
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
import worker import worker
from _shared.heartbeat import heartbeat_loop
from providers.sync_ops import ( from providers.sync_ops import (
generate_lyrics, get_credits, generate_lyrics, get_credits,
get_timestamped_lyrics, generate_style_boost, get_timestamped_lyrics, generate_style_boost,
@@ -25,15 +28,19 @@ logger = logging.getLogger(__name__)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
worker_task = asyncio.create_task(worker.worker_loop()) worker_task = asyncio.create_task(worker.worker_loop())
hb_redis = aioredis.from_url(os.getenv("REDIS_URL", "redis://192.168.45.54:6379"), decode_responses=False)
hb_task = asyncio.create_task(heartbeat_loop(hb_redis, "music-render", "render", worker.stats))
logger.info("music-render lifespan 시작") logger.info("music-render lifespan 시작")
try: try:
yield yield
finally: finally:
worker_task.cancel() for t in (worker_task, hb_task):
try: t.cancel()
await worker_task try:
except asyncio.CancelledError: await t
pass except asyncio.CancelledError:
pass
await hb_redis.aclose()
logger.info("music-render lifespan 종료") logger.info("music-render lifespan 종료")

View File

@@ -167,3 +167,25 @@ async def test_poll_once_returns_false_on_timeout(monkeypatch):
dispatch_mock.assert_not_called() dispatch_mock.assert_not_called()
fake_queue.ack.assert_not_awaited() fake_queue.ack.assert_not_awaited()
fake_queue.fail.assert_not_awaited() fake_queue.fail.assert_not_awaited()
# ----- heartbeat stats 카운터 -----
class _OneJobQueue:
def __init__(self): self.acked = False
async def dequeue(self, timeout=5):
if self.acked: return None
return ({"job_type": "suno_generation", "task_id": "t1", "params": {}}, b"raw")
async def ack(self, raw): self.acked = True
async def fail(self, raw, payload): pass
@pytest.mark.asyncio
async def test_poll_once_increments_jobs_done(monkeypatch):
worker.stats.jobs_done = 0
monkeypatch.setattr(worker, "run_suno_generation", lambda task_id, params: None)
handled = await worker.poll_once(_OneJobQueue())
assert handled is True
assert worker.stats.jobs_done == 1
assert worker.stats.busy is False
assert worker.stats.last_job_at is not None

View File

@@ -21,6 +21,7 @@ from providers.suno import (
) )
from providers.local import run_local_generation from providers.local import run_local_generation
from _shared.reliable_queue import ReliableQueue from _shared.reliable_queue import ReliableQueue
from _shared.heartbeat import WorkerStats, utc_now_iso
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -28,6 +29,8 @@ REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379")
QUEUE_KEY = "queue:music-render" QUEUE_KEY = "queue:music-render"
PAUSED_KEY = "queue:paused" PAUSED_KEY = "queue:paused"
stats = WorkerStats()
# Maps job_type → module-level function name (string). # Maps job_type → module-level function name (string).
# _dispatch resolves the name via globals() at call time so unittest.mock.patch # _dispatch resolves the name via globals() at call time so unittest.mock.patch
# on "worker.<name>" is correctly intercepted. # on "worker.<name>" is correctly intercepted.
@@ -74,6 +77,7 @@ async def poll_once(queue: ReliableQueue) -> bool:
if result is None: if result is None:
return False return False
payload, raw = result payload, raw = result
stats.busy = True
try: try:
# sync provider 함수 — thread로 실행해서 이벤트 루프 블로킹 방지 # sync provider 함수 — thread로 실행해서 이벤트 루프 블로킹 방지
await asyncio.to_thread(_dispatch, payload) await asyncio.to_thread(_dispatch, payload)
@@ -81,8 +85,14 @@ async def poll_once(queue: ReliableQueue) -> bool:
logger.exception("dispatch unhandled exception task_id=%s", logger.exception("dispatch unhandled exception task_id=%s",
payload.get("task_id")) payload.get("task_id"))
await queue.fail(raw, payload) await queue.fail(raw, payload)
stats.jobs_failed += 1
stats.last_job_at = utc_now_iso()
stats.busy = False
return True return True
await queue.ack(raw) await queue.ack(raw)
stats.jobs_done += 1
stats.last_job_at = utc_now_iso()
stats.busy = False
return True return True

View File

@@ -3,11 +3,14 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import redis.asyncio as aioredis
from fastapi import FastAPI from fastapi import FastAPI
import worker import worker
from _shared.heartbeat import heartbeat_loop
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -16,15 +19,19 @@ logger = logging.getLogger(__name__)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
worker_task = asyncio.create_task(worker.worker_loop()) worker_task = asyncio.create_task(worker.worker_loop())
hb_redis = aioredis.from_url(os.getenv("REDIS_URL", "redis://192.168.45.54:6379"), decode_responses=False)
hb_task = asyncio.create_task(heartbeat_loop(hb_redis, "video-render", "render", worker.stats))
logger.info("video-render lifespan 시작") logger.info("video-render lifespan 시작")
try: try:
yield yield
finally: finally:
worker_task.cancel() for t in (worker_task, hb_task):
try: t.cancel()
await worker_task try:
except asyncio.CancelledError: await t
pass except asyncio.CancelledError:
pass
await hb_redis.aclose()
logger.info("video-render lifespan 종료") logger.info("video-render lifespan 종료")

View File

@@ -94,3 +94,25 @@ async def test_poll_once_returns_false_on_timeout(monkeypatch):
assert handled is False assert handled is False
fake_queue.ack.assert_not_awaited() fake_queue.ack.assert_not_awaited()
fake_queue.fail.assert_not_awaited() fake_queue.fail.assert_not_awaited()
# ----- heartbeat stats 카운터 -----
class _OneJobQueue:
def __init__(self): self.acked = False
async def dequeue(self, timeout=5):
if self.acked: return None
return ({"job_type": "sora_generation", "task_id": "t1", "params": {}}, b"raw")
async def ack(self, raw): self.acked = True
async def fail(self, raw, payload): pass
@pytest.mark.asyncio
async def test_poll_once_increments_jobs_done(monkeypatch):
worker.stats.jobs_done = 0
monkeypatch.setattr(worker, "run_sora_generation", lambda task_id, params: None)
handled = await worker.poll_once(_OneJobQueue())
assert handled is True
assert worker.stats.jobs_done == 1
assert worker.stats.busy is False
assert worker.stats.last_job_at is not None

View File

@@ -19,6 +19,7 @@ from providers.veo import run_veo_generation
from providers.kling import run_kling_generation from providers.kling import run_kling_generation
from providers.seedance import run_seedance_generation from providers.seedance import run_seedance_generation
from _shared.reliable_queue import ReliableQueue from _shared.reliable_queue import ReliableQueue
from _shared.heartbeat import WorkerStats, utc_now_iso
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -26,6 +27,8 @@ REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379")
QUEUE_KEY = "queue:video-render" QUEUE_KEY = "queue:video-render"
PAUSED_KEY = "queue:paused" PAUSED_KEY = "queue:paused"
stats = WorkerStats()
# string names so `unittest.mock.patch` on `worker.<name>` is correctly intercepted # string names so `unittest.mock.patch` on `worker.<name>` is correctly intercepted
_DISPATCH_TABLE = { _DISPATCH_TABLE = {
"sora_generation": "run_sora_generation", "sora_generation": "run_sora_generation",
@@ -60,14 +63,21 @@ async def poll_once(queue: ReliableQueue) -> bool:
if result is None: if result is None:
return False return False
payload, raw = result payload, raw = result
stats.busy = True
try: try:
await asyncio.to_thread(_dispatch, payload) await asyncio.to_thread(_dispatch, payload)
except Exception: except Exception:
logger.exception("dispatch unhandled exception task_id=%s", logger.exception("dispatch unhandled exception task_id=%s",
payload.get("task_id")) payload.get("task_id"))
await queue.fail(raw, payload) await queue.fail(raw, payload)
stats.jobs_failed += 1
stats.last_job_at = utc_now_iso()
stats.busy = False
return True return True
await queue.ack(raw) await queue.ack(raw)
stats.jobs_done += 1
stats.last_job_at = utc_now_iso()
stats.busy = False
return True return True