Files
ai-trade/services/image-render/worker.py
gahusb 2ff31b2e76 feat(render-workers): 4 render 워커 heartbeat 배선 + poll_once 카운터
- services/_shared/heartbeat.py (A1) WorkerStats/utc_now_iso/heartbeat_loop 소비
- image-render / video-render / music-render / insta-render 각 worker.py:
  stats = WorkerStats() 모듈 레벨 추가, poll_once에서 dispatch 전 busy=True,
  ack 후 jobs_done+1 / fail 후 jobs_failed+1 + last_job_at + busy=False
- 각 main.py: lifespan에 aioredis(decode_responses=False) + heartbeat_loop 태스크 spawn,
  종료 시 cancel + aclose
- 각 tests/test_worker.py: test_poll_once_increments_jobs_done 추가
  (image:flux / video:sora / music:suno / insta:_process_one mock)

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_019LV86jBozkNhSFXJA412fq
2026-07-01 00:52:57 +09:00

113 lines
3.8 KiB
Python

"""Redis ReliableQueue worker — F6 신뢰성 패턴 (BLMOVE + ack/fail + recovery).
queue:paused 가 set이면 대기 (task-watcher가 박재오 활동 감지 시 set).
string-based dispatch + getattr (테스트 patch 호환).
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import sys
import redis.asyncio as aioredis
from nas_client import webhook_update_task
from providers.gpt_image import run_gpt_image_generation
from providers.nano_banana import run_nano_banana_generation
from providers.flux import run_flux_generation
from _shared.reliable_queue import ReliableQueue
from _shared.heartbeat import WorkerStats, utc_now_iso
logger = logging.getLogger(__name__)
REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379")
QUEUE_KEY = "queue:image-render"
PAUSED_KEY = "queue:paused"
stats = WorkerStats()
# string names so `unittest.mock.patch` / `monkeypatch.setattr` on `worker.<name>`
# is correctly intercepted by getattr(sys.modules[__name__], ...)
_DISPATCH_TABLE = {
"gpt_image_generation": "run_gpt_image_generation",
"nano_banana_generation": "run_nano_banana_generation",
"flux_generation": "run_flux_generation",
}
def _dispatch(payload: dict) -> None:
"""payload[job_type] → provider 함수 호출 (sync, worker_loop에서 asyncio.to_thread로 wrap)."""
job_type = payload.get("job_type", "")
task_id = payload.get("task_id", "")
params = payload.get("params", {})
fn_name = _DISPATCH_TABLE.get(job_type)
if fn_name is None:
logger.error("unknown job_type=%s task=%s", job_type, task_id)
webhook_update_task(task_id, "failed", 0, "", error=f"unknown job_type: {job_type}")
return
try:
fn = getattr(sys.modules[__name__], fn_name)
except AttributeError:
logger.error("dispatch table typo for job_type=%s name=%s task=%s", job_type, fn_name, task_id)
webhook_update_task(task_id, "failed", 0, "", error=f"internal dispatch error: {fn_name}")
return
fn(task_id, params)
async def poll_once(queue: ReliableQueue) -> bool:
"""F6 — 1 cycle: dequeue → _dispatch → ack/fail. Returns True if a job handled."""
result = await queue.dequeue(timeout=5)
if result is None:
return False
payload, raw = result
stats.busy = True
try:
await asyncio.to_thread(_dispatch, payload)
except Exception:
logger.exception("dispatch unhandled exception task_id=%s",
payload.get("task_id"))
await queue.fail(raw, payload)
stats.jobs_failed += 1
stats.last_job_at = utc_now_iso()
stats.busy = False
return True
await queue.ack(raw)
stats.jobs_done += 1
stats.last_job_at = utc_now_iso()
stats.busy = False
return True
async def worker_loop():
redis = aioredis.from_url(REDIS_URL, decode_responses=False)
queue = ReliableQueue(redis, queue_key=QUEUE_KEY)
logger.info("image-render worker started worker_id=%s queue=%s",
queue.worker_id, QUEUE_KEY)
try:
recovered = await queue.recover()
if recovered:
logger.info("recovered %d orphaned items at startup", recovered)
except Exception:
logger.exception("startup recover failed")
while True:
try:
paused = await redis.get(PAUSED_KEY)
if paused == b"1":
await asyncio.sleep(10)
continue
await poll_once(queue)
except asyncio.CancelledError:
logger.info("worker_loop cancelled")
raise
except Exception:
logger.exception("worker_loop iteration 실패, 5초 후 재시도")
await asyncio.sleep(5)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
asyncio.run(worker_loop())