feat(music-lab): 배치 음악 생성 endpoint + 자동 compile·video 파이프라인 오케스트레이터

- batch_generator.py: 장르별 N트랙 순차 Suno 생성 → 자동 compile → 자동 video pipeline
- main.py: POST/GET /api/music/generate-batch, GET /api/music/generate-batch/{id} 추가
- tests: 10개 endpoint 테스트 (검증·필터·404)
This commit is contained in:
2026-05-10 18:57:23 +09:00
parent f0cb06268e
commit 77b8d05ad7
3 changed files with 294 additions and 0 deletions

View File

@@ -0,0 +1,148 @@
"""배치 음악 생성 + 자동 컴파일·영상 파이프라인."""
import asyncio
import logging
import uuid
from . import db
from .random_pools import randomize
logger = logging.getLogger("music-lab.batch")
POLL_INTERVAL_S = 5
TRACK_GEN_TIMEOUT_S = 240
async def run_batch(batch_id: int) -> None:
"""1) genre로 N트랙 순차 Suno 생성
2) 모두 완료 후 compile_job 자동 생성·실행
3) compile 완료 후 영상 파이프라인 시작 (cover step)
"""
job = db.get_batch_job(batch_id)
if not job:
return
genre = job["genre"]
count = job["count"]
duration = job["target_duration_sec"]
auto_pipe = bool(job["auto_pipeline"])
db.update_batch_job(batch_id, status="generating")
track_ids: list[int] = []
for i in range(1, count + 1):
title = f"{genre.title()} Mix Track {i}"
params = randomize(genre)
db.update_batch_job(batch_id,
current_track_index=i,
current_track_status="generating")
track_id = await _generate_one_track(
title=title, genre=genre,
duration_sec=duration, params=params,
)
if track_id:
track_ids.append(track_id)
db.append_batch_track(batch_id, track_id)
db.update_batch_job(batch_id, current_track_status="succeeded")
else:
db.update_batch_job(batch_id, current_track_status="failed")
logger.warning("배치 %d 트랙 %d 실패 — 계속 진행", batch_id, i)
if not track_ids:
db.update_batch_job(batch_id, status="failed",
error="모든 트랙 생성 실패")
return
db.update_batch_job(batch_id, status="generated")
if not auto_pipe:
return
# 자동 컴파일
db.update_batch_job(batch_id, status="compiling")
try:
compile_id = db.create_compile_job(
title=f"{genre.title()} Mix",
track_ids=track_ids,
crossfade_sec=3.0,
)
db.update_batch_job(batch_id, compile_job_id=compile_id)
except Exception as e:
logger.exception("compile create failed")
db.update_batch_job(batch_id, status="failed", error=f"compile create: {e}")
return
from . import compiler
try:
await asyncio.to_thread(compiler.run_compile, compile_id)
except Exception as e:
logger.exception("compile run failed")
db.update_batch_job(batch_id, status="failed", error=f"compile run: {e}")
return
job_after = db.get_compile_job(compile_id)
status_after = job_after.get("status") if job_after else None
if status_after not in ("done", "succeeded"):
db.update_batch_job(
batch_id, status="failed",
error=f"compile not done (status={status_after})"
)
return
# 자동 영상 파이프라인
try:
pipeline_id = db.create_pipeline(compile_job_id=compile_id)
db.update_batch_job(batch_id, pipeline_id=pipeline_id, status="piped")
from .pipeline import orchestrator
await orchestrator.run_step(pipeline_id, "cover")
except Exception as e:
logger.exception("pipeline launch failed")
db.update_batch_job(batch_id, status="failed", error=f"pipeline launch: {e}")
async def _generate_one_track(*, title: str, genre: str, duration_sec: int,
params: dict) -> int | None:
"""기존 Suno generate 호출 + 완료까지 polling. 성공 시 새 track id, 실패 시 None."""
from .suno_provider import run_suno_generation
task_id = str(uuid.uuid4())
suno_params = {
"title": title,
"genre": genre,
"moods": params["moods"],
"instruments": params["instruments"],
"duration_sec": duration_sec,
"bpm": params["bpm"],
"key": params["key"],
"scale": params["scale"],
"prompt": params.get("prompt_modifier", ""),
}
db.create_task(task_id, suno_params, provider="suno")
# Suno background task — 우리가 await로 기다림 (BackgroundTasks 미사용)
asyncio.create_task(asyncio.to_thread(run_suno_generation, task_id, suno_params))
waited = 0
while waited < TRACK_GEN_TIMEOUT_S:
await asyncio.sleep(POLL_INTERVAL_S)
waited += POLL_INTERVAL_S
task = db.get_task(task_id)
if not task:
continue
status = task.get("status")
if status == "succeeded":
# task["track"] 또는 task["result"]["track"] 형태 시도, 없으면 task_id로 조회
tr = task.get("track")
if tr and isinstance(tr, dict):
return tr.get("id")
result = task.get("result", {}) or {}
if isinstance(result, dict) and isinstance(result.get("track"), dict):
return result["track"].get("id")
# Fallback: music_library에서 task_id로 검색
track = db.get_track_by_task_id(task_id)
if track:
return track.get("id")
return None
if status == "failed":
return None
return None # timeout

View File

@@ -35,6 +35,7 @@ from .suno_provider import (
generate_lyrics, get_credits, get_timestamped_lyrics, generate_style_boost, generate_lyrics, get_credits, get_timestamped_lyrics, generate_style_boost,
SUNO_API_KEY, SUNO_MODELS, SUNO_API_KEY, SUNO_MODELS,
) )
from .batch_generator import run_batch as _run_batch
app = FastAPI() app = FastAPI()
@@ -849,6 +850,62 @@ def export_compile(job_id: int):
} }
# ── 배치 음악 생성 API ────────────────────────────────────────────────────────
class BatchGenerateRequest(BaseModel):
genre: str
count: int = 10
target_duration_sec: int = 180
auto_pipeline: bool = True
@app.post("/api/music/generate-batch", status_code=201)
async def generate_batch(req: BatchGenerateRequest, bg: BackgroundTasks):
if not (1 <= req.count <= 10):
raise HTTPException(status_code=400, detail="count는 1-10 사이")
if not (60 <= req.target_duration_sec <= 300):
raise HTTPException(status_code=400, detail="target_duration_sec는 60-300 사이")
if not req.genre:
raise HTTPException(status_code=400, detail="genre 필수")
if not SUNO_API_KEY:
raise HTTPException(status_code=400, detail="SUNO_API_KEY 미설정")
batch_id = _db_module.create_batch_job(
genre=req.genre, count=req.count,
target_duration_sec=req.target_duration_sec,
auto_pipeline=req.auto_pipeline,
)
bg.add_task(_run_batch, batch_id)
return _db_module.get_batch_job(batch_id)
@app.get("/api/music/generate-batch/{batch_id}")
def get_batch(batch_id: int):
j = _db_module.get_batch_job(batch_id)
if not j:
raise HTTPException(status_code=404, detail="Not found")
if j["track_ids"]:
ids_csv = ",".join(str(i) for i in j["track_ids"])
import sqlite3
conn = sqlite3.connect(_db_module.DB_PATH)
conn.row_factory = sqlite3.Row
rows = conn.execute(
f"SELECT id, title, audio_url, duration_sec FROM music_library WHERE id IN ({ids_csv})"
).fetchall()
conn.close()
# 트랙을 batch.track_ids 순서대로 정렬
by_id = {r["id"]: dict(r) for r in rows}
j["tracks"] = [by_id.get(tid) for tid in j["track_ids"] if tid in by_id]
else:
j["tracks"] = []
return j
@app.get("/api/music/generate-batch")
def list_batches(status: str = "all"):
return {"batches": _db_module.list_batch_jobs(active_only=(status == "active"))}
# ── 수익화 추적 API ─────────────────────────────────────────────────────────── # ── 수익화 추적 API ───────────────────────────────────────────────────────────
@app.get("/api/music/revenue/dashboard") @app.get("/api/music/revenue/dashboard")

View File

@@ -0,0 +1,89 @@
import pytest
from unittest.mock import AsyncMock, patch
from fastapi.testclient import TestClient
import app.main as main_module
from app import db
@pytest.fixture
def client(monkeypatch, tmp_path):
monkeypatch.setattr(db, "DB_PATH", str(tmp_path / "music.db"))
db.init_db()
monkeypatch.setenv("SUNO_API_KEY", "test")
# main.py의 SUNO_API_KEY 모듈 변수도 갱신 필요할 수 있음
monkeypatch.setattr(main_module, "SUNO_API_KEY", "test", raising=False)
return TestClient(main_module.app)
def test_create_batch_201(client):
with patch.object(main_module, "_run_batch", new=AsyncMock()):
r = client.post("/api/music/generate-batch",
json={"genre": "lo-fi", "count": 3})
assert r.status_code == 201, r.text
body = r.json()
assert body["genre"] == "lo-fi"
assert body["count"] == 3
assert body["status"] == "queued"
def test_create_batch_rejects_count_too_high(client):
r = client.post("/api/music/generate-batch",
json={"genre": "lo-fi", "count": 11})
assert r.status_code == 400
def test_create_batch_rejects_count_zero(client):
r = client.post("/api/music/generate-batch",
json={"genre": "lo-fi", "count": 0})
assert r.status_code == 400
def test_create_batch_rejects_no_genre(client):
r = client.post("/api/music/generate-batch", json={"count": 3})
assert r.status_code in (400, 422)
def test_create_batch_rejects_invalid_duration(client):
r = client.post("/api/music/generate-batch",
json={"genre": "lo-fi", "count": 3, "target_duration_sec": 30})
assert r.status_code == 400
def test_create_batch_rejects_no_suno_key(client, monkeypatch):
monkeypatch.setattr(main_module, "SUNO_API_KEY", "", raising=False)
r = client.post("/api/music/generate-batch",
json={"genre": "lo-fi", "count": 3})
assert r.status_code == 400
def test_get_batch_returns_tracks(client):
bid = db.create_batch_job(genre="lo-fi", count=2)
db.append_batch_track(bid, 999) # phantom track
r = client.get(f"/api/music/generate-batch/{bid}")
assert r.status_code == 200
body = r.json()
assert body["track_ids"] == [999]
assert body["tracks"] == [] # 999 not in music_library
def test_get_batch_404(client):
r = client.get("/api/music/generate-batch/99999")
assert r.status_code == 404
def test_list_batches(client):
db.create_batch_job(genre="lo-fi", count=1)
db.create_batch_job(genre="phonk", count=2)
r = client.get("/api/music/generate-batch")
assert r.status_code == 200
assert len(r.json()["batches"]) == 2
def test_list_batches_active_filter(client):
b1 = db.create_batch_job(genre="lo-fi", count=1)
b2 = db.create_batch_job(genre="phonk", count=2)
db.update_batch_job(b1, status="failed")
r = client.get("/api/music/generate-batch?status=active")
ids = [j["id"] for j in r.json()["batches"]]
assert b2 in ids
assert b1 not in ids