diff --git a/music-lab/app/main.py b/music-lab/app/main.py index 07e4fd0..568c0aa 100644 --- a/music-lab/app/main.py +++ b/music-lab/app/main.py @@ -929,7 +929,11 @@ def market_suggest(limit: int = 5): # ── Pipeline endpoints ──────────────────────────────────────────────────────── class PipelineCreate(BaseModel): - track_id: int + track_id: int | None = None + compile_job_id: int | None = None + visual_style: str | None = None # single | essential + background_mode: str | None = None # static | video_loop + background_keyword: str | None = None class FeedbackRequest(BaseModel): @@ -940,10 +944,34 @@ class FeedbackRequest(BaseModel): @app.post("/api/music/pipeline", status_code=201) def create_pipeline(req: PipelineCreate): + # XOR 검증 + if (req.track_id is None) == (req.compile_job_id is None): + raise HTTPException(400, "track_id 또는 compile_job_id 중 정확히 하나를 지정") + + # compile_job 상태 확인 + if req.compile_job_id is not None: + job = _db_module.get_compile_job(req.compile_job_id) + if not job: + raise HTTPException(404, f"compile job {req.compile_job_id} 없음") + if job.get("status") != "succeeded": + raise HTTPException(400, f"compile job {req.compile_job_id} not ready (status={job.get('status')})") + + # 동일 입력으로 이미 active 파이프라인 있으면 409 actives = _db_module.list_pipelines(active_only=True) - if any(p["track_id"] == req.track_id for p in actives): - raise HTTPException(409, "이미 진행 중인 파이프라인이 있습니다") - pid = _db_module.create_pipeline(req.track_id) + for p in actives: + if (req.track_id and p.get("track_id") == req.track_id) or \ + (req.compile_job_id and p.get("compile_job_id") == req.compile_job_id): + raise HTTPException(409, "이미 진행 중인 파이프라인이 있습니다") + + setup = _db_module.get_youtube_setup() + vd = setup["visual_defaults"] + pid = _db_module.create_pipeline( + track_id=req.track_id, + compile_job_id=req.compile_job_id, + visual_style=req.visual_style or vd.get("default_visual_style", "essential"), + background_mode=req.background_mode or vd.get("default_background_mode", "static"), + background_keyword=req.background_keyword or vd.get("default_background_keyword") or None, + ) return _db_module.get_pipeline(pid) diff --git a/music-lab/tests/test_pipeline_endpoints.py b/music-lab/tests/test_pipeline_endpoints.py index 93263af..13e19b8 100644 --- a/music-lab/tests/test_pipeline_endpoints.py +++ b/music-lab/tests/test_pipeline_endpoints.py @@ -108,3 +108,69 @@ def test_youtube_status_when_disconnected(client): r = client.get("/api/music/youtube/status") assert r.status_code == 200 assert r.json() == {"connected": False} + + +def test_create_pipeline_with_compile_job(client, monkeypatch): + import sqlite3 + conn = sqlite3.connect(db.DB_PATH) + cur = conn.cursor() + try: + cur.execute(""" + INSERT INTO compile_jobs (title, track_ids_json, crossfade_sec, + audio_path, status, created_at) + VALUES ('Test Mix', '[1,2,3]', 3, '/app/data/compiles/9.mp3', + 'succeeded', datetime()) + """) + except sqlite3.OperationalError: + pytest.skip("compile_jobs schema mismatch") + conn.commit() + cid = cur.lastrowid + conn.close() + + r = client.post("/api/music/pipeline", json={"compile_job_id": cid}) + assert r.status_code == 201 + body = r.json() + assert body["track_id"] is None + assert body["compile_job_id"] == cid + assert body["visual_style"] == "essential" + + +def test_create_pipeline_rejects_both_inputs(client): + r = client.post("/api/music/pipeline", json={"track_id": 1, "compile_job_id": 1}) + assert r.status_code == 400 + + +def test_create_pipeline_rejects_neither(client): + r = client.post("/api/music/pipeline", json={}) + assert r.status_code == 400 + + +def test_create_pipeline_rejects_compile_not_ready(client): + import sqlite3 + conn = sqlite3.connect(db.DB_PATH) + cur = conn.cursor() + try: + cur.execute(""" + INSERT INTO compile_jobs (title, status, created_at) + VALUES ('Pending', 'rendering', datetime()) + """) + except sqlite3.OperationalError: + pytest.skip("compile_jobs schema mismatch") + conn.commit() + cid = cur.lastrowid + conn.close() + + r = client.post("/api/music/pipeline", json={"compile_job_id": cid}) + assert r.status_code == 400 + + +def test_create_pipeline_with_visual_options(client): + r = client.post("/api/music/pipeline", json={ + "track_id": 1, "visual_style": "single", + "background_mode": "video_loop", "background_keyword": "rain", + }) + assert r.status_code == 201 + body = r.json() + assert body["visual_style"] == "single" + assert body["background_mode"] == "video_loop" + assert body["background_keyword"] == "rain"