From fe60c8d330902ffd34e097f93134d9897a3b349b Mon Sep 17 00:00:00 2001 From: gahusb Date: Thu, 7 May 2026 17:11:29 +0900 Subject: [PATCH] =?UTF-8?q?feat(music-lab):=20pipeline=20=EC=98=A4?= =?UTF-8?q?=EC=BC=80=EC=8A=A4=ED=8A=B8=EB=A0=88=EC=9D=B4=ED=84=B0=20+=2014?= =?UTF-8?q?=20=EC=97=94=EB=93=9C=ED=8F=AC=EC=9D=B8=ED=8A=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- music-lab/app/main.py | 194 +++++++++++++++++++++ music-lab/app/pipeline/orchestrator.py | 183 +++++++++++++++++++ music-lab/tests/test_pipeline_endpoints.py | 110 ++++++++++++ 3 files changed, 487 insertions(+) create mode 100644 music-lab/app/pipeline/orchestrator.py create mode 100644 music-lab/tests/test_pipeline_endpoints.py diff --git a/music-lab/app/main.py b/music-lab/app/main.py index f0a3666..2db9944 100644 --- a/music-lab/app/main.py +++ b/music-lab/app/main.py @@ -22,9 +22,12 @@ from .db import ( create_compile_job, get_compile_jobs, get_compile_job, update_compile_job, delete_compile_job, ) +from . import db as _db_module from .compiler import run_compile from .market import ingest_trends, get_suggestions from .local_provider import run_local_generation +from .pipeline import orchestrator +from .pipeline import youtube as yt_module from .suno_provider import ( run_suno_generation, run_suno_extend, run_vocal_removal, run_cover_image, run_wav_convert, run_stem_split, @@ -921,3 +924,194 @@ def list_market_reports(limit: int = 10): @app.get("/api/music/market/suggest") def market_suggest(limit: int = 5): return {"suggestions": get_suggestions(limit)} + + +# ── Pipeline endpoints ──────────────────────────────────────────────────────── + +class PipelineCreate(BaseModel): + track_id: int + + +class FeedbackRequest(BaseModel): + step: str + intent: str # approve | reject + feedback_text: Optional[str] = None + + +@app.post("/api/music/pipeline", status_code=201) +def create_pipeline(req: PipelineCreate): + actives = _db_module.list_pipelines(active_only=True) + if any(p["track_id"] == req.track_id for p in actives): + raise HTTPException(409, "이미 진행 중인 파이프라인이 있습니다") + pid = _db_module.create_pipeline(req.track_id) + return _db_module.get_pipeline(pid) + + +@app.get("/api/music/pipeline") +def list_pipelines_endpoint(status: str = "all"): + pipelines = _db_module.list_pipelines(active_only=(status == "active")) + return {"pipelines": pipelines} + + +@app.get("/api/music/pipeline/lookup-by-msg/{msg_id}") +def lookup_by_msg(msg_id: int): + for p in _db_module.list_pipelines(active_only=True): + for step, mid in p["last_telegram_msg_ids"].items(): + if mid == msg_id: + return {"pipeline_id": p["id"], "step": step} + raise HTTPException(404) + + +@app.get("/api/music/pipeline/{pid}") +def get_pipeline_endpoint(pid: int): + p = _db_module.get_pipeline(pid) + if not p: + raise HTTPException(404) + p["jobs"] = _db_module.list_pipeline_jobs(pid) + p["feedback"] = _db_module.get_feedback_history(pid) + return p + + +@app.post("/api/music/pipeline/{pid}/start", status_code=202) +async def start_pipeline(pid: int, bg: BackgroundTasks): + p = _db_module.get_pipeline(pid) + if not p: + raise HTTPException(404) + if p["state"] != "created": + raise HTTPException(409, f"이미 시작됨 ({p['state']})") + bg.add_task(orchestrator.run_step, pid, "cover") + return {"ok": True} + + +def _state_to_step(state: str) -> Optional[str]: + return { + "video_pending": "video", + "thumb_pending": "thumb", + "meta_pending": "meta", + "ai_review": "review", + "publish_pending": None, # 사용자 명시 발행 호출 필요 + "publishing": "publish", + }.get(state) + + +@app.post("/api/music/pipeline/{pid}/feedback", status_code=202) +async def feedback(pid: int, req: FeedbackRequest, bg: BackgroundTasks): + p = _db_module.get_pipeline(pid) + if not p: + raise HTTPException(404) + if p["state"] == "awaiting_manual": + raise HTTPException(409, "수동 개입 대기 중") + state = p["state"] + expected = f"{req.step}_pending" + if state != expected: + # 멱등 처리 — 이미 다음 단계로 넘어갔으면 무시 + return {"ok": True, "skipped": True} + + if req.intent == "approve": + from .pipeline.state_machine import next_state_on_approve + next_st = next_state_on_approve(state) + _db_module.update_pipeline_state(pid, next_st) + next_step = _state_to_step(next_st) + if next_step: + bg.add_task(orchestrator.run_step, pid, next_step) + return {"ok": True} + + elif req.intent == "reject": + count = _db_module.increment_feedback_count(pid, req.step) + if count > 5: + _db_module.update_pipeline_state(pid, "awaiting_manual") + raise HTTPException(409, "재생성 한도 초과") + if req.feedback_text: + _db_module.record_feedback(pid, req.step, req.feedback_text) + bg.add_task(orchestrator.run_step, pid, req.step, req.feedback_text or "") + return {"ok": True} + + else: + raise HTTPException(400, f"unknown intent: {req.intent}") + + +@app.post("/api/music/pipeline/{pid}/cancel") +def cancel_pipeline(pid: int): + p = _db_module.get_pipeline(pid) + if not p: + raise HTTPException(404) + _db_module.update_pipeline_state(pid, "cancelled", cancelled_at=_db_module._now()) + return {"ok": True} + + +@app.post("/api/music/pipeline/{pid}/publish", status_code=202) +async def publish_pipeline(pid: int, bg: BackgroundTasks): + p = _db_module.get_pipeline(pid) + if not p: + raise HTTPException(404) + if p["state"] != "publish_pending": + raise HTTPException(409, f"발행 단계 아님 ({p['state']})") + _db_module.update_pipeline_state(pid, "publishing") + bg.add_task(orchestrator.run_step, pid, "publish") + return {"ok": True} + + +# Telegram 메시지 매칭용 엔드포인트 (agent-office용) + +class TelegramMsgPatch(BaseModel): + step: str + message_id: int + + +@app.patch("/api/music/pipeline/{pid}/telegram-msg") +def save_telegram_msg(pid: int, req: TelegramMsgPatch): + p = _db_module.get_pipeline(pid) + if not p: + raise HTTPException(404) + ids = p["last_telegram_msg_ids"] + ids[req.step] = req.message_id + _db_module.update_pipeline_state( + pid, p["state"], last_telegram_msg_ids=json.dumps(ids) + ) + return {"ok": True} + + +# ── Setup endpoints ─────────────────────────────────────────────────────────── + +class SetupRequest(BaseModel): + metadata_template: Optional[Dict[str, Any]] = None + cover_prompts: Optional[Dict[str, Any]] = None + review_weights: Optional[Dict[str, Any]] = None + review_threshold: Optional[int] = None + visual_defaults: Optional[Dict[str, Any]] = None + publish_policy: Optional[Dict[str, Any]] = None + + +@app.get("/api/music/setup") +def get_setup(): + return _db_module.get_youtube_setup() + + +@app.put("/api/music/setup") +def put_setup(req: SetupRequest): + payload = {k: v for k, v in req.dict().items() if v is not None} + _db_module.update_youtube_setup(**payload) + return _db_module.get_youtube_setup() + + +# ── YouTube OAuth endpoints ─────────────────────────────────────────────────── + +@app.get("/api/music/youtube/auth-url") +def youtube_auth_url(): + return {"url": yt_module.get_auth_url()} + + +@app.get("/api/music/youtube/callback") +async def youtube_callback(code: str): + return await yt_module.exchange_code(code) + + +@app.post("/api/music/youtube/disconnect") +def youtube_disconnect(): + yt_module.disconnect() + return {"ok": True} + + +@app.get("/api/music/youtube/status") +def youtube_status(): + return yt_module.get_status() or {"connected": False} diff --git a/music-lab/app/pipeline/orchestrator.py b/music-lab/app/pipeline/orchestrator.py new file mode 100644 index 0000000..cebe7e2 --- /dev/null +++ b/music-lab/app/pipeline/orchestrator.py @@ -0,0 +1,183 @@ +"""파이프라인 오케스트레이터 — 단계별 BackgroundTask 등록 및 산출물 → DB 반영.""" +import json +import logging +import os +import sqlite3 + +from app import db +from . import cover, video, thumb, metadata, review, youtube + +logger = logging.getLogger("music-lab.orchestrator") + + +async def run_step(pipeline_id: int, step: str, feedback: str = "") -> None: + """단계 실행 → 결과를 DB에 반영하고 *_pending 또는 다음 단계로 전이. + + 호출 직후 _running 상태로 전환, 끝나면 _pending(사용자 게이트) 또는 자동 다음. + 실패 시 failed 상태 + reason. + """ + job_id = db.create_pipeline_job(pipeline_id, step) + db.update_pipeline_job(job_id, status="running") + p = db.get_pipeline(pipeline_id) + track = _get_track(p["track_id"]) + + try: + if step == "cover": + result = await _run_cover(p, track, feedback) + elif step == "video": + result = await _run_video(p, track) + elif step == "thumb": + result = await _run_thumb(p, track, feedback) + elif step == "meta": + result = await _run_meta(p, track, feedback) + elif step == "review": + result = await _run_review(p, track) + elif step == "publish": + result = await _run_publish(p, track) + else: + raise ValueError(f"unknown step: {step}") + db.update_pipeline_job(job_id, status="succeeded") + db.update_pipeline_state(pipeline_id, result["next_state"], **result.get("fields", {})) + except Exception as e: + logger.exception("step %s failed for pipeline %s", step, pipeline_id) + db.update_pipeline_job(job_id, status="failed", error=str(e)) + db.update_pipeline_state(pipeline_id, "failed", failed_reason=f"{step}: {e}") + + +def _get_track(track_id: int) -> dict: + # tracks 테이블 헬퍼 — 기존 db에 있는 함수 사용 + t = None + if hasattr(db, "get_track_by_id"): + t = db.get_track_by_id(track_id) + elif hasattr(db, "get_track"): + t = db.get_track(track_id) + if not t: + # 폴백: music_library 테이블에서 직접 (스키마 확인 필요) + t = _fetch_track_fallback(track_id) + if not t: + raise ValueError(f"트랙 {track_id} 없음") + return t + + +def _fetch_track_fallback(track_id: int) -> dict | None: + """db 모듈에 get_track이 없을 때 대비 — music_library 테이블 직접 조회.""" + try: + conn = sqlite3.connect(db.DB_PATH) + conn.row_factory = sqlite3.Row + # 가능한 테이블/컬럼 시도 (music_library 또는 tracks) + for table in ("music_library", "tracks"): + try: + row = conn.execute(f"SELECT * FROM {table} WHERE id = ?", (track_id,)).fetchone() + if row: + d = dict(row) + # JSON 컬럼 파싱 (있으면) + for k in ("moods", "instruments"): + if k in d and isinstance(d[k], str): + try: + d[k] = json.loads(d[k]) + except (json.JSONDecodeError, TypeError): + d[k] = [] + conn.close() + return d + except sqlite3.OperationalError: + continue + conn.close() + except Exception as e: + logger.warning("track fallback fetch 실패: %s", e) + return None + + +async def _run_cover(p, track, feedback): + setup = db.get_youtube_setup() + prompts = setup["cover_prompts"] + template = prompts.get(track.get("genre", "default").lower(), prompts.get("default", "")) + out = await cover.generate( + pipeline_id=p["id"], genre=track.get("genre", "default"), + prompt_template=template, + mood=", ".join(track.get("moods", []) or []), + track_title=track.get("title", ""), + feedback=feedback, + ) + return {"next_state": "cover_pending", "fields": {"cover_url": out["url"]}} + + +async def _run_video(p, track): + setup = db.get_youtube_setup() + vd = setup["visual_defaults"] + audio_path = _local_path(track.get("audio_url", "")) + cover_path = _local_path(p["cover_url"]) + out = video.generate( + pipeline_id=p["id"], audio_path=audio_path, cover_path=cover_path, + genre=track.get("genre", "default"), + duration_sec=track.get("duration_sec", 120), + resolution=vd["resolution"], style=vd["style"], + ) + return {"next_state": "video_pending", "fields": {"video_url": out["url"]}} + + +async def _run_thumb(p, track, feedback): + video_path = _local_path(p["video_url"]) + out = thumb.generate(pipeline_id=p["id"], video_path=video_path, + track_title=track.get("title", ""), overlay_text=True) + return {"next_state": "thumb_pending", "fields": {"thumbnail_url": out["url"]}} + + +async def _run_meta(p, track, feedback): + setup = db.get_youtube_setup() + trend_top = _get_trend_top() + out = await metadata.generate( + track=track, template=setup["metadata_template"], + trend_keywords=trend_top, feedback=feedback, + ) + return {"next_state": "meta_pending", + "fields": {"metadata_json": json.dumps(out, ensure_ascii=False)}} + + +async def _run_review(p, track): + setup = db.get_youtube_setup() + meta = json.loads(p["metadata_json"]) if p.get("metadata_json") else {} + result = await review.run_4_axis( + pipeline=p, track=track, + video_meta={"length_sec": track.get("duration_sec", 120), + "resolution": setup["visual_defaults"]["resolution"]}, + metadata=meta, thumbnail_url=p.get("thumbnail_url", ""), + trend_top=_get_trend_top(), + weights=setup["review_weights"], threshold=setup["review_threshold"], + ) + return {"next_state": "publish_pending", + "fields": {"review_json": json.dumps(result, ensure_ascii=False)}} + + +async def _run_publish(p, track): + setup = db.get_youtube_setup() + meta = json.loads(p["metadata_json"]) if p.get("metadata_json") else {} + privacy = setup["publish_policy"].get("privacy", "private") + result = youtube.upload_video( + video_path=_local_path(p["video_url"]), + thumbnail_path=_local_path(p["thumbnail_url"]) if p.get("thumbnail_url") else None, + metadata=meta, privacy=privacy, + ) + return {"next_state": "published", + "fields": {"youtube_video_id": result["video_id"]}} + + +def _local_path(media_url: str) -> str: + """ /media/videos/123/cover.jpg → /app/data/videos/123/cover.jpg """ + if not media_url: + return "" + base_media = os.getenv("VIDEO_MEDIA_BASE", "/media/videos") + base_data = os.getenv("VIDEO_DATA_DIR", "/app/data/videos") + if media_url.startswith(base_media): + return media_url.replace(base_media, base_data, 1) + # /media/music/abc.mp3 → /app/data/music/abc.mp3 + return media_url.replace("/media/", "/app/data/", 1) + + +def _get_trend_top(n: int = 10) -> list[str]: + try: + if hasattr(db, "get_market_trends"): + rows = db.get_market_trends(days=7) + return [r.get("genre", "") for r in rows[:n] if r.get("genre")] + except Exception: + pass + return [] diff --git a/music-lab/tests/test_pipeline_endpoints.py b/music-lab/tests/test_pipeline_endpoints.py new file mode 100644 index 0000000..93263af --- /dev/null +++ b/music-lab/tests/test_pipeline_endpoints.py @@ -0,0 +1,110 @@ +import sqlite3 +import pytest +from unittest.mock import AsyncMock, patch +from fastapi.testclient import TestClient + +from app.main import app +from app import db + + +@pytest.fixture +def client(monkeypatch, tmp_path): + monkeypatch.setattr(db, "DB_PATH", str(tmp_path / "music.db")) + db.init_db() + # 최소 트랙 1개 — music_library 테이블에 직접 삽입 + conn = sqlite3.connect(db.DB_PATH) + cur = conn.cursor() + cur.execute( + """INSERT INTO music_library + (id, title, genre, moods, instruments, duration_sec, bpm, key, scale, + prompt, audio_url, file_path, task_id, tags) + VALUES (1, 'T', 'lo-fi', '["chill"]', '["piano"]', 120, 85, 'C', 'maj', + 'p', '/media/music/x.mp3', '/app/data/music/x.mp3', NULL, '[]')""", + ) + conn.commit() + conn.close() + return TestClient(app) + + +def test_create_pipeline(client): + r = client.post("/api/music/pipeline", json={"track_id": 1}) + assert r.status_code == 201 + assert r.json()["state"] == "created" + + +def test_create_duplicate_pipeline_returns_409(client): + client.post("/api/music/pipeline", json={"track_id": 1}) + r = client.post("/api/music/pipeline", json={"track_id": 1}) + assert r.status_code == 409 + + +def test_get_pipeline_returns_jobs_and_feedback(client): + pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"] + r = client.get(f"/api/music/pipeline/{pid}") + assert "jobs" in r.json() + assert "feedback" in r.json() + + +def test_list_pipelines_active_filter(client): + pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"] + db.update_pipeline_state(pid, "published") + r = client.get("/api/music/pipeline?status=active") + assert all(p["state"] != "published" for p in r.json()["pipelines"]) + + +def test_feedback_reject_records_feedback_and_increments_count(client): + pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"] + db.update_pipeline_state(pid, "cover_pending") + # orchestrator.run_step를 mock해서 백그라운드 작업이 cover_pending을 변경하지 않도록 + with patch("app.main.orchestrator.run_step", new=AsyncMock()): + r = client.post( + f"/api/music/pipeline/{pid}/feedback", + json={"step": "cover", "intent": "reject", "feedback_text": "더 어둡게"}, + ) + assert r.status_code == 202 + p = db.get_pipeline(pid) + assert p["feedback_count_per_step"]["cover"] == 1 + history = db.get_feedback_history(pid) + assert history[0]["feedback_text"] == "더 어둡게" + + +def test_feedback_after_5_rejects_marks_awaiting_manual(client): + pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"] + db.update_pipeline_state(pid, "cover_pending") + with patch("app.main.orchestrator.run_step", new=AsyncMock()): + for i in range(5): + client.post( + f"/api/music/pipeline/{pid}/feedback", + json={"step": "cover", "intent": "reject", "feedback_text": f"again {i}"}, + ) + r = client.post( + f"/api/music/pipeline/{pid}/feedback", + json={"step": "cover", "intent": "reject", "feedback_text": "6th"}, + ) + assert r.status_code == 409 + assert db.get_pipeline(pid)["state"] == "awaiting_manual" + + +def test_cancel_pipeline(client): + pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"] + r = client.post(f"/api/music/pipeline/{pid}/cancel") + assert r.status_code == 200 + assert db.get_pipeline(pid)["state"] == "cancelled" + + +def test_setup_get_returns_defaults(client): + r = client.get("/api/music/setup") + assert r.status_code == 200 + assert r.json()["review_threshold"] == 60 + + +def test_setup_put_updates(client): + r = client.put("/api/music/setup", json={"review_threshold": 70}) + assert r.status_code == 200 + assert r.json()["review_threshold"] == 70 + + +def test_youtube_status_when_disconnected(client): + r = client.get("/api/music/youtube/status") + assert r.status_code == 200 + assert r.json() == {"connected": False}