Compare commits
2 Commits
f074cbec2d
...
77b8d05ad7
| Author | SHA1 | Date | |
|---|---|---|---|
| 77b8d05ad7 | |||
| f0cb06268e |
148
music-lab/app/batch_generator.py
Normal file
148
music-lab/app/batch_generator.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""배치 음악 생성 + 자동 컴파일·영상 파이프라인."""
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from . import db
|
||||
from .random_pools import randomize
|
||||
|
||||
logger = logging.getLogger("music-lab.batch")
|
||||
|
||||
POLL_INTERVAL_S = 5
|
||||
TRACK_GEN_TIMEOUT_S = 240
|
||||
|
||||
|
||||
async def run_batch(batch_id: int) -> None:
|
||||
"""1) genre로 N트랙 순차 Suno 생성
|
||||
2) 모두 완료 후 compile_job 자동 생성·실행
|
||||
3) compile 완료 후 영상 파이프라인 시작 (cover step)
|
||||
"""
|
||||
job = db.get_batch_job(batch_id)
|
||||
if not job:
|
||||
return
|
||||
genre = job["genre"]
|
||||
count = job["count"]
|
||||
duration = job["target_duration_sec"]
|
||||
auto_pipe = bool(job["auto_pipeline"])
|
||||
|
||||
db.update_batch_job(batch_id, status="generating")
|
||||
|
||||
track_ids: list[int] = []
|
||||
for i in range(1, count + 1):
|
||||
title = f"{genre.title()} Mix Track {i}"
|
||||
params = randomize(genre)
|
||||
db.update_batch_job(batch_id,
|
||||
current_track_index=i,
|
||||
current_track_status="generating")
|
||||
|
||||
track_id = await _generate_one_track(
|
||||
title=title, genre=genre,
|
||||
duration_sec=duration, params=params,
|
||||
)
|
||||
if track_id:
|
||||
track_ids.append(track_id)
|
||||
db.append_batch_track(batch_id, track_id)
|
||||
db.update_batch_job(batch_id, current_track_status="succeeded")
|
||||
else:
|
||||
db.update_batch_job(batch_id, current_track_status="failed")
|
||||
logger.warning("배치 %d 트랙 %d 실패 — 계속 진행", batch_id, i)
|
||||
|
||||
if not track_ids:
|
||||
db.update_batch_job(batch_id, status="failed",
|
||||
error="모든 트랙 생성 실패")
|
||||
return
|
||||
|
||||
db.update_batch_job(batch_id, status="generated")
|
||||
|
||||
if not auto_pipe:
|
||||
return
|
||||
|
||||
# 자동 컴파일
|
||||
db.update_batch_job(batch_id, status="compiling")
|
||||
try:
|
||||
compile_id = db.create_compile_job(
|
||||
title=f"{genre.title()} Mix",
|
||||
track_ids=track_ids,
|
||||
crossfade_sec=3.0,
|
||||
)
|
||||
db.update_batch_job(batch_id, compile_job_id=compile_id)
|
||||
except Exception as e:
|
||||
logger.exception("compile create failed")
|
||||
db.update_batch_job(batch_id, status="failed", error=f"compile create: {e}")
|
||||
return
|
||||
|
||||
from . import compiler
|
||||
try:
|
||||
await asyncio.to_thread(compiler.run_compile, compile_id)
|
||||
except Exception as e:
|
||||
logger.exception("compile run failed")
|
||||
db.update_batch_job(batch_id, status="failed", error=f"compile run: {e}")
|
||||
return
|
||||
|
||||
job_after = db.get_compile_job(compile_id)
|
||||
status_after = job_after.get("status") if job_after else None
|
||||
if status_after not in ("done", "succeeded"):
|
||||
db.update_batch_job(
|
||||
batch_id, status="failed",
|
||||
error=f"compile not done (status={status_after})"
|
||||
)
|
||||
return
|
||||
|
||||
# 자동 영상 파이프라인
|
||||
try:
|
||||
pipeline_id = db.create_pipeline(compile_job_id=compile_id)
|
||||
db.update_batch_job(batch_id, pipeline_id=pipeline_id, status="piped")
|
||||
|
||||
from .pipeline import orchestrator
|
||||
await orchestrator.run_step(pipeline_id, "cover")
|
||||
except Exception as e:
|
||||
logger.exception("pipeline launch failed")
|
||||
db.update_batch_job(batch_id, status="failed", error=f"pipeline launch: {e}")
|
||||
|
||||
|
||||
async def _generate_one_track(*, title: str, genre: str, duration_sec: int,
|
||||
params: dict) -> int | None:
|
||||
"""기존 Suno generate 호출 + 완료까지 polling. 성공 시 새 track id, 실패 시 None."""
|
||||
from .suno_provider import run_suno_generation
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
suno_params = {
|
||||
"title": title,
|
||||
"genre": genre,
|
||||
"moods": params["moods"],
|
||||
"instruments": params["instruments"],
|
||||
"duration_sec": duration_sec,
|
||||
"bpm": params["bpm"],
|
||||
"key": params["key"],
|
||||
"scale": params["scale"],
|
||||
"prompt": params.get("prompt_modifier", ""),
|
||||
}
|
||||
db.create_task(task_id, suno_params, provider="suno")
|
||||
|
||||
# Suno background task — 우리가 await로 기다림 (BackgroundTasks 미사용)
|
||||
asyncio.create_task(asyncio.to_thread(run_suno_generation, task_id, suno_params))
|
||||
|
||||
waited = 0
|
||||
while waited < TRACK_GEN_TIMEOUT_S:
|
||||
await asyncio.sleep(POLL_INTERVAL_S)
|
||||
waited += POLL_INTERVAL_S
|
||||
task = db.get_task(task_id)
|
||||
if not task:
|
||||
continue
|
||||
status = task.get("status")
|
||||
if status == "succeeded":
|
||||
# task["track"] 또는 task["result"]["track"] 형태 시도, 없으면 task_id로 조회
|
||||
tr = task.get("track")
|
||||
if tr and isinstance(tr, dict):
|
||||
return tr.get("id")
|
||||
result = task.get("result", {}) or {}
|
||||
if isinstance(result, dict) and isinstance(result.get("track"), dict):
|
||||
return result["track"].get("id")
|
||||
# Fallback: music_library에서 task_id로 검색
|
||||
track = db.get_track_by_task_id(task_id)
|
||||
if track:
|
||||
return track.get("id")
|
||||
return None
|
||||
if status == "failed":
|
||||
return None
|
||||
return None # timeout
|
||||
@@ -264,6 +264,27 @@ def init_db() -> None:
|
||||
)
|
||||
""")
|
||||
|
||||
# ── music_batch_jobs 테이블 ──────────────────────────────────────
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS music_batch_jobs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
genre TEXT NOT NULL,
|
||||
count INTEGER NOT NULL,
|
||||
target_duration_sec INTEGER NOT NULL DEFAULT 180,
|
||||
auto_pipeline INTEGER NOT NULL DEFAULT 1,
|
||||
completed INTEGER NOT NULL DEFAULT 0,
|
||||
track_ids_json TEXT NOT NULL DEFAULT '[]',
|
||||
current_track_index INTEGER NOT NULL DEFAULT 0,
|
||||
current_track_status TEXT,
|
||||
status TEXT NOT NULL DEFAULT 'queued',
|
||||
error TEXT,
|
||||
compile_job_id INTEGER,
|
||||
pipeline_id INTEGER,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
|
||||
# ── YouTube pipeline 테이블 (5개) ─────────────────────────────────
|
||||
# track_id는 nullable: compile_job_id로 입력하는 essential mix 모드 지원
|
||||
conn.execute("""
|
||||
@@ -1257,3 +1278,85 @@ def get_oauth_token() -> Optional[Dict[str, Any]]:
|
||||
def delete_oauth_token() -> None:
|
||||
with _conn() as conn:
|
||||
conn.execute("DELETE FROM youtube_oauth_tokens")
|
||||
|
||||
|
||||
# ── music_batch_jobs CRUD ─────────────────────────────────────────────────────
|
||||
|
||||
_BATCH_ALLOWED_COLS = frozenset([
|
||||
"completed", "track_ids_json", "current_track_index",
|
||||
"current_track_status", "status", "error",
|
||||
"compile_job_id", "pipeline_id",
|
||||
])
|
||||
|
||||
|
||||
def create_batch_job(genre: str, count: int, target_duration_sec: int = 180,
|
||||
auto_pipeline: bool = True) -> int:
|
||||
with _conn() as conn:
|
||||
now = _now()
|
||||
cur = conn.cursor()
|
||||
cur.execute("""
|
||||
INSERT INTO music_batch_jobs
|
||||
(genre, count, target_duration_sec, auto_pipeline,
|
||||
status, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, 'queued', ?, ?)
|
||||
""", (genre, count, target_duration_sec, 1 if auto_pipeline else 0, now, now))
|
||||
return cur.lastrowid
|
||||
|
||||
|
||||
def get_batch_job(batch_id: int) -> dict | None:
|
||||
with _conn() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM music_batch_jobs WHERE id = ?", (batch_id,)
|
||||
).fetchone()
|
||||
if not row:
|
||||
return None
|
||||
d = dict(row)
|
||||
d["track_ids"] = json.loads(d.get("track_ids_json") or "[]")
|
||||
return d
|
||||
|
||||
|
||||
def update_batch_job(batch_id: int, **fields) -> None:
|
||||
unknown = set(fields) - _BATCH_ALLOWED_COLS
|
||||
if unknown:
|
||||
raise ValueError(f"unknown batch job columns: {unknown}")
|
||||
if not fields:
|
||||
return
|
||||
cols = ", ".join(f"{k} = ?" for k in fields)
|
||||
vals = list(fields.values()) + [_now(), batch_id]
|
||||
with _conn() as conn:
|
||||
conn.execute(
|
||||
f"UPDATE music_batch_jobs SET {cols}, updated_at = ? WHERE id = ?",
|
||||
vals,
|
||||
)
|
||||
|
||||
|
||||
def append_batch_track(batch_id: int, track_id: int) -> None:
|
||||
"""track_ids_json에 새 track_id 추가 + completed 증가 (atomic)."""
|
||||
with _conn() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT track_ids_json, completed FROM music_batch_jobs WHERE id = ?",
|
||||
(batch_id,),
|
||||
).fetchone()
|
||||
if not row:
|
||||
return
|
||||
ids = json.loads(row["track_ids_json"] or "[]")
|
||||
ids.append(track_id)
|
||||
conn.execute(
|
||||
"UPDATE music_batch_jobs SET track_ids_json = ?, completed = ?, updated_at = ? WHERE id = ?",
|
||||
(json.dumps(ids), row["completed"] + 1, _now(), batch_id),
|
||||
)
|
||||
|
||||
|
||||
def list_batch_jobs(active_only: bool = False) -> list[dict]:
|
||||
sql = "SELECT * FROM music_batch_jobs"
|
||||
if active_only:
|
||||
sql += " WHERE status NOT IN ('failed','cancelled','piped')"
|
||||
sql += " ORDER BY created_at DESC"
|
||||
with _conn() as conn:
|
||||
rows = conn.execute(sql).fetchall()
|
||||
out = []
|
||||
for r in rows:
|
||||
d = dict(r)
|
||||
d["track_ids"] = json.loads(d.get("track_ids_json") or "[]")
|
||||
out.append(d)
|
||||
return out
|
||||
|
||||
@@ -35,6 +35,7 @@ from .suno_provider import (
|
||||
generate_lyrics, get_credits, get_timestamped_lyrics, generate_style_boost,
|
||||
SUNO_API_KEY, SUNO_MODELS,
|
||||
)
|
||||
from .batch_generator import run_batch as _run_batch
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@@ -849,6 +850,62 @@ def export_compile(job_id: int):
|
||||
}
|
||||
|
||||
|
||||
# ── 배치 음악 생성 API ────────────────────────────────────────────────────────
|
||||
|
||||
class BatchGenerateRequest(BaseModel):
|
||||
genre: str
|
||||
count: int = 10
|
||||
target_duration_sec: int = 180
|
||||
auto_pipeline: bool = True
|
||||
|
||||
|
||||
@app.post("/api/music/generate-batch", status_code=201)
|
||||
async def generate_batch(req: BatchGenerateRequest, bg: BackgroundTasks):
|
||||
if not (1 <= req.count <= 10):
|
||||
raise HTTPException(status_code=400, detail="count는 1-10 사이")
|
||||
if not (60 <= req.target_duration_sec <= 300):
|
||||
raise HTTPException(status_code=400, detail="target_duration_sec는 60-300 사이")
|
||||
if not req.genre:
|
||||
raise HTTPException(status_code=400, detail="genre 필수")
|
||||
if not SUNO_API_KEY:
|
||||
raise HTTPException(status_code=400, detail="SUNO_API_KEY 미설정")
|
||||
|
||||
batch_id = _db_module.create_batch_job(
|
||||
genre=req.genre, count=req.count,
|
||||
target_duration_sec=req.target_duration_sec,
|
||||
auto_pipeline=req.auto_pipeline,
|
||||
)
|
||||
bg.add_task(_run_batch, batch_id)
|
||||
return _db_module.get_batch_job(batch_id)
|
||||
|
||||
|
||||
@app.get("/api/music/generate-batch/{batch_id}")
|
||||
def get_batch(batch_id: int):
|
||||
j = _db_module.get_batch_job(batch_id)
|
||||
if not j:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
if j["track_ids"]:
|
||||
ids_csv = ",".join(str(i) for i in j["track_ids"])
|
||||
import sqlite3
|
||||
conn = sqlite3.connect(_db_module.DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
rows = conn.execute(
|
||||
f"SELECT id, title, audio_url, duration_sec FROM music_library WHERE id IN ({ids_csv})"
|
||||
).fetchall()
|
||||
conn.close()
|
||||
# 트랙을 batch.track_ids 순서대로 정렬
|
||||
by_id = {r["id"]: dict(r) for r in rows}
|
||||
j["tracks"] = [by_id.get(tid) for tid in j["track_ids"] if tid in by_id]
|
||||
else:
|
||||
j["tracks"] = []
|
||||
return j
|
||||
|
||||
|
||||
@app.get("/api/music/generate-batch")
|
||||
def list_batches(status: str = "all"):
|
||||
return {"batches": _db_module.list_batch_jobs(active_only=(status == "active"))}
|
||||
|
||||
|
||||
# ── 수익화 추적 API ───────────────────────────────────────────────────────────
|
||||
|
||||
@app.get("/api/music/revenue/dashboard")
|
||||
|
||||
69
music-lab/app/random_pools.py
Normal file
69
music-lab/app/random_pools.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""장르별 음악 파라미터 랜덤 풀 — 음악적으로 어울리는 결과 유도."""
|
||||
import random
|
||||
|
||||
POOLS = {
|
||||
"lo-fi": {
|
||||
"moods": ["chill", "relaxing", "dreamy", "melancholic", "mellow", "nostalgic", "peaceful"],
|
||||
"instruments_pool": ["piano", "synth", "drums", "vinyl", "rhodes", "soft bass", "ambient pads"],
|
||||
"instruments_count": (3, 4),
|
||||
"bpm": (70, 90),
|
||||
"keys": ["C", "D", "F", "G", "A"],
|
||||
"scales": ["minor", "major"],
|
||||
"prompt_modifiers": ["cozy bedroom vibes", "rainy night", "late night study", "cafe ambience"],
|
||||
},
|
||||
"phonk": {
|
||||
"moods": ["dark", "aggressive", "moody", "intense", "hypnotic"],
|
||||
"instruments_pool": ["808 bass", "hi-hat", "synth lead", "vocal chops", "bass drops", "trap drums"],
|
||||
"instruments_count": (3, 4),
|
||||
"bpm": (130, 160),
|
||||
"keys": ["C", "D", "F", "G"],
|
||||
"scales": ["minor"],
|
||||
"prompt_modifiers": ["drift atmosphere", "dark neon", "midnight drive"],
|
||||
},
|
||||
"ambient": {
|
||||
"moods": ["peaceful", "meditative", "ethereal", "spacious", "dreamy"],
|
||||
"instruments_pool": ["pad synths", "atmospheric guitar", "soft strings", "field recordings", "drone bass"],
|
||||
"instruments_count": (2, 3),
|
||||
"bpm": (50, 75),
|
||||
"keys": ["C", "D", "E", "G", "A"],
|
||||
"scales": ["major", "minor"],
|
||||
"prompt_modifiers": ["misty mountain morning", "deep space", "still water", "forest dawn"],
|
||||
},
|
||||
"pop": {
|
||||
"moods": ["uplifting", "happy", "energetic", "romantic", "catchy"],
|
||||
"instruments_pool": ["acoustic guitar", "piano", "drums", "bass", "synth", "vocals harmonies"],
|
||||
"instruments_count": (3, 5),
|
||||
"bpm": (95, 130),
|
||||
"keys": ["C", "D", "E", "F", "G", "A"],
|
||||
"scales": ["major"],
|
||||
"prompt_modifiers": ["radio-ready", "summer vibe", "feel-good"],
|
||||
},
|
||||
"default": {
|
||||
"moods": ["chill", "relaxing", "uplifting", "mellow"],
|
||||
"instruments_pool": ["piano", "synth", "drums", "guitar", "bass", "strings"],
|
||||
"instruments_count": (3, 4),
|
||||
"bpm": (80, 110),
|
||||
"keys": ["C", "D", "F", "G", "A"],
|
||||
"scales": ["minor", "major"],
|
||||
"prompt_modifiers": [""],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def randomize(genre: str, rng=None) -> dict:
|
||||
"""장르 → 랜덤 음악 파라미터 1세트.
|
||||
|
||||
반환: {moods, instruments, bpm, key, scale, prompt_modifier}
|
||||
"""
|
||||
rng = rng or random.Random()
|
||||
pool = POOLS.get(genre.lower(), POOLS["default"])
|
||||
n_instr = rng.randint(*pool["instruments_count"])
|
||||
instruments = rng.sample(pool["instruments_pool"], min(n_instr, len(pool["instruments_pool"])))
|
||||
return {
|
||||
"moods": [rng.choice(pool["moods"])],
|
||||
"instruments": instruments,
|
||||
"bpm": rng.randint(*pool["bpm"]),
|
||||
"key": rng.choice(pool["keys"]),
|
||||
"scale": rng.choice(pool["scales"]),
|
||||
"prompt_modifier": rng.choice(pool["prompt_modifiers"]),
|
||||
}
|
||||
96
music-lab/tests/test_batch_db.py
Normal file
96
music-lab/tests/test_batch_db.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import pytest
|
||||
from app import db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fresh_db(monkeypatch, tmp_path):
|
||||
monkeypatch.setattr(db, "DB_PATH", str(tmp_path / "music.db"))
|
||||
db.init_db()
|
||||
return db
|
||||
|
||||
|
||||
def test_create_batch_job(fresh_db):
|
||||
bid = db.create_batch_job(genre="lo-fi", count=10)
|
||||
j = db.get_batch_job(bid)
|
||||
assert j["genre"] == "lo-fi"
|
||||
assert j["count"] == 10
|
||||
assert j["status"] == "queued"
|
||||
assert j["track_ids"] == []
|
||||
assert j["auto_pipeline"] == 1
|
||||
assert j["target_duration_sec"] == 180
|
||||
|
||||
|
||||
def test_create_batch_job_no_auto_pipeline(fresh_db):
|
||||
bid = db.create_batch_job(genre="phonk", count=5, auto_pipeline=False)
|
||||
j = db.get_batch_job(bid)
|
||||
assert j["auto_pipeline"] == 0
|
||||
|
||||
|
||||
def test_update_batch_job(fresh_db):
|
||||
bid = db.create_batch_job(genre="phonk", count=5)
|
||||
db.update_batch_job(bid, status="generating", current_track_index=2)
|
||||
j = db.get_batch_job(bid)
|
||||
assert j["status"] == "generating"
|
||||
assert j["current_track_index"] == 2
|
||||
|
||||
|
||||
def test_update_batch_rejects_unknown_col(fresh_db):
|
||||
bid = db.create_batch_job(genre="lo-fi", count=1)
|
||||
with pytest.raises(ValueError):
|
||||
db.update_batch_job(bid, evil_col="x")
|
||||
|
||||
|
||||
def test_append_batch_track(fresh_db):
|
||||
bid = db.create_batch_job(genre="lo-fi", count=3)
|
||||
db.append_batch_track(bid, 101)
|
||||
db.append_batch_track(bid, 102)
|
||||
j = db.get_batch_job(bid)
|
||||
assert j["track_ids"] == [101, 102]
|
||||
assert j["completed"] == 2
|
||||
|
||||
|
||||
def test_list_batch_jobs_active_filter(fresh_db):
|
||||
b1 = db.create_batch_job(genre="lo-fi", count=1)
|
||||
b2 = db.create_batch_job(genre="phonk", count=1)
|
||||
db.update_batch_job(b1, status="failed")
|
||||
actives = db.list_batch_jobs(active_only=True)
|
||||
assert all(j["status"] not in ("failed",) for j in actives)
|
||||
assert any(j["id"] == b2 for j in actives)
|
||||
assert not any(j["id"] == b1 for j in actives)
|
||||
|
||||
|
||||
def test_random_pools_lofi():
|
||||
from app.random_pools import randomize, POOLS
|
||||
import random
|
||||
rng = random.Random(42)
|
||||
result = randomize("lo-fi", rng)
|
||||
assert result["bpm"] in range(70, 91)
|
||||
assert result["key"] in POOLS["lo-fi"]["keys"]
|
||||
assert result["scale"] in POOLS["lo-fi"]["scales"]
|
||||
assert len(result["moods"]) == 1
|
||||
assert result["moods"][0] in POOLS["lo-fi"]["moods"]
|
||||
assert 3 <= len(result["instruments"]) <= 4
|
||||
|
||||
|
||||
def test_random_pools_phonk():
|
||||
from app.random_pools import randomize
|
||||
import random
|
||||
rng = random.Random(0)
|
||||
result = randomize("phonk", rng)
|
||||
assert result["bpm"] in range(130, 161)
|
||||
assert result["scale"] == "minor"
|
||||
|
||||
|
||||
def test_random_pools_unknown_genre_uses_default():
|
||||
from app.random_pools import randomize
|
||||
import random
|
||||
result = randomize("nonexistent", random.Random(0))
|
||||
assert result["bpm"] in range(80, 111)
|
||||
|
||||
|
||||
def test_random_pools_seed_reproducible():
|
||||
from app.random_pools import randomize
|
||||
import random
|
||||
a = randomize("lo-fi", random.Random(123))
|
||||
b = randomize("lo-fi", random.Random(123))
|
||||
assert a == b
|
||||
89
music-lab/tests/test_batch_endpoints.py
Normal file
89
music-lab/tests/test_batch_endpoints.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
import app.main as main_module
|
||||
from app import db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(monkeypatch, tmp_path):
|
||||
monkeypatch.setattr(db, "DB_PATH", str(tmp_path / "music.db"))
|
||||
db.init_db()
|
||||
monkeypatch.setenv("SUNO_API_KEY", "test")
|
||||
# main.py의 SUNO_API_KEY 모듈 변수도 갱신 필요할 수 있음
|
||||
monkeypatch.setattr(main_module, "SUNO_API_KEY", "test", raising=False)
|
||||
return TestClient(main_module.app)
|
||||
|
||||
|
||||
def test_create_batch_201(client):
|
||||
with patch.object(main_module, "_run_batch", new=AsyncMock()):
|
||||
r = client.post("/api/music/generate-batch",
|
||||
json={"genre": "lo-fi", "count": 3})
|
||||
assert r.status_code == 201, r.text
|
||||
body = r.json()
|
||||
assert body["genre"] == "lo-fi"
|
||||
assert body["count"] == 3
|
||||
assert body["status"] == "queued"
|
||||
|
||||
|
||||
def test_create_batch_rejects_count_too_high(client):
|
||||
r = client.post("/api/music/generate-batch",
|
||||
json={"genre": "lo-fi", "count": 11})
|
||||
assert r.status_code == 400
|
||||
|
||||
|
||||
def test_create_batch_rejects_count_zero(client):
|
||||
r = client.post("/api/music/generate-batch",
|
||||
json={"genre": "lo-fi", "count": 0})
|
||||
assert r.status_code == 400
|
||||
|
||||
|
||||
def test_create_batch_rejects_no_genre(client):
|
||||
r = client.post("/api/music/generate-batch", json={"count": 3})
|
||||
assert r.status_code in (400, 422)
|
||||
|
||||
|
||||
def test_create_batch_rejects_invalid_duration(client):
|
||||
r = client.post("/api/music/generate-batch",
|
||||
json={"genre": "lo-fi", "count": 3, "target_duration_sec": 30})
|
||||
assert r.status_code == 400
|
||||
|
||||
|
||||
def test_create_batch_rejects_no_suno_key(client, monkeypatch):
|
||||
monkeypatch.setattr(main_module, "SUNO_API_KEY", "", raising=False)
|
||||
r = client.post("/api/music/generate-batch",
|
||||
json={"genre": "lo-fi", "count": 3})
|
||||
assert r.status_code == 400
|
||||
|
||||
|
||||
def test_get_batch_returns_tracks(client):
|
||||
bid = db.create_batch_job(genre="lo-fi", count=2)
|
||||
db.append_batch_track(bid, 999) # phantom track
|
||||
r = client.get(f"/api/music/generate-batch/{bid}")
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
assert body["track_ids"] == [999]
|
||||
assert body["tracks"] == [] # 999 not in music_library
|
||||
|
||||
|
||||
def test_get_batch_404(client):
|
||||
r = client.get("/api/music/generate-batch/99999")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_list_batches(client):
|
||||
db.create_batch_job(genre="lo-fi", count=1)
|
||||
db.create_batch_job(genre="phonk", count=2)
|
||||
r = client.get("/api/music/generate-batch")
|
||||
assert r.status_code == 200
|
||||
assert len(r.json()["batches"]) == 2
|
||||
|
||||
|
||||
def test_list_batches_active_filter(client):
|
||||
b1 = db.create_batch_job(genre="lo-fi", count=1)
|
||||
b2 = db.create_batch_job(genre="phonk", count=2)
|
||||
db.update_batch_job(b1, status="failed")
|
||||
r = client.get("/api/music/generate-batch?status=active")
|
||||
ids = [j["id"] for j in r.json()["batches"]]
|
||||
assert b2 in ids
|
||||
assert b1 not in ids
|
||||
Reference in New Issue
Block a user