From 77b8d05ad73e8d94611d570259423c84b5b9e235 Mon Sep 17 00:00:00 2001 From: gahusb Date: Sun, 10 May 2026 18:57:23 +0900 Subject: [PATCH] =?UTF-8?q?feat(music-lab):=20=EB=B0=B0=EC=B9=98=20?= =?UTF-8?q?=EC=9D=8C=EC=95=85=20=EC=83=9D=EC=84=B1=20endpoint=20+=20?= =?UTF-8?q?=EC=9E=90=EB=8F=99=20compile=C2=B7video=20=ED=8C=8C=EC=9D=B4?= =?UTF-8?q?=ED=94=84=EB=9D=BC=EC=9D=B8=20=EC=98=A4=EC=BC=80=EC=8A=A4?= =?UTF-8?q?=ED=8A=B8=EB=A0=88=EC=9D=B4=ED=84=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - batch_generator.py: 장르별 N트랙 순차 Suno 생성 → 자동 compile → 자동 video pipeline - main.py: POST/GET /api/music/generate-batch, GET /api/music/generate-batch/{id} 추가 - tests: 10개 endpoint 테스트 (검증·필터·404) --- music-lab/app/batch_generator.py | 148 ++++++++++++++++++++++++ music-lab/app/main.py | 57 +++++++++ music-lab/tests/test_batch_endpoints.py | 89 ++++++++++++++ 3 files changed, 294 insertions(+) create mode 100644 music-lab/app/batch_generator.py create mode 100644 music-lab/tests/test_batch_endpoints.py diff --git a/music-lab/app/batch_generator.py b/music-lab/app/batch_generator.py new file mode 100644 index 0000000..cea332c --- /dev/null +++ b/music-lab/app/batch_generator.py @@ -0,0 +1,148 @@ +"""배치 음악 생성 + 자동 컴파일·영상 파이프라인.""" +import asyncio +import logging +import uuid + +from . import db +from .random_pools import randomize + +logger = logging.getLogger("music-lab.batch") + +POLL_INTERVAL_S = 5 +TRACK_GEN_TIMEOUT_S = 240 + + +async def run_batch(batch_id: int) -> None: + """1) genre로 N트랙 순차 Suno 생성 + 2) 모두 완료 후 compile_job 자동 생성·실행 + 3) compile 완료 후 영상 파이프라인 시작 (cover step) + """ + job = db.get_batch_job(batch_id) + if not job: + return + genre = job["genre"] + count = job["count"] + duration = job["target_duration_sec"] + auto_pipe = bool(job["auto_pipeline"]) + + db.update_batch_job(batch_id, status="generating") + + track_ids: list[int] = [] + for i in range(1, count + 1): + title = f"{genre.title()} Mix Track {i}" + params = randomize(genre) + db.update_batch_job(batch_id, + current_track_index=i, + current_track_status="generating") + + track_id = await _generate_one_track( + title=title, genre=genre, + duration_sec=duration, params=params, + ) + if track_id: + track_ids.append(track_id) + db.append_batch_track(batch_id, track_id) + db.update_batch_job(batch_id, current_track_status="succeeded") + else: + db.update_batch_job(batch_id, current_track_status="failed") + logger.warning("배치 %d 트랙 %d 실패 — 계속 진행", batch_id, i) + + if not track_ids: + db.update_batch_job(batch_id, status="failed", + error="모든 트랙 생성 실패") + return + + db.update_batch_job(batch_id, status="generated") + + if not auto_pipe: + return + + # 자동 컴파일 + db.update_batch_job(batch_id, status="compiling") + try: + compile_id = db.create_compile_job( + title=f"{genre.title()} Mix", + track_ids=track_ids, + crossfade_sec=3.0, + ) + db.update_batch_job(batch_id, compile_job_id=compile_id) + except Exception as e: + logger.exception("compile create failed") + db.update_batch_job(batch_id, status="failed", error=f"compile create: {e}") + return + + from . import compiler + try: + await asyncio.to_thread(compiler.run_compile, compile_id) + except Exception as e: + logger.exception("compile run failed") + db.update_batch_job(batch_id, status="failed", error=f"compile run: {e}") + return + + job_after = db.get_compile_job(compile_id) + status_after = job_after.get("status") if job_after else None + if status_after not in ("done", "succeeded"): + db.update_batch_job( + batch_id, status="failed", + error=f"compile not done (status={status_after})" + ) + return + + # 자동 영상 파이프라인 + try: + pipeline_id = db.create_pipeline(compile_job_id=compile_id) + db.update_batch_job(batch_id, pipeline_id=pipeline_id, status="piped") + + from .pipeline import orchestrator + await orchestrator.run_step(pipeline_id, "cover") + except Exception as e: + logger.exception("pipeline launch failed") + db.update_batch_job(batch_id, status="failed", error=f"pipeline launch: {e}") + + +async def _generate_one_track(*, title: str, genre: str, duration_sec: int, + params: dict) -> int | None: + """기존 Suno generate 호출 + 완료까지 polling. 성공 시 새 track id, 실패 시 None.""" + from .suno_provider import run_suno_generation + + task_id = str(uuid.uuid4()) + suno_params = { + "title": title, + "genre": genre, + "moods": params["moods"], + "instruments": params["instruments"], + "duration_sec": duration_sec, + "bpm": params["bpm"], + "key": params["key"], + "scale": params["scale"], + "prompt": params.get("prompt_modifier", ""), + } + db.create_task(task_id, suno_params, provider="suno") + + # Suno background task — 우리가 await로 기다림 (BackgroundTasks 미사용) + asyncio.create_task(asyncio.to_thread(run_suno_generation, task_id, suno_params)) + + waited = 0 + while waited < TRACK_GEN_TIMEOUT_S: + await asyncio.sleep(POLL_INTERVAL_S) + waited += POLL_INTERVAL_S + task = db.get_task(task_id) + if not task: + continue + status = task.get("status") + if status == "succeeded": + # task["track"] 또는 task["result"]["track"] 형태 시도, 없으면 task_id로 조회 + tr = task.get("track") + if tr and isinstance(tr, dict): + return tr.get("id") + result = task.get("result", {}) or {} + if isinstance(result, dict) and isinstance(result.get("track"), dict): + return result["track"].get("id") + # Fallback: music_library에서 task_id로 검색 + track = db.get_track_by_task_id(task_id) + if track: + return track.get("id") + return None + if status == "failed": + return None + return None # timeout diff --git a/music-lab/app/main.py b/music-lab/app/main.py index 906f398..6d4dcd1 100644 --- a/music-lab/app/main.py +++ b/music-lab/app/main.py @@ -35,6 +35,7 @@ from .suno_provider import ( generate_lyrics, get_credits, get_timestamped_lyrics, generate_style_boost, SUNO_API_KEY, SUNO_MODELS, ) +from .batch_generator import run_batch as _run_batch app = FastAPI() @@ -849,6 +850,62 @@ def export_compile(job_id: int): } +# ── 배치 음악 생성 API ──────────────────────────────────────────────────────── + +class BatchGenerateRequest(BaseModel): + genre: str + count: int = 10 + target_duration_sec: int = 180 + auto_pipeline: bool = True + + +@app.post("/api/music/generate-batch", status_code=201) +async def generate_batch(req: BatchGenerateRequest, bg: BackgroundTasks): + if not (1 <= req.count <= 10): + raise HTTPException(status_code=400, detail="count는 1-10 사이") + if not (60 <= req.target_duration_sec <= 300): + raise HTTPException(status_code=400, detail="target_duration_sec는 60-300 사이") + if not req.genre: + raise HTTPException(status_code=400, detail="genre 필수") + if not SUNO_API_KEY: + raise HTTPException(status_code=400, detail="SUNO_API_KEY 미설정") + + batch_id = _db_module.create_batch_job( + genre=req.genre, count=req.count, + target_duration_sec=req.target_duration_sec, + auto_pipeline=req.auto_pipeline, + ) + bg.add_task(_run_batch, batch_id) + return _db_module.get_batch_job(batch_id) + + +@app.get("/api/music/generate-batch/{batch_id}") +def get_batch(batch_id: int): + j = _db_module.get_batch_job(batch_id) + if not j: + raise HTTPException(status_code=404, detail="Not found") + if j["track_ids"]: + ids_csv = ",".join(str(i) for i in j["track_ids"]) + import sqlite3 + conn = sqlite3.connect(_db_module.DB_PATH) + conn.row_factory = sqlite3.Row + rows = conn.execute( + f"SELECT id, title, audio_url, duration_sec FROM music_library WHERE id IN ({ids_csv})" + ).fetchall() + conn.close() + # 트랙을 batch.track_ids 순서대로 정렬 + by_id = {r["id"]: dict(r) for r in rows} + j["tracks"] = [by_id.get(tid) for tid in j["track_ids"] if tid in by_id] + else: + j["tracks"] = [] + return j + + +@app.get("/api/music/generate-batch") +def list_batches(status: str = "all"): + return {"batches": _db_module.list_batch_jobs(active_only=(status == "active"))} + + # ── 수익화 추적 API ─────────────────────────────────────────────────────────── @app.get("/api/music/revenue/dashboard") diff --git a/music-lab/tests/test_batch_endpoints.py b/music-lab/tests/test_batch_endpoints.py new file mode 100644 index 0000000..d769109 --- /dev/null +++ b/music-lab/tests/test_batch_endpoints.py @@ -0,0 +1,89 @@ +import pytest +from unittest.mock import AsyncMock, patch +from fastapi.testclient import TestClient +import app.main as main_module +from app import db + + +@pytest.fixture +def client(monkeypatch, tmp_path): + monkeypatch.setattr(db, "DB_PATH", str(tmp_path / "music.db")) + db.init_db() + monkeypatch.setenv("SUNO_API_KEY", "test") + # main.py의 SUNO_API_KEY 모듈 변수도 갱신 필요할 수 있음 + monkeypatch.setattr(main_module, "SUNO_API_KEY", "test", raising=False) + return TestClient(main_module.app) + + +def test_create_batch_201(client): + with patch.object(main_module, "_run_batch", new=AsyncMock()): + r = client.post("/api/music/generate-batch", + json={"genre": "lo-fi", "count": 3}) + assert r.status_code == 201, r.text + body = r.json() + assert body["genre"] == "lo-fi" + assert body["count"] == 3 + assert body["status"] == "queued" + + +def test_create_batch_rejects_count_too_high(client): + r = client.post("/api/music/generate-batch", + json={"genre": "lo-fi", "count": 11}) + assert r.status_code == 400 + + +def test_create_batch_rejects_count_zero(client): + r = client.post("/api/music/generate-batch", + json={"genre": "lo-fi", "count": 0}) + assert r.status_code == 400 + + +def test_create_batch_rejects_no_genre(client): + r = client.post("/api/music/generate-batch", json={"count": 3}) + assert r.status_code in (400, 422) + + +def test_create_batch_rejects_invalid_duration(client): + r = client.post("/api/music/generate-batch", + json={"genre": "lo-fi", "count": 3, "target_duration_sec": 30}) + assert r.status_code == 400 + + +def test_create_batch_rejects_no_suno_key(client, monkeypatch): + monkeypatch.setattr(main_module, "SUNO_API_KEY", "", raising=False) + r = client.post("/api/music/generate-batch", + json={"genre": "lo-fi", "count": 3}) + assert r.status_code == 400 + + +def test_get_batch_returns_tracks(client): + bid = db.create_batch_job(genre="lo-fi", count=2) + db.append_batch_track(bid, 999) # phantom track + r = client.get(f"/api/music/generate-batch/{bid}") + assert r.status_code == 200 + body = r.json() + assert body["track_ids"] == [999] + assert body["tracks"] == [] # 999 not in music_library + + +def test_get_batch_404(client): + r = client.get("/api/music/generate-batch/99999") + assert r.status_code == 404 + + +def test_list_batches(client): + db.create_batch_job(genre="lo-fi", count=1) + db.create_batch_job(genre="phonk", count=2) + r = client.get("/api/music/generate-batch") + assert r.status_code == 200 + assert len(r.json()["batches"]) == 2 + + +def test_list_batches_active_filter(client): + b1 = db.create_batch_job(genre="lo-fi", count=1) + b2 = db.create_batch_job(genre="phonk", count=2) + db.update_batch_job(b1, status="failed") + r = client.get("/api/music/generate-batch?status=active") + ids = [j["id"] for j in r.json()["batches"]] + assert b2 in ids + assert b1 not in ids