Files
web-page-backend/music-lab/app/pipeline/orchestrator.py

304 lines
12 KiB
Python

"""파이프라인 오케스트레이터 — 단계별 BackgroundTask 등록 및 산출물 → DB 반영."""
import asyncio
import json
import logging
import os
import sqlite3
from app import db
from . import cover, video, thumb, metadata, review, youtube, background, storage
from .gradient import make_gradient_with_title
logger = logging.getLogger("music-lab.orchestrator")
async def run_step(pipeline_id: int, step: str, feedback: str = "") -> None:
"""단계 실행 → 결과를 DB에 반영하고 *_pending 또는 다음 단계로 전이.
호출 직후 _running 상태로 전환, 끝나면 _pending(사용자 게이트) 또는 자동 다음.
실패 시 failed 상태 + reason.
"""
job_id = db.create_pipeline_job(pipeline_id, step)
db.update_pipeline_job(job_id, status="running")
p = db.get_pipeline(pipeline_id)
try:
ctx = _resolve_input(p)
except ValueError as e:
db.update_pipeline_job(job_id, status="failed", error=str(e))
db.update_pipeline_state(pipeline_id, "failed", failed_reason=f"{step}: {e}")
return
try:
if step == "cover":
result = await _run_cover(p, ctx, feedback)
elif step == "video":
result = await _run_video(p, ctx)
elif step == "thumb":
result = await _run_thumb(p, ctx, feedback)
elif step == "meta":
result = await _run_meta(p, ctx, feedback)
elif step == "review":
result = await _run_review(p, ctx)
elif step == "publish":
result = await _run_publish(p, ctx)
else:
raise ValueError(f"unknown step: {step}")
db.update_pipeline_job(job_id, status="succeeded")
db.update_pipeline_state(pipeline_id, result["next_state"], **result.get("fields", {}))
except Exception as e:
logger.exception("step %s failed for pipeline %s", step, pipeline_id)
db.update_pipeline_job(job_id, status="failed", error=str(e))
db.update_pipeline_state(pipeline_id, "failed", failed_reason=f"{step}: {e}")
def _resolve_input(p: dict) -> dict:
"""파이프라인 입력 = 단일 트랙 또는 컴파일 결과.
반환: {
"audio_path": str, # 컨테이너 절대경로
"duration_sec": int,
"tracks": list[{"id", "title", "start_offset_sec", "duration_sec"}],
"title": str,
"genre": str, # mix는 "mix"
"moods": list[str],
}
"""
track_id = p.get("track_id")
compile_id = p.get("compile_job_id")
if track_id is None and compile_id is None:
raise ValueError("track_id 또는 compile_job_id 중 하나는 필요")
if compile_id is not None:
job = db.get_compile_job(compile_id)
if not job or job.get("status") != "succeeded":
raise ValueError(
f"compile job {compile_id} not ready "
f"(status={job.get('status') if job else None})"
)
tracks = []
offset = 0.0
crossfade = job.get("crossfade_sec", 0) or 0
track_ids = job.get("track_ids") or []
for tid in track_ids:
t = db.get_track_by_id(tid)
if not t:
continue
dur = t.get("duration_sec", 0)
tracks.append({
"id": tid,
"title": t.get("title", ""),
"start_offset_sec": int(offset),
"duration_sec": dur,
})
offset += dur - crossfade
# 마지막 트랙은 풀 길이 반영 (crossfade 빼기 한 것 복구)
total = int(offset + crossfade) if tracks else 0
return {
"audio_path": job.get("audio_path") or job.get("output_path") or "",
"duration_sec": total,
"tracks": tracks,
"title": job.get("title") or "Mix",
"genre": "mix",
"moods": [],
}
# 단일 트랙
t = db.get_track_by_id(track_id)
if not t:
raise ValueError(f"track {track_id} 없음")
return {
"audio_path": t.get("file_path") or _local_path(t.get("audio_url", "")),
"duration_sec": t.get("duration_sec", 0),
"tracks": [{
"id": t["id"],
"title": t.get("title", ""),
"start_offset_sec": 0,
"duration_sec": t.get("duration_sec", 0),
}],
"title": t.get("title", ""),
"genre": t.get("genre", "default"),
"moods": t.get("moods", []) or [],
}
def _get_track(track_id: int) -> dict:
# tracks 테이블 헬퍼 — 기존 db에 있는 함수 사용
t = None
if hasattr(db, "get_track_by_id"):
t = db.get_track_by_id(track_id)
elif hasattr(db, "get_track"):
t = db.get_track(track_id)
if not t:
# 폴백: music_library 테이블에서 직접 (스키마 확인 필요)
t = _fetch_track_fallback(track_id)
if not t:
raise ValueError(f"트랙 {track_id} 없음")
return t
def _fetch_track_fallback(track_id: int) -> dict | None:
"""db 모듈에 get_track이 없을 때 대비 — music_library 테이블 직접 조회."""
try:
conn = sqlite3.connect(db.DB_PATH)
conn.row_factory = sqlite3.Row
# 가능한 테이블/컬럼 시도 (music_library 또는 tracks)
for table in ("music_library", "tracks"):
try:
row = conn.execute(f"SELECT * FROM {table} WHERE id = ?", (track_id,)).fetchone()
if row:
d = dict(row)
# JSON 컬럼 파싱 (있으면)
for k in ("moods", "instruments"):
if k in d and isinstance(d[k], str):
try:
d[k] = json.loads(d[k])
except (json.JSONDecodeError, TypeError):
d[k] = []
conn.close()
return d
except sqlite3.OperationalError:
continue
conn.close()
except Exception as e:
logger.warning("track fallback fetch 실패: %s", e)
return None
async def _run_cover(p, ctx, feedback):
setup = db.get_youtube_setup()
vd = setup["visual_defaults"]
bg_mode = p.get("background_mode") or vd.get("default_background_mode", "static")
keyword = p.get("background_keyword") or vd.get("default_background_keyword", "")
if bg_mode == "video_loop":
# Pexels 영상 다운로드 시도 — 성공 여부와 무관하게 cover.jpg는 그라데이션으로 별도 생성
# (실패 시 video.py가 cover.jpg를 fallback 배경으로 사용 가능)
await background.fetch_video_loop(p["id"], keyword)
out_path = os.path.join(storage.pipeline_dir(p["id"]), "cover.jpg")
make_gradient_with_title(ctx["genre"], ctx["title"], out_path)
return {"next_state": "cover_pending",
"fields": {"cover_url": storage.media_url(p["id"], "cover.jpg")}}
# 정적 모드 — 기존 cover.generate 흐름
prompts = setup["cover_prompts"]
template = prompts.get(ctx["genre"].lower(), prompts.get("default", ""))
image_source = vd.get("background_image_source", "ai")
out = await cover.generate(
pipeline_id=p["id"], genre=ctx["genre"],
prompt_template=template,
mood=", ".join(ctx["moods"] or []),
track_title=ctx["title"], feedback=feedback,
image_source=image_source,
background_keyword=keyword,
)
return {"next_state": "cover_pending", "fields": {"cover_url": out["url"]}}
async def _run_video(p, ctx):
setup = db.get_youtube_setup()
vd = setup["visual_defaults"]
audio_path = ctx["audio_path"]
cover_path = _local_path(p["cover_url"])
style = p.get("visual_style") or vd.get("default_visual_style", "essential")
bg_mode = p.get("background_mode") or vd.get("default_background_mode", "static")
bg_path = None
if bg_mode == "video_loop":
loop_local = os.path.join(storage.pipeline_dir(p["id"]), "loop.mp4")
bg_path = loop_local if os.path.isfile(loop_local) else None
out = await asyncio.to_thread(
video.generate,
pipeline_id=p["id"], audio_path=audio_path, cover_path=cover_path,
genre=ctx["genre"],
duration_sec=ctx["duration_sec"],
resolution=vd.get("resolution", "1920x1080"),
style=style,
background_mode=bg_mode,
background_path=bg_path,
tracks=ctx["tracks"] if len(ctx["tracks"]) > 1 else None,
)
return {"next_state": "video_pending", "fields": {"video_url": out["url"]}}
async def _run_thumb(p, ctx, feedback):
video_path = _local_path(p["video_url"])
out = await asyncio.to_thread(
thumb.generate,
pipeline_id=p["id"], video_path=video_path,
track_title=ctx["title"], overlay_text=True,
)
return {"next_state": "thumb_pending", "fields": {"thumbnail_url": out["url"]}}
async def _run_meta(p, ctx, feedback):
setup = db.get_youtube_setup()
trend_top = _get_trend_top()
out = await metadata.generate(
track={"title": ctx["title"], "genre": ctx["genre"],
"duration_sec": ctx["duration_sec"], "moods": ctx["moods"]},
template=setup["metadata_template"],
trend_keywords=trend_top, feedback=feedback,
tracks=ctx["tracks"] if len(ctx["tracks"]) > 1 else None,
)
return {"next_state": "meta_pending",
"fields": {"metadata_json": json.dumps(out, ensure_ascii=False)}}
async def _run_review(p, ctx):
setup = db.get_youtube_setup()
meta = json.loads(p["metadata_json"]) if p.get("metadata_json") else {}
result = await review.run_4_axis(
pipeline=p,
track={"title": ctx["title"], "genre": ctx["genre"], "duration_sec": ctx["duration_sec"]},
video_meta={"length_sec": ctx["duration_sec"],
"resolution": setup["visual_defaults"]["resolution"]},
metadata=meta, thumbnail_url=p.get("thumbnail_url", ""),
trend_top=_get_trend_top(),
weights=setup["review_weights"], threshold=setup["review_threshold"],
)
return {"next_state": "publish_pending",
"fields": {"review_json": json.dumps(result, ensure_ascii=False)}}
async def _run_publish(p, ctx):
setup = db.get_youtube_setup()
meta = json.loads(p["metadata_json"]) if p.get("metadata_json") else {}
privacy = setup["publish_policy"].get("privacy", "private")
result = await asyncio.to_thread(
youtube.upload_video,
video_path=_local_path(p["video_url"]),
thumbnail_path=_local_path(p["thumbnail_url"]) if p.get("thumbnail_url") else None,
metadata=meta, privacy=privacy,
)
return {"next_state": "published",
"fields": {"youtube_video_id": result["video_id"]}}
def _local_path(media_url: str) -> str:
""" /media/videos/123/cover.jpg → /app/data/videos/123/cover.jpg
/media/music/abc.mp3 → /app/data/abc.mp3 (music mount at /app/data, no subdir)
"""
if not media_url:
return ""
base_media = os.getenv("VIDEO_MEDIA_BASE", "/media/videos")
base_data = os.getenv("VIDEO_DATA_DIR", "/app/data/videos")
if media_url.startswith(base_media):
return media_url.replace(base_media, base_data, 1)
if media_url.startswith("/media/music/"):
return media_url.replace("/media/music/", "/app/data/", 1)
return media_url.replace("/media/", "/app/data/", 1)
def _get_trend_top(n: int = 10) -> list[str]:
try:
if hasattr(db, "get_market_trends"):
rows = db.get_market_trends(days=7)
return [r.get("genre", "") for r in rows[:n] if r.get("genre")]
except Exception:
pass
return []