feat(music-lab): pipeline 오케스트레이터 + 14 엔드포인트

This commit is contained in:
2026-05-07 17:11:29 +09:00
parent 4755e34c14
commit fe60c8d330
3 changed files with 487 additions and 0 deletions

View File

@@ -22,9 +22,12 @@ from .db import (
create_compile_job, get_compile_jobs, get_compile_job, create_compile_job, get_compile_jobs, get_compile_job,
update_compile_job, delete_compile_job, update_compile_job, delete_compile_job,
) )
from . import db as _db_module
from .compiler import run_compile from .compiler import run_compile
from .market import ingest_trends, get_suggestions from .market import ingest_trends, get_suggestions
from .local_provider import run_local_generation from .local_provider import run_local_generation
from .pipeline import orchestrator
from .pipeline import youtube as yt_module
from .suno_provider import ( from .suno_provider import (
run_suno_generation, run_suno_extend, run_vocal_removal, run_suno_generation, run_suno_extend, run_vocal_removal,
run_cover_image, run_wav_convert, run_stem_split, run_cover_image, run_wav_convert, run_stem_split,
@@ -921,3 +924,194 @@ def list_market_reports(limit: int = 10):
@app.get("/api/music/market/suggest") @app.get("/api/music/market/suggest")
def market_suggest(limit: int = 5): def market_suggest(limit: int = 5):
return {"suggestions": get_suggestions(limit)} return {"suggestions": get_suggestions(limit)}
# ── Pipeline endpoints ────────────────────────────────────────────────────────
class PipelineCreate(BaseModel):
track_id: int
class FeedbackRequest(BaseModel):
step: str
intent: str # approve | reject
feedback_text: Optional[str] = None
@app.post("/api/music/pipeline", status_code=201)
def create_pipeline(req: PipelineCreate):
actives = _db_module.list_pipelines(active_only=True)
if any(p["track_id"] == req.track_id for p in actives):
raise HTTPException(409, "이미 진행 중인 파이프라인이 있습니다")
pid = _db_module.create_pipeline(req.track_id)
return _db_module.get_pipeline(pid)
@app.get("/api/music/pipeline")
def list_pipelines_endpoint(status: str = "all"):
pipelines = _db_module.list_pipelines(active_only=(status == "active"))
return {"pipelines": pipelines}
@app.get("/api/music/pipeline/lookup-by-msg/{msg_id}")
def lookup_by_msg(msg_id: int):
for p in _db_module.list_pipelines(active_only=True):
for step, mid in p["last_telegram_msg_ids"].items():
if mid == msg_id:
return {"pipeline_id": p["id"], "step": step}
raise HTTPException(404)
@app.get("/api/music/pipeline/{pid}")
def get_pipeline_endpoint(pid: int):
p = _db_module.get_pipeline(pid)
if not p:
raise HTTPException(404)
p["jobs"] = _db_module.list_pipeline_jobs(pid)
p["feedback"] = _db_module.get_feedback_history(pid)
return p
@app.post("/api/music/pipeline/{pid}/start", status_code=202)
async def start_pipeline(pid: int, bg: BackgroundTasks):
p = _db_module.get_pipeline(pid)
if not p:
raise HTTPException(404)
if p["state"] != "created":
raise HTTPException(409, f"이미 시작됨 ({p['state']})")
bg.add_task(orchestrator.run_step, pid, "cover")
return {"ok": True}
def _state_to_step(state: str) -> Optional[str]:
return {
"video_pending": "video",
"thumb_pending": "thumb",
"meta_pending": "meta",
"ai_review": "review",
"publish_pending": None, # 사용자 명시 발행 호출 필요
"publishing": "publish",
}.get(state)
@app.post("/api/music/pipeline/{pid}/feedback", status_code=202)
async def feedback(pid: int, req: FeedbackRequest, bg: BackgroundTasks):
p = _db_module.get_pipeline(pid)
if not p:
raise HTTPException(404)
if p["state"] == "awaiting_manual":
raise HTTPException(409, "수동 개입 대기 중")
state = p["state"]
expected = f"{req.step}_pending"
if state != expected:
# 멱등 처리 — 이미 다음 단계로 넘어갔으면 무시
return {"ok": True, "skipped": True}
if req.intent == "approve":
from .pipeline.state_machine import next_state_on_approve
next_st = next_state_on_approve(state)
_db_module.update_pipeline_state(pid, next_st)
next_step = _state_to_step(next_st)
if next_step:
bg.add_task(orchestrator.run_step, pid, next_step)
return {"ok": True}
elif req.intent == "reject":
count = _db_module.increment_feedback_count(pid, req.step)
if count > 5:
_db_module.update_pipeline_state(pid, "awaiting_manual")
raise HTTPException(409, "재생성 한도 초과")
if req.feedback_text:
_db_module.record_feedback(pid, req.step, req.feedback_text)
bg.add_task(orchestrator.run_step, pid, req.step, req.feedback_text or "")
return {"ok": True}
else:
raise HTTPException(400, f"unknown intent: {req.intent}")
@app.post("/api/music/pipeline/{pid}/cancel")
def cancel_pipeline(pid: int):
p = _db_module.get_pipeline(pid)
if not p:
raise HTTPException(404)
_db_module.update_pipeline_state(pid, "cancelled", cancelled_at=_db_module._now())
return {"ok": True}
@app.post("/api/music/pipeline/{pid}/publish", status_code=202)
async def publish_pipeline(pid: int, bg: BackgroundTasks):
p = _db_module.get_pipeline(pid)
if not p:
raise HTTPException(404)
if p["state"] != "publish_pending":
raise HTTPException(409, f"발행 단계 아님 ({p['state']})")
_db_module.update_pipeline_state(pid, "publishing")
bg.add_task(orchestrator.run_step, pid, "publish")
return {"ok": True}
# Telegram 메시지 매칭용 엔드포인트 (agent-office용)
class TelegramMsgPatch(BaseModel):
step: str
message_id: int
@app.patch("/api/music/pipeline/{pid}/telegram-msg")
def save_telegram_msg(pid: int, req: TelegramMsgPatch):
p = _db_module.get_pipeline(pid)
if not p:
raise HTTPException(404)
ids = p["last_telegram_msg_ids"]
ids[req.step] = req.message_id
_db_module.update_pipeline_state(
pid, p["state"], last_telegram_msg_ids=json.dumps(ids)
)
return {"ok": True}
# ── Setup endpoints ───────────────────────────────────────────────────────────
class SetupRequest(BaseModel):
metadata_template: Optional[Dict[str, Any]] = None
cover_prompts: Optional[Dict[str, Any]] = None
review_weights: Optional[Dict[str, Any]] = None
review_threshold: Optional[int] = None
visual_defaults: Optional[Dict[str, Any]] = None
publish_policy: Optional[Dict[str, Any]] = None
@app.get("/api/music/setup")
def get_setup():
return _db_module.get_youtube_setup()
@app.put("/api/music/setup")
def put_setup(req: SetupRequest):
payload = {k: v for k, v in req.dict().items() if v is not None}
_db_module.update_youtube_setup(**payload)
return _db_module.get_youtube_setup()
# ── YouTube OAuth endpoints ───────────────────────────────────────────────────
@app.get("/api/music/youtube/auth-url")
def youtube_auth_url():
return {"url": yt_module.get_auth_url()}
@app.get("/api/music/youtube/callback")
async def youtube_callback(code: str):
return await yt_module.exchange_code(code)
@app.post("/api/music/youtube/disconnect")
def youtube_disconnect():
yt_module.disconnect()
return {"ok": True}
@app.get("/api/music/youtube/status")
def youtube_status():
return yt_module.get_status() or {"connected": False}

View File

@@ -0,0 +1,183 @@
"""파이프라인 오케스트레이터 — 단계별 BackgroundTask 등록 및 산출물 → DB 반영."""
import json
import logging
import os
import sqlite3
from app import db
from . import cover, video, thumb, metadata, review, youtube
logger = logging.getLogger("music-lab.orchestrator")
async def run_step(pipeline_id: int, step: str, feedback: str = "") -> None:
"""단계 실행 → 결과를 DB에 반영하고 *_pending 또는 다음 단계로 전이.
호출 직후 _running 상태로 전환, 끝나면 _pending(사용자 게이트) 또는 자동 다음.
실패 시 failed 상태 + reason.
"""
job_id = db.create_pipeline_job(pipeline_id, step)
db.update_pipeline_job(job_id, status="running")
p = db.get_pipeline(pipeline_id)
track = _get_track(p["track_id"])
try:
if step == "cover":
result = await _run_cover(p, track, feedback)
elif step == "video":
result = await _run_video(p, track)
elif step == "thumb":
result = await _run_thumb(p, track, feedback)
elif step == "meta":
result = await _run_meta(p, track, feedback)
elif step == "review":
result = await _run_review(p, track)
elif step == "publish":
result = await _run_publish(p, track)
else:
raise ValueError(f"unknown step: {step}")
db.update_pipeline_job(job_id, status="succeeded")
db.update_pipeline_state(pipeline_id, result["next_state"], **result.get("fields", {}))
except Exception as e:
logger.exception("step %s failed for pipeline %s", step, pipeline_id)
db.update_pipeline_job(job_id, status="failed", error=str(e))
db.update_pipeline_state(pipeline_id, "failed", failed_reason=f"{step}: {e}")
def _get_track(track_id: int) -> dict:
# tracks 테이블 헬퍼 — 기존 db에 있는 함수 사용
t = None
if hasattr(db, "get_track_by_id"):
t = db.get_track_by_id(track_id)
elif hasattr(db, "get_track"):
t = db.get_track(track_id)
if not t:
# 폴백: music_library 테이블에서 직접 (스키마 확인 필요)
t = _fetch_track_fallback(track_id)
if not t:
raise ValueError(f"트랙 {track_id} 없음")
return t
def _fetch_track_fallback(track_id: int) -> dict | None:
"""db 모듈에 get_track이 없을 때 대비 — music_library 테이블 직접 조회."""
try:
conn = sqlite3.connect(db.DB_PATH)
conn.row_factory = sqlite3.Row
# 가능한 테이블/컬럼 시도 (music_library 또는 tracks)
for table in ("music_library", "tracks"):
try:
row = conn.execute(f"SELECT * FROM {table} WHERE id = ?", (track_id,)).fetchone()
if row:
d = dict(row)
# JSON 컬럼 파싱 (있으면)
for k in ("moods", "instruments"):
if k in d and isinstance(d[k], str):
try:
d[k] = json.loads(d[k])
except (json.JSONDecodeError, TypeError):
d[k] = []
conn.close()
return d
except sqlite3.OperationalError:
continue
conn.close()
except Exception as e:
logger.warning("track fallback fetch 실패: %s", e)
return None
async def _run_cover(p, track, feedback):
setup = db.get_youtube_setup()
prompts = setup["cover_prompts"]
template = prompts.get(track.get("genre", "default").lower(), prompts.get("default", ""))
out = await cover.generate(
pipeline_id=p["id"], genre=track.get("genre", "default"),
prompt_template=template,
mood=", ".join(track.get("moods", []) or []),
track_title=track.get("title", ""),
feedback=feedback,
)
return {"next_state": "cover_pending", "fields": {"cover_url": out["url"]}}
async def _run_video(p, track):
setup = db.get_youtube_setup()
vd = setup["visual_defaults"]
audio_path = _local_path(track.get("audio_url", ""))
cover_path = _local_path(p["cover_url"])
out = video.generate(
pipeline_id=p["id"], audio_path=audio_path, cover_path=cover_path,
genre=track.get("genre", "default"),
duration_sec=track.get("duration_sec", 120),
resolution=vd["resolution"], style=vd["style"],
)
return {"next_state": "video_pending", "fields": {"video_url": out["url"]}}
async def _run_thumb(p, track, feedback):
video_path = _local_path(p["video_url"])
out = thumb.generate(pipeline_id=p["id"], video_path=video_path,
track_title=track.get("title", ""), overlay_text=True)
return {"next_state": "thumb_pending", "fields": {"thumbnail_url": out["url"]}}
async def _run_meta(p, track, feedback):
setup = db.get_youtube_setup()
trend_top = _get_trend_top()
out = await metadata.generate(
track=track, template=setup["metadata_template"],
trend_keywords=trend_top, feedback=feedback,
)
return {"next_state": "meta_pending",
"fields": {"metadata_json": json.dumps(out, ensure_ascii=False)}}
async def _run_review(p, track):
setup = db.get_youtube_setup()
meta = json.loads(p["metadata_json"]) if p.get("metadata_json") else {}
result = await review.run_4_axis(
pipeline=p, track=track,
video_meta={"length_sec": track.get("duration_sec", 120),
"resolution": setup["visual_defaults"]["resolution"]},
metadata=meta, thumbnail_url=p.get("thumbnail_url", ""),
trend_top=_get_trend_top(),
weights=setup["review_weights"], threshold=setup["review_threshold"],
)
return {"next_state": "publish_pending",
"fields": {"review_json": json.dumps(result, ensure_ascii=False)}}
async def _run_publish(p, track):
setup = db.get_youtube_setup()
meta = json.loads(p["metadata_json"]) if p.get("metadata_json") else {}
privacy = setup["publish_policy"].get("privacy", "private")
result = youtube.upload_video(
video_path=_local_path(p["video_url"]),
thumbnail_path=_local_path(p["thumbnail_url"]) if p.get("thumbnail_url") else None,
metadata=meta, privacy=privacy,
)
return {"next_state": "published",
"fields": {"youtube_video_id": result["video_id"]}}
def _local_path(media_url: str) -> str:
""" /media/videos/123/cover.jpg → /app/data/videos/123/cover.jpg """
if not media_url:
return ""
base_media = os.getenv("VIDEO_MEDIA_BASE", "/media/videos")
base_data = os.getenv("VIDEO_DATA_DIR", "/app/data/videos")
if media_url.startswith(base_media):
return media_url.replace(base_media, base_data, 1)
# /media/music/abc.mp3 → /app/data/music/abc.mp3
return media_url.replace("/media/", "/app/data/", 1)
def _get_trend_top(n: int = 10) -> list[str]:
try:
if hasattr(db, "get_market_trends"):
rows = db.get_market_trends(days=7)
return [r.get("genre", "") for r in rows[:n] if r.get("genre")]
except Exception:
pass
return []

View File

@@ -0,0 +1,110 @@
import sqlite3
import pytest
from unittest.mock import AsyncMock, patch
from fastapi.testclient import TestClient
from app.main import app
from app import db
@pytest.fixture
def client(monkeypatch, tmp_path):
monkeypatch.setattr(db, "DB_PATH", str(tmp_path / "music.db"))
db.init_db()
# 최소 트랙 1개 — music_library 테이블에 직접 삽입
conn = sqlite3.connect(db.DB_PATH)
cur = conn.cursor()
cur.execute(
"""INSERT INTO music_library
(id, title, genre, moods, instruments, duration_sec, bpm, key, scale,
prompt, audio_url, file_path, task_id, tags)
VALUES (1, 'T', 'lo-fi', '["chill"]', '["piano"]', 120, 85, 'C', 'maj',
'p', '/media/music/x.mp3', '/app/data/music/x.mp3', NULL, '[]')""",
)
conn.commit()
conn.close()
return TestClient(app)
def test_create_pipeline(client):
r = client.post("/api/music/pipeline", json={"track_id": 1})
assert r.status_code == 201
assert r.json()["state"] == "created"
def test_create_duplicate_pipeline_returns_409(client):
client.post("/api/music/pipeline", json={"track_id": 1})
r = client.post("/api/music/pipeline", json={"track_id": 1})
assert r.status_code == 409
def test_get_pipeline_returns_jobs_and_feedback(client):
pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"]
r = client.get(f"/api/music/pipeline/{pid}")
assert "jobs" in r.json()
assert "feedback" in r.json()
def test_list_pipelines_active_filter(client):
pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"]
db.update_pipeline_state(pid, "published")
r = client.get("/api/music/pipeline?status=active")
assert all(p["state"] != "published" for p in r.json()["pipelines"])
def test_feedback_reject_records_feedback_and_increments_count(client):
pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"]
db.update_pipeline_state(pid, "cover_pending")
# orchestrator.run_step를 mock해서 백그라운드 작업이 cover_pending을 변경하지 않도록
with patch("app.main.orchestrator.run_step", new=AsyncMock()):
r = client.post(
f"/api/music/pipeline/{pid}/feedback",
json={"step": "cover", "intent": "reject", "feedback_text": "더 어둡게"},
)
assert r.status_code == 202
p = db.get_pipeline(pid)
assert p["feedback_count_per_step"]["cover"] == 1
history = db.get_feedback_history(pid)
assert history[0]["feedback_text"] == "더 어둡게"
def test_feedback_after_5_rejects_marks_awaiting_manual(client):
pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"]
db.update_pipeline_state(pid, "cover_pending")
with patch("app.main.orchestrator.run_step", new=AsyncMock()):
for i in range(5):
client.post(
f"/api/music/pipeline/{pid}/feedback",
json={"step": "cover", "intent": "reject", "feedback_text": f"again {i}"},
)
r = client.post(
f"/api/music/pipeline/{pid}/feedback",
json={"step": "cover", "intent": "reject", "feedback_text": "6th"},
)
assert r.status_code == 409
assert db.get_pipeline(pid)["state"] == "awaiting_manual"
def test_cancel_pipeline(client):
pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"]
r = client.post(f"/api/music/pipeline/{pid}/cancel")
assert r.status_code == 200
assert db.get_pipeline(pid)["state"] == "cancelled"
def test_setup_get_returns_defaults(client):
r = client.get("/api/music/setup")
assert r.status_code == 200
assert r.json()["review_threshold"] == 60
def test_setup_put_updates(client):
r = client.put("/api/music/setup", json={"review_threshold": 70})
assert r.status_code == 200
assert r.json()["review_threshold"] == 70
def test_youtube_status_when_disconnected(client):
r = client.get("/api/music/youtube/status")
assert r.status_code == 200
assert r.json() == {"connected": False}