diff --git a/music-lab/app/batch_generator.py b/music-lab/app/batch_generator.py index cea332c..6e630fe 100644 --- a/music-lab/app/batch_generator.py +++ b/music-lab/app/batch_generator.py @@ -102,8 +102,10 @@ async def run_batch(batch_id: int) -> None: async def _generate_one_track(*, title: str, genre: str, duration_sec: int, params: dict) -> int | None: - """기존 Suno generate 호출 + 완료까지 polling. 성공 시 새 track id, 실패 시 None.""" - from .suno_provider import run_suno_generation + """Redis 큐에 push + task 상태 polling. 성공 시 새 track id, 실패 시 None.""" + import json + from datetime import datetime, timezone, timedelta + from .main import redis_client # 같은 컨테이너 — 동일 redis 클라이언트 공유 task_id = str(uuid.uuid4()) suno_params = { @@ -116,11 +118,23 @@ async def _generate_one_track(*, title: str, genre: str, duration_sec: int, "key": params["key"], "scale": params["scale"], "prompt": params.get("prompt_modifier", ""), + "provider": "suno", + "model": "V4", + "instrumental": False, + "lyrics": "", } db.create_task(task_id, suno_params, provider="suno") - # Suno background task — 우리가 await로 기다림 (BackgroundTasks 미사용) - asyncio.create_task(asyncio.to_thread(run_suno_generation, task_id, suno_params)) + # Redis push (Windows music-render가 BLPOP 처리) + kst = timezone(timedelta(hours=9)) + payload = { + "task_id": task_id, + "kind": "music", + "job_type": "suno_generation", + "params": suno_params, + "submitted_at": datetime.now(kst).isoformat(), + } + await redis_client.rpush("queue:music-render", json.dumps(payload)) waited = 0 while waited < TRACK_GEN_TIMEOUT_S: @@ -131,14 +145,7 @@ async def _generate_one_track(*, title: str, genre: str, duration_sec: int, continue status = task.get("status") if status == "succeeded": - # task["track"] 또는 task["result"]["track"] 형태 시도, 없으면 task_id로 조회 - tr = task.get("track") - if tr and isinstance(tr, dict): - return tr.get("id") - result = task.get("result", {}) or {} - if isinstance(result, dict) and isinstance(result.get("track"), dict): - return result["track"].get("id") - # Fallback: music_library에서 task_id로 검색 + # Windows webhook이 add_track 했으므로 task_id로 검색 track = db.get_track_by_task_id(task_id) if track: return track.get("id")