diff --git a/services/music-render/tests/test_worker.py b/services/music-render/tests/test_worker.py new file mode 100644 index 0000000..d1fd189 --- /dev/null +++ b/services/music-render/tests/test_worker.py @@ -0,0 +1,109 @@ +"""worker.py — job_type 디스패처 + paused 체크.""" +import json +import pytest +from unittest.mock import MagicMock, patch + +import worker + + +def test_dispatch_suno_generation_calls_run_suno_generation(): + payload = { + "task_id": "t1", + "job_type": "suno_generation", + "params": {"genre": "lofi", "title": "x"}, + } + with patch("worker.run_suno_generation") as m: + worker._dispatch(payload) + m.assert_called_once_with("t1", {"genre": "lofi", "title": "x"}) + + +def test_dispatch_local_generation_calls_run_local_generation(): + payload = { + "task_id": "t2", + "job_type": "local_generation", + "params": {"genre": "ambient"}, + } + with patch("worker.run_local_generation") as m: + worker._dispatch(payload) + m.assert_called_once_with("t2", {"genre": "ambient"}) + + +def test_dispatch_unknown_job_type_logs_error(): + payload = {"task_id": "t3", "job_type": "weird_type", "params": {}} + with patch("worker.webhook_update_task") as m: + worker._dispatch(payload) + # 알 수 없는 job_type은 failed로 보고 + m.assert_called_once() + args = m.call_args[0] + assert args[0] == "t3" + assert args[1] == "failed" + + +def test_dispatch_suno_extend_calls_run_suno_extend(): + payload = {"task_id": "t4", "job_type": "suno_extend", "params": {"suno_id": "abc"}} + with patch("worker.run_suno_extend") as m: + worker._dispatch(payload) + m.assert_called_once_with("t4", {"suno_id": "abc"}) + + +def test_dispatch_vocal_removal_calls_run_vocal_removal(): + payload = {"task_id": "t5", "job_type": "vocal_removal", "params": {"suno_id": "abc"}} + with patch("worker.run_vocal_removal") as m: + worker._dispatch(payload) + m.assert_called_once_with("t5", {"suno_id": "abc"}) + + +def test_dispatch_cover_image_calls_run_cover_image(): + payload = {"task_id": "t6", "job_type": "cover_image", "params": {"suno_task_id": "x"}} + with patch("worker.run_cover_image") as m: + worker._dispatch(payload) + m.assert_called_once_with("t6", {"suno_task_id": "x"}) + + +def test_dispatch_wav_convert_calls_run_wav_convert(): + payload = {"task_id": "t7", "job_type": "wav_convert", "params": {"suno_task_id": "x", "suno_id": "y"}} + with patch("worker.run_wav_convert") as m: + worker._dispatch(payload) + m.assert_called_once_with("t7", {"suno_task_id": "x", "suno_id": "y"}) + + +def test_dispatch_stem_split_calls_run_stem_split(): + payload = {"task_id": "t8", "job_type": "stem_split", "params": {"suno_task_id": "x", "suno_id": "y"}} + with patch("worker.run_stem_split") as m: + worker._dispatch(payload) + m.assert_called_once_with("t8", {"suno_task_id": "x", "suno_id": "y"}) + + +def test_dispatch_video_generate_calls_run_video_generate(): + payload = {"task_id": "t9", "job_type": "video_generate", "params": {"suno_task_id": "x", "suno_id": "y"}} + with patch("worker.run_video_generate") as m: + worker._dispatch(payload) + m.assert_called_once_with("t9", {"suno_task_id": "x", "suno_id": "y"}) + + +def test_dispatch_upload_cover_calls_run_upload_cover(): + payload = {"task_id": "t10", "job_type": "upload_cover", "params": {"upload_url": "u"}} + with patch("worker.run_upload_cover") as m: + worker._dispatch(payload) + m.assert_called_once_with("t10", {"upload_url": "u"}) + + +def test_dispatch_upload_extend_calls_run_upload_extend(): + payload = {"task_id": "t11", "job_type": "upload_extend", "params": {"upload_url": "u"}} + with patch("worker.run_upload_extend") as m: + worker._dispatch(payload) + m.assert_called_once_with("t11", {"upload_url": "u"}) + + +def test_dispatch_add_vocals_calls_run_add_vocals(): + payload = {"task_id": "t12", "job_type": "add_vocals", "params": {"upload_url": "u"}} + with patch("worker.run_add_vocals") as m: + worker._dispatch(payload) + m.assert_called_once_with("t12", {"upload_url": "u"}) + + +def test_dispatch_add_instrumental_calls_run_add_instrumental(): + payload = {"task_id": "t13", "job_type": "add_instrumental", "params": {"upload_url": "u"}} + with patch("worker.run_add_instrumental") as m: + worker._dispatch(payload) + m.assert_called_once_with("t13", {"upload_url": "u"}) diff --git a/services/music-render/worker.py b/services/music-render/worker.py new file mode 100644 index 0000000..b84034c --- /dev/null +++ b/services/music-render/worker.py @@ -0,0 +1,90 @@ +"""Redis BLPOP worker — queue:music-render → job_type 디스패치 → NAS webhook. + +queue:paused 가 set이면 대기 (task-watcher가 박재오 활동 감지 시 set). +""" +from __future__ import annotations + +import asyncio +import json +import logging +import os +from typing import Any + +import redis.asyncio as aioredis + +from nas_client import webhook_update_task +from providers.suno import ( + run_suno_generation, run_suno_extend, run_vocal_removal, + run_cover_image, run_wav_convert, run_stem_split, + run_upload_cover, run_upload_extend, run_add_vocals, + run_add_instrumental, run_video_generate, +) +from providers.local import run_local_generation + +logger = logging.getLogger(__name__) + +REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379") +QUEUE_KEY = "queue:music-render" +PAUSED_KEY = "queue:paused" + +# Maps job_type → module-level function name (string). +# _dispatch resolves the name via globals() at call time so unittest.mock.patch +# on "worker." is correctly intercepted. +_DISPATCH_TABLE: dict[str, str] = { + "suno_generation": "run_suno_generation", + "local_generation": "run_local_generation", + "suno_extend": "run_suno_extend", + "vocal_removal": "run_vocal_removal", + "cover_image": "run_cover_image", + "wav_convert": "run_wav_convert", + "stem_split": "run_stem_split", + "upload_cover": "run_upload_cover", + "upload_extend": "run_upload_extend", + "add_vocals": "run_add_vocals", + "add_instrumental": "run_add_instrumental", + "video_generate": "run_video_generate", +} + + +def _dispatch(payload: dict) -> None: + """payload[job_type] → provider 함수 호출 (sync, asyncio.to_thread로 래핑).""" + import sys + _self = sys.modules[__name__] + job_type = payload.get("job_type", "") + task_id = payload.get("task_id", "") + params = payload.get("params", {}) + fn_name = _DISPATCH_TABLE.get(job_type) + if fn_name is None: + logger.error("unknown job_type=%s task=%s", job_type, task_id) + webhook_update_task(task_id, "failed", 0, "", error=f"unknown job_type: {job_type}") + return + fn = getattr(_self, fn_name) + fn(task_id, params) + + +async def worker_loop(): + redis = aioredis.from_url(REDIS_URL, decode_responses=False) + logger.info("music-render worker started (queue=%s)", QUEUE_KEY) + while True: + try: + paused = await redis.get(PAUSED_KEY) + if paused == b"1": + await asyncio.sleep(10) + continue + item = await redis.blpop(QUEUE_KEY, timeout=1) + if item is None: + continue + _, raw = item + try: + payload = json.loads(raw) + except json.JSONDecodeError: + logger.error("invalid queue payload: %r", raw[:200]) + continue + # sync provider 함수 — thread로 실행해서 이벤트 루프 블로킹 방지 + await asyncio.to_thread(_dispatch, payload) + except asyncio.CancelledError: + logger.info("worker_loop cancelled") + raise + except Exception: + logger.exception("worker_loop iteration 실패, 5초 후 재시도") + await asyncio.sleep(5)