Compare commits
13 Commits
e03d074222
...
4f67cd02fa
| Author | SHA1 | Date | |
|---|---|---|---|
| 4f67cd02fa | |||
| 868906b8c6 | |||
| bd97cc1e97 | |||
| 7552ce4263 | |||
| 17034ea6ea | |||
| fe60c8d330 | |||
| 4755e34c14 | |||
| ad1c721ba8 | |||
| 1c705b0ef3 | |||
| 68dec2e53d | |||
| e33a2310af | |||
| fceca88db4 | |||
| d66a321982 |
@@ -4,6 +4,7 @@ from .blog import BlogAgent
|
|||||||
from .realestate import RealestateAgent
|
from .realestate import RealestateAgent
|
||||||
from .lotto import LottoAgent
|
from .lotto import LottoAgent
|
||||||
from .youtube import YouTubeResearchAgent
|
from .youtube import YouTubeResearchAgent
|
||||||
|
from .youtube_publisher import YoutubePublisherAgent
|
||||||
|
|
||||||
AGENT_REGISTRY = {}
|
AGENT_REGISTRY = {}
|
||||||
|
|
||||||
@@ -14,6 +15,7 @@ def init_agents():
|
|||||||
AGENT_REGISTRY["realestate"] = RealestateAgent()
|
AGENT_REGISTRY["realestate"] = RealestateAgent()
|
||||||
AGENT_REGISTRY["lotto"] = LottoAgent()
|
AGENT_REGISTRY["lotto"] = LottoAgent()
|
||||||
AGENT_REGISTRY["youtube"] = YouTubeResearchAgent()
|
AGENT_REGISTRY["youtube"] = YouTubeResearchAgent()
|
||||||
|
AGENT_REGISTRY["youtube_publisher"] = YoutubePublisherAgent()
|
||||||
|
|
||||||
def get_agent(agent_id: str):
|
def get_agent(agent_id: str):
|
||||||
return AGENT_REGISTRY.get(agent_id)
|
return AGENT_REGISTRY.get(agent_id)
|
||||||
|
|||||||
75
agent-office/app/agents/classify_intent.py
Normal file
75
agent-office/app/agents/classify_intent.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
"""텔레그램 사용자 응답 자연어 분류 — 화이트리스트 우선, 모호 시 LLM."""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
logger = logging.getLogger("agent-office.classify_intent")
|
||||||
|
|
||||||
|
CLAUDE_HAIKU_DEFAULT = "claude-haiku-4-5-20251001"
|
||||||
|
|
||||||
|
APPROVE_WORDS = {
|
||||||
|
"승인", "시작", "진행", "ok", "okay", "agree",
|
||||||
|
"네", "예", "좋아", "좋아요", "go", "yes", "y",
|
||||||
|
}
|
||||||
|
REJECT_WORDS = {"반려", "거절", "취소", "no", "nope", "n"}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_api_key() -> str:
|
||||||
|
return os.getenv("ANTHROPIC_API_KEY", "")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model() -> str:
|
||||||
|
return os.getenv("CLAUDE_HAIKU_MODEL", CLAUDE_HAIKU_DEFAULT)
|
||||||
|
|
||||||
|
|
||||||
|
def classify(text: str) -> tuple[str, str | None]:
|
||||||
|
"""returns (intent, feedback) — intent ∈ {approve, reject, unclear}"""
|
||||||
|
if not text:
|
||||||
|
return ("unclear", None)
|
||||||
|
t = text.strip().lower()
|
||||||
|
if t in APPROVE_WORDS:
|
||||||
|
return ("approve", None)
|
||||||
|
if t in REJECT_WORDS:
|
||||||
|
return ("reject", None)
|
||||||
|
# 반려 단어로 시작 + 추가 텍스트
|
||||||
|
for w in REJECT_WORDS:
|
||||||
|
if t.startswith(w):
|
||||||
|
rest = text.strip()[len(w):].lstrip(" ,.-:").strip()
|
||||||
|
if rest:
|
||||||
|
return ("reject", rest)
|
||||||
|
# 승인 단어로 시작 (긍정 의도면 추가 텍스트 무시)
|
||||||
|
for w in APPROVE_WORDS:
|
||||||
|
if t.startswith(w + " ") or t == w:
|
||||||
|
return ("approve", None)
|
||||||
|
return _llm_classify(text)
|
||||||
|
|
||||||
|
|
||||||
|
def _llm_classify(text: str) -> tuple[str, str | None]:
|
||||||
|
api_key = _get_api_key()
|
||||||
|
if not api_key:
|
||||||
|
return ("unclear", None)
|
||||||
|
prompt = (
|
||||||
|
"사용자 응답을 분류하세요. JSON으로만 응답.\n"
|
||||||
|
f'응답: "{text}"\n\n'
|
||||||
|
'출력: {"intent":"approve|reject|unclear","feedback":"반려면 수정 방향, 아니면 빈 문자열"}'
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
resp = httpx.post(
|
||||||
|
"https://api.anthropic.com/v1/messages",
|
||||||
|
headers={"x-api-key": api_key, "anthropic-version": "2023-06-01"},
|
||||||
|
json={"model": _get_model(), "max_tokens": 200,
|
||||||
|
"messages": [{"role": "user", "content": prompt}]},
|
||||||
|
timeout=15,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
text_out = resp.json()["content"][0]["text"]
|
||||||
|
start = text_out.find("{")
|
||||||
|
end = text_out.rfind("}") + 1
|
||||||
|
if start < 0 or end <= start:
|
||||||
|
return ("unclear", None)
|
||||||
|
data = json.loads(text_out[start:end])
|
||||||
|
return (data.get("intent", "unclear"), data.get("feedback") or None)
|
||||||
|
except (httpx.HTTPError, httpx.TimeoutException, KeyError, ValueError, json.JSONDecodeError) as e:
|
||||||
|
logger.warning("LLM 분류 실패: %s", e)
|
||||||
|
return ("unclear", None)
|
||||||
108
agent-office/app/agents/youtube_publisher.py
Normal file
108
agent-office/app/agents/youtube_publisher.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""텔레그램 단일 채널로 단계별 승인 인터랙션 오케스트레이션."""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from .base import BaseAgent
|
||||||
|
from . import classify_intent
|
||||||
|
from .. import service_proxy
|
||||||
|
from ..db import add_log
|
||||||
|
from ..telegram.messaging import send_raw
|
||||||
|
|
||||||
|
logger = logging.getLogger("agent-office.youtube_publisher")
|
||||||
|
|
||||||
|
|
||||||
|
_STEP_TITLES = {
|
||||||
|
"cover_pending": ("커버 아트", "cover"),
|
||||||
|
"video_pending": ("영상 비주얼", "video"),
|
||||||
|
"thumb_pending": ("썸네일", "thumb"),
|
||||||
|
"meta_pending": ("메타데이터", "meta"),
|
||||||
|
"publish_pending": ("최종 검토 + 발행", "publish"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class YoutubePublisherAgent(BaseAgent):
|
||||||
|
agent_id = "youtube_publisher"
|
||||||
|
display_name = "YouTube 퍼블리셔"
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._notified_state_per_pipeline: dict[int, str] = {}
|
||||||
|
|
||||||
|
async def poll_state_changes(self) -> None:
|
||||||
|
"""주기적으로 호출되어 *_pending 신규 진입 시 텔레그램 발송."""
|
||||||
|
try:
|
||||||
|
pipelines = await service_proxy.list_active_pipelines()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("폴링 실패: %s", e)
|
||||||
|
return
|
||||||
|
|
||||||
|
for p in pipelines:
|
||||||
|
state = p.get("state")
|
||||||
|
pid = p.get("id")
|
||||||
|
if pid is None:
|
||||||
|
continue
|
||||||
|
if state in _STEP_TITLES and self._notified_state_per_pipeline.get(pid) != state:
|
||||||
|
await self._notify_step(p)
|
||||||
|
self._notified_state_per_pipeline[pid] = state
|
||||||
|
|
||||||
|
async def _notify_step(self, pipeline: dict) -> None:
|
||||||
|
state = pipeline["state"]
|
||||||
|
title_name, step = _STEP_TITLES[state]
|
||||||
|
body = self._format_body(pipeline, step)
|
||||||
|
track_title = pipeline.get("track_title") or f"Pipeline #{pipeline['id']}"
|
||||||
|
text = (
|
||||||
|
f"🎵 [{track_title}] {title_name} 검토\n\n"
|
||||||
|
f"{body}\n\n"
|
||||||
|
f"➡️ 답장으로 알려주세요: '승인' 또는 '반려 + 수정 방향'"
|
||||||
|
)
|
||||||
|
sent = await send_raw(text=text)
|
||||||
|
if sent.get("ok"):
|
||||||
|
msg_id = sent.get("message_id")
|
||||||
|
try:
|
||||||
|
await service_proxy.save_pipeline_telegram_msg(pipeline["id"], step, msg_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("telegram-msg 저장 실패: %s", e)
|
||||||
|
add_log(self.agent_id, f"pipeline {pipeline['id']} {step} 알림 전송", "info")
|
||||||
|
|
||||||
|
def _format_body(self, p: dict, step: str) -> str:
|
||||||
|
if step == "cover":
|
||||||
|
return f"🖼️ 커버: {p.get('cover_url', '-')}"
|
||||||
|
if step == "video":
|
||||||
|
return f"🎬 영상: {p.get('video_url', '-')}"
|
||||||
|
if step == "thumb":
|
||||||
|
return f"🎴 썸네일: {p.get('thumbnail_url', '-')}"
|
||||||
|
if step == "meta":
|
||||||
|
m = p.get("metadata", {}) or {}
|
||||||
|
tags = m.get("tags", []) or []
|
||||||
|
description = (m.get("description", "") or "")
|
||||||
|
return (
|
||||||
|
f"📝 제목: {m.get('title', '')}\n"
|
||||||
|
f"🏷️ 태그: {', '.join(tags[:8])}\n"
|
||||||
|
f"📄 설명(앞부분): {description[:200]}"
|
||||||
|
)
|
||||||
|
if step == "publish":
|
||||||
|
r = p.get("review", {}) or {}
|
||||||
|
return (
|
||||||
|
f"AI 검토 결과: {r.get('verdict', '?')} "
|
||||||
|
f"(가중 {r.get('weighted_total', '?')}/100)\n"
|
||||||
|
f"{r.get('summary', '')}"
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
async def on_telegram_reply(self, pipeline_id: int, step: str, user_text: str) -> None:
|
||||||
|
intent, feedback = classify_intent.classify(user_text)
|
||||||
|
if intent == "unclear":
|
||||||
|
await send_raw("다시 입력해주세요. 예: '승인' 또는 '반려, 제목 짧게'")
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
await service_proxy.post_pipeline_feedback(pipeline_id, step, intent, feedback)
|
||||||
|
except Exception as e:
|
||||||
|
await send_raw(f"⚠️ 처리 실패: {e}")
|
||||||
|
|
||||||
|
async def on_schedule(self) -> None:
|
||||||
|
await self.poll_state_changes()
|
||||||
|
|
||||||
|
async def on_command(self, command: str, params: dict) -> dict:
|
||||||
|
return {"ok": False, "message": f"Unknown command: {command}"}
|
||||||
|
|
||||||
|
async def on_approval(self, task_id: str, approved: bool, feedback: str = "") -> None:
|
||||||
|
pass
|
||||||
@@ -34,6 +34,11 @@ async def _send_youtube_weekly_report():
|
|||||||
if agent:
|
if agent:
|
||||||
await agent.send_weekly_report()
|
await agent.send_weekly_report()
|
||||||
|
|
||||||
|
async def _poll_pipelines():
|
||||||
|
agent = AGENT_REGISTRY.get("youtube_publisher")
|
||||||
|
if agent:
|
||||||
|
await agent.poll_state_changes()
|
||||||
|
|
||||||
def init_scheduler():
|
def init_scheduler():
|
||||||
scheduler.add_job(_run_stock_schedule, "cron", hour=7, minute=30, id="stock_news")
|
scheduler.add_job(_run_stock_schedule, "cron", hour=7, minute=30, id="stock_news")
|
||||||
scheduler.add_job(_run_blog_schedule, "cron", hour=10, minute=0, id="blog_pipeline")
|
scheduler.add_job(_run_blog_schedule, "cron", hour=10, minute=0, id="blog_pipeline")
|
||||||
@@ -41,4 +46,5 @@ def init_scheduler():
|
|||||||
scheduler.add_job(_run_youtube_research, "cron", hour=9, minute=0, id="youtube_research")
|
scheduler.add_job(_run_youtube_research, "cron", hour=9, minute=0, id="youtube_research")
|
||||||
scheduler.add_job(_send_youtube_weekly_report, "cron", day_of_week="mon", hour=8, minute=0, id="youtube_weekly_report")
|
scheduler.add_job(_send_youtube_weekly_report, "cron", day_of_week="mon", hour=8, minute=0, id="youtube_weekly_report")
|
||||||
scheduler.add_job(_check_idle_breaks, "interval", seconds=60, id="idle_check")
|
scheduler.add_job(_check_idle_breaks, "interval", seconds=60, id="idle_check")
|
||||||
|
scheduler.add_job(_poll_pipelines, "interval", seconds=30, id="pipeline_poll")
|
||||||
scheduler.start()
|
scheduler.start()
|
||||||
|
|||||||
@@ -178,3 +178,46 @@ async def lotto_save_briefing(payload: dict) -> Dict[str, Any]:
|
|||||||
resp = await _client.post(f"{LOTTO_BACKEND_URL}/api/lotto/briefing", json=payload)
|
resp = await _client.post(f"{LOTTO_BACKEND_URL}/api/lotto/briefing", json=payload)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
return resp.json()
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
# --- music-lab pipeline (YouTube publisher orchestration) ---
|
||||||
|
|
||||||
|
async def list_active_pipelines() -> list[dict]:
|
||||||
|
async with httpx.AsyncClient(timeout=15) as client:
|
||||||
|
resp = await client.get(f"{MUSIC_LAB_URL}/api/music/pipeline?status=active")
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json().get("pipelines", [])
|
||||||
|
|
||||||
|
|
||||||
|
async def get_pipeline(pid: int) -> dict:
|
||||||
|
async with httpx.AsyncClient(timeout=15) as client:
|
||||||
|
resp = await client.get(f"{MUSIC_LAB_URL}/api/music/pipeline/{pid}")
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
async def post_pipeline_feedback(pid: int, step: str, intent: str,
|
||||||
|
feedback_text: Optional[str] = None) -> dict:
|
||||||
|
async with httpx.AsyncClient(timeout=15) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{MUSIC_LAB_URL}/api/music/pipeline/{pid}/feedback",
|
||||||
|
json={"step": step, "intent": intent, "feedback_text": feedback_text},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
async def save_pipeline_telegram_msg(pid: int, step: str, msg_id: int) -> None:
|
||||||
|
async with httpx.AsyncClient(timeout=10) as client:
|
||||||
|
await client.patch(
|
||||||
|
f"{MUSIC_LAB_URL}/api/music/pipeline/{pid}/telegram-msg",
|
||||||
|
json={"step": step, "message_id": msg_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def lookup_pipeline_by_msg(msg_id: int) -> Optional[dict]:
|
||||||
|
async with httpx.AsyncClient(timeout=10) as client:
|
||||||
|
resp = await client.get(f"{MUSIC_LAB_URL}/api/music/pipeline/lookup-by-msg/{msg_id}")
|
||||||
|
if resp.status_code == 200:
|
||||||
|
return resp.json()
|
||||||
|
return None
|
||||||
|
|||||||
@@ -103,6 +103,34 @@ def _build_messages(history: list, user_text: str) -> list:
|
|||||||
return msgs
|
return msgs
|
||||||
|
|
||||||
|
|
||||||
|
async def maybe_route_to_pipeline(message: dict) -> bool:
|
||||||
|
"""파이프라인 텔레그램 메시지에 대한 reply 인 경우 youtube_publisher 로 라우팅.
|
||||||
|
|
||||||
|
Returns True if message was routed (caller should stop further processing).
|
||||||
|
"""
|
||||||
|
reply_to = message.get("reply_to_message") or {}
|
||||||
|
msg_id = reply_to.get("message_id")
|
||||||
|
if not msg_id:
|
||||||
|
return False
|
||||||
|
from .. import service_proxy
|
||||||
|
try:
|
||||||
|
link = await service_proxy.lookup_pipeline_by_msg(msg_id)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
if not link:
|
||||||
|
return False
|
||||||
|
from ..agents import AGENT_REGISTRY
|
||||||
|
agent = AGENT_REGISTRY.get("youtube_publisher")
|
||||||
|
if not agent:
|
||||||
|
return False
|
||||||
|
pipeline_id = link.get("pipeline_id")
|
||||||
|
step = link.get("step")
|
||||||
|
if pipeline_id is None or not step:
|
||||||
|
return False
|
||||||
|
await agent.on_telegram_reply(pipeline_id, step, message.get("text", ""))
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def respond_to_message(chat_id: str, user_text: str) -> Optional[str]:
|
async def respond_to_message(chat_id: str, user_text: str) -> Optional[str]:
|
||||||
"""자연어 메시지에 응답. 실패 시 사용자에게 돌려줄 문자열 반환(또는 None = 무시)."""
|
"""자연어 메시지에 응답. 실패 시 사용자에게 돌려줄 문자열 반환(또는 None = 무시)."""
|
||||||
if not ANTHROPIC_API_KEY:
|
if not ANTHROPIC_API_KEY:
|
||||||
|
|||||||
@@ -102,6 +102,11 @@ async def _handle_message(message: dict, agent_dispatcher) -> Optional[dict]:
|
|||||||
from .router import parse_command, resolve_agent_command, HELP_TEXT
|
from .router import parse_command, resolve_agent_command, HELP_TEXT
|
||||||
from .messaging import send_raw, send_agent_message
|
from .messaging import send_raw, send_agent_message
|
||||||
from .agent_registry import AGENT_META
|
from .agent_registry import AGENT_META
|
||||||
|
from .conversational import maybe_route_to_pipeline
|
||||||
|
|
||||||
|
# 파이프라인 메시지에 대한 reply라면 youtube_publisher 로 라우팅
|
||||||
|
if await maybe_route_to_pipeline(message):
|
||||||
|
return {"handled": "pipeline_reply"}
|
||||||
|
|
||||||
text = message.get("text", "")
|
text = message.get("text", "")
|
||||||
parsed = parse_command(text)
|
parsed = parse_command(text)
|
||||||
|
|||||||
@@ -3,5 +3,6 @@ uvicorn[standard]==0.30.6
|
|||||||
apscheduler==3.10.4
|
apscheduler==3.10.4
|
||||||
websockets>=12.0
|
websockets>=12.0
|
||||||
httpx>=0.27
|
httpx>=0.27
|
||||||
|
respx>=0.21
|
||||||
google-api-python-client>=2.100.0
|
google-api-python-client>=2.100.0
|
||||||
pytrends>=4.9.2
|
pytrends>=4.9.2
|
||||||
|
|||||||
48
agent-office/tests/test_classify_intent.py
Normal file
48
agent-office/tests/test_classify_intent.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import pytest
|
||||||
|
import respx
|
||||||
|
from httpx import Response
|
||||||
|
from app.agents import classify_intent as ci
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear_approve_no_llm(monkeypatch):
|
||||||
|
# Patch _llm_classify so we can assert it wasn't called
|
||||||
|
called = {"n": 0}
|
||||||
|
def fake(text):
|
||||||
|
called["n"] += 1
|
||||||
|
return ("unclear", None)
|
||||||
|
monkeypatch.setattr(ci, "_llm_classify", fake)
|
||||||
|
assert ci.classify("승인") == ("approve", None)
|
||||||
|
assert ci.classify("OK") == ("approve", None)
|
||||||
|
assert ci.classify("진행") == ("approve", None)
|
||||||
|
assert ci.classify("agree") == ("approve", None)
|
||||||
|
assert called["n"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear_reject_only_no_llm(monkeypatch):
|
||||||
|
monkeypatch.setattr(ci, "_llm_classify", lambda t: ("unclear", None))
|
||||||
|
assert ci.classify("반려") == ("reject", None)
|
||||||
|
assert ci.classify("거절") == ("reject", None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reject_with_text_split(monkeypatch):
|
||||||
|
monkeypatch.setattr(ci, "_llm_classify", lambda t: ("unclear", None))
|
||||||
|
intent, fb = ci.classify("반려, 제목 짧게")
|
||||||
|
assert intent == "reject"
|
||||||
|
assert "제목 짧게" in fb
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_ambiguous_calls_llm(monkeypatch):
|
||||||
|
monkeypatch.setenv("ANTHROPIC_API_KEY", "k")
|
||||||
|
respx.post("https://api.anthropic.com/v1/messages").mock(
|
||||||
|
return_value=Response(200, json={"content": [{"type": "text",
|
||||||
|
"text": '{"intent":"reject","feedback":"좀 더 화려하게"}'}]})
|
||||||
|
)
|
||||||
|
intent, fb = ci.classify("음... 좀 더 화려한 분위기가 좋겠어")
|
||||||
|
assert intent == "reject"
|
||||||
|
assert "화려하게" in fb
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_text_returns_unclear():
|
||||||
|
assert ci.classify("") == ("unclear", None)
|
||||||
|
assert ci.classify(None) == ("unclear", None)
|
||||||
110
agent-office/tests/test_pipeline_polling.py
Normal file
110
agent-office/tests/test_pipeline_polling.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
_fd, _TMP = tempfile.mkstemp(suffix=".db")
|
||||||
|
os.close(_fd)
|
||||||
|
os.unlink(_TMP)
|
||||||
|
os.environ["AGENT_OFFICE_DB_PATH"] = _TMP
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _init_db():
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
if os.path.exists(_TMP):
|
||||||
|
os.remove(_TMP)
|
||||||
|
from app.db import init_db
|
||||||
|
init_db()
|
||||||
|
yield
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_poll_notifies_once_per_state():
|
||||||
|
from app.agents.youtube_publisher import YoutubePublisherAgent
|
||||||
|
|
||||||
|
pipelines = [{
|
||||||
|
"id": 1,
|
||||||
|
"state": "cover_pending",
|
||||||
|
"cover_url": "/x.jpg",
|
||||||
|
"track_title": "Test",
|
||||||
|
}]
|
||||||
|
with patch(
|
||||||
|
"app.agents.youtube_publisher.service_proxy.list_active_pipelines",
|
||||||
|
new=AsyncMock(return_value=pipelines),
|
||||||
|
), patch(
|
||||||
|
"app.agents.youtube_publisher.send_raw",
|
||||||
|
new=AsyncMock(return_value={"ok": True, "message_id": 99}),
|
||||||
|
) as mock_send, patch(
|
||||||
|
"app.agents.youtube_publisher.service_proxy.save_pipeline_telegram_msg",
|
||||||
|
new=AsyncMock(),
|
||||||
|
):
|
||||||
|
a = YoutubePublisherAgent()
|
||||||
|
await a.poll_state_changes()
|
||||||
|
await a.poll_state_changes() # 같은 상태 — 두 번째는 알림 안 함
|
||||||
|
assert mock_send.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_telegram_reply_approve_calls_feedback():
|
||||||
|
from app.agents.youtube_publisher import YoutubePublisherAgent
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.agents.youtube_publisher.service_proxy.post_pipeline_feedback",
|
||||||
|
new=AsyncMock(),
|
||||||
|
) as mock_fb, patch(
|
||||||
|
"app.agents.youtube_publisher.send_raw",
|
||||||
|
new=AsyncMock(),
|
||||||
|
):
|
||||||
|
a = YoutubePublisherAgent()
|
||||||
|
await a.on_telegram_reply(pipeline_id=42, step="cover", user_text="승인")
|
||||||
|
mock_fb.assert_called_once_with(42, "cover", "approve", None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_telegram_reply_reject_with_feedback():
|
||||||
|
from app.agents.youtube_publisher import YoutubePublisherAgent
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.agents.youtube_publisher.service_proxy.post_pipeline_feedback",
|
||||||
|
new=AsyncMock(),
|
||||||
|
) as mock_fb, patch(
|
||||||
|
"app.agents.youtube_publisher.send_raw",
|
||||||
|
new=AsyncMock(),
|
||||||
|
):
|
||||||
|
a = YoutubePublisherAgent()
|
||||||
|
await a.on_telegram_reply(pipeline_id=43, step="meta", user_text="반려, 제목 짧게")
|
||||||
|
args = mock_fb.call_args[0]
|
||||||
|
assert args[0] == 43
|
||||||
|
assert args[1] == "meta"
|
||||||
|
assert args[2] == "reject"
|
||||||
|
assert "제목 짧게" in (args[3] or "")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_telegram_reply_unclear_asks_again():
|
||||||
|
from app.agents.youtube_publisher import YoutubePublisherAgent
|
||||||
|
|
||||||
|
sent = []
|
||||||
|
|
||||||
|
async def mock_send(text=None, **kw):
|
||||||
|
sent.append(text)
|
||||||
|
return {"ok": True, "message_id": 1}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.agents.youtube_publisher.send_raw",
|
||||||
|
new=mock_send,
|
||||||
|
), patch(
|
||||||
|
"app.agents.youtube_publisher.classify_intent.classify",
|
||||||
|
return_value=("unclear", None),
|
||||||
|
):
|
||||||
|
a = YoutubePublisherAgent()
|
||||||
|
await a.on_telegram_reply(pipeline_id=44, step="cover", user_text="huh?")
|
||||||
|
assert any("다시 입력" in (s or "") for s in sent)
|
||||||
@@ -66,6 +66,12 @@ services:
|
|||||||
- CORS_ALLOW_ORIGINS=${CORS_ALLOW_ORIGINS:-http://localhost:3007,http://localhost:8080}
|
- CORS_ALLOW_ORIGINS=${CORS_ALLOW_ORIGINS:-http://localhost:3007,http://localhost:8080}
|
||||||
- PEXELS_API_KEY=${PEXELS_API_KEY:-}
|
- PEXELS_API_KEY=${PEXELS_API_KEY:-}
|
||||||
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY:-}
|
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY:-}
|
||||||
|
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
|
||||||
|
- YOUTUBE_OAUTH_CLIENT_ID=${YOUTUBE_OAUTH_CLIENT_ID:-}
|
||||||
|
- YOUTUBE_OAUTH_CLIENT_SECRET=${YOUTUBE_OAUTH_CLIENT_SECRET:-}
|
||||||
|
- YOUTUBE_OAUTH_REDIRECT_URI=${YOUTUBE_OAUTH_REDIRECT_URI:-}
|
||||||
|
- CLAUDE_HAIKU_MODEL=${CLAUDE_HAIKU_MODEL:-claude-haiku-4-5-20251001}
|
||||||
|
- CLAUDE_SONNET_MODEL=${CLAUDE_SONNET_MODEL:-claude-sonnet-4-6}
|
||||||
- VIDEO_DATA_DIR=${VIDEO_DATA_DIR:-/app/data/videos}
|
- VIDEO_DATA_DIR=${VIDEO_DATA_DIR:-/app/data/videos}
|
||||||
volumes:
|
volumes:
|
||||||
- ${RUNTIME_PATH}/data/music:/app/data
|
- ${RUNTIME_PATH}/data/music:/app/data
|
||||||
@@ -137,6 +143,8 @@ services:
|
|||||||
- TELEGRAM_WEBHOOK_URL=${TELEGRAM_WEBHOOK_URL:-}
|
- TELEGRAM_WEBHOOK_URL=${TELEGRAM_WEBHOOK_URL:-}
|
||||||
- TELEGRAM_WIFE_CHAT_ID=${TELEGRAM_WIFE_CHAT_ID:-}
|
- TELEGRAM_WIFE_CHAT_ID=${TELEGRAM_WIFE_CHAT_ID:-}
|
||||||
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY:-}
|
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY:-}
|
||||||
|
- CLAUDE_HAIKU_MODEL=${CLAUDE_HAIKU_MODEL:-claude-haiku-4-5-20251001}
|
||||||
|
- CLAUDE_SONNET_MODEL=${CLAUDE_SONNET_MODEL:-claude-sonnet-4-6}
|
||||||
- LOTTO_BACKEND_URL=${LOTTO_BACKEND_URL:-http://lotto:8000}
|
- LOTTO_BACKEND_URL=${LOTTO_BACKEND_URL:-http://lotto:8000}
|
||||||
- LOTTO_CURATOR_MODEL=${LOTTO_CURATOR_MODEL:-claude-sonnet-4-5}
|
- LOTTO_CURATOR_MODEL=${LOTTO_CURATOR_MODEL:-claude-sonnet-4-5}
|
||||||
- CONVERSATION_MODEL=${CONVERSATION_MODEL:-claude-haiku-4-5-20251001}
|
- CONVERSATION_MODEL=${CONVERSATION_MODEL:-claude-haiku-4-5-20251001}
|
||||||
|
|||||||
@@ -1,7 +1,13 @@
|
|||||||
FROM python:3.12-alpine
|
FROM python:3.12-alpine
|
||||||
ENV PYTHONUNBUFFERED=1
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
RUN apk add --no-cache ffmpeg
|
# ffmpeg for audio/video processing, ttf-dejavu + fontconfig for PIL overlays.
|
||||||
|
# Alpine installs DejaVu fonts to /usr/share/fonts/dejavu/, but app code
|
||||||
|
# references the Debian-style path; symlink for compatibility.
|
||||||
|
RUN apk add --no-cache ffmpeg ttf-dejavu fontconfig \
|
||||||
|
&& mkdir -p /usr/share/fonts/truetype \
|
||||||
|
&& ln -sf /usr/share/fonts/dejavu /usr/share/fonts/truetype/dejavu \
|
||||||
|
&& fc-cache -f
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import json
|
import json
|
||||||
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
DB_PATH = "/app/data/music.db"
|
DB_PATH = "/app/data/music.db"
|
||||||
@@ -184,6 +185,112 @@ def init_db() -> None:
|
|||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
# ── YouTube pipeline 테이블 (5개) ─────────────────────────────────
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS video_pipelines (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
track_id INTEGER NOT NULL,
|
||||||
|
state TEXT NOT NULL DEFAULT 'created',
|
||||||
|
state_started_at TEXT NOT NULL,
|
||||||
|
cover_url TEXT,
|
||||||
|
video_url TEXT,
|
||||||
|
thumbnail_url TEXT,
|
||||||
|
metadata_json TEXT,
|
||||||
|
review_json TEXT,
|
||||||
|
youtube_video_id TEXT,
|
||||||
|
feedback_count_per_step TEXT NOT NULL DEFAULT '{}',
|
||||||
|
last_telegram_msg_ids TEXT NOT NULL DEFAULT '{}',
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
updated_at TEXT NOT NULL,
|
||||||
|
cancelled_at TEXT,
|
||||||
|
failed_reason TEXT
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS pipeline_jobs (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
pipeline_id INTEGER NOT NULL,
|
||||||
|
step TEXT NOT NULL,
|
||||||
|
status TEXT NOT NULL DEFAULT 'queued',
|
||||||
|
error TEXT,
|
||||||
|
started_at TEXT,
|
||||||
|
finished_at TEXT,
|
||||||
|
duration_ms INTEGER,
|
||||||
|
FOREIGN KEY (pipeline_id) REFERENCES video_pipelines(id)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS pipeline_feedback (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
pipeline_id INTEGER NOT NULL,
|
||||||
|
step TEXT NOT NULL,
|
||||||
|
feedback_text TEXT NOT NULL,
|
||||||
|
received_at TEXT NOT NULL,
|
||||||
|
FOREIGN KEY (pipeline_id) REFERENCES video_pipelines(id)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS youtube_oauth_tokens (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
channel_id TEXT NOT NULL,
|
||||||
|
channel_title TEXT,
|
||||||
|
avatar_url TEXT,
|
||||||
|
refresh_token TEXT NOT NULL,
|
||||||
|
access_token TEXT,
|
||||||
|
expires_at TEXT,
|
||||||
|
created_at TEXT NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS youtube_setup (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT CHECK (id = 1),
|
||||||
|
metadata_template_json TEXT NOT NULL,
|
||||||
|
cover_prompts_json TEXT NOT NULL,
|
||||||
|
review_weights_json TEXT NOT NULL,
|
||||||
|
review_threshold INTEGER NOT NULL DEFAULT 60,
|
||||||
|
visual_defaults_json TEXT NOT NULL,
|
||||||
|
publish_policy_json TEXT NOT NULL,
|
||||||
|
updated_at TEXT NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# 기본 setup 1행 보장
|
||||||
|
cnt_row = conn.execute("SELECT COUNT(*) AS c FROM youtube_setup").fetchone()
|
||||||
|
if cnt_row["c"] == 0:
|
||||||
|
_seed_default_youtube_setup(conn)
|
||||||
|
|
||||||
|
|
||||||
|
def _seed_default_youtube_setup(conn: sqlite3.Connection) -> None:
|
||||||
|
"""youtube_setup 테이블에 기본 1행을 삽입한다.
|
||||||
|
|
||||||
|
init_db()와 get_youtube_setup() (행이 사라진 경우 self-heal)가 공유한다.
|
||||||
|
"""
|
||||||
|
defaults = (
|
||||||
|
json.dumps({
|
||||||
|
"title": "[{genre}] {title} | {bpm}BPM",
|
||||||
|
"description": "{title}\n\n장르: {genre}\nBPM: {bpm}\nKey: {key}\n",
|
||||||
|
"tags": ["lofi", "chill", "instrumental"],
|
||||||
|
"category_id": 10,
|
||||||
|
}),
|
||||||
|
json.dumps({
|
||||||
|
"lo-fi": "moody anime cityscape at dusk, lofi aesthetic",
|
||||||
|
"phonk": "dark drift car aesthetic, neon, phonk vibe",
|
||||||
|
"ambient": "ethereal mountain landscape, ambient mood",
|
||||||
|
"default": "abstract music album cover art",
|
||||||
|
}),
|
||||||
|
json.dumps({"meta": 25, "policy": 30, "viewer": 25, "trend": 20}),
|
||||||
|
60,
|
||||||
|
json.dumps({"resolution": "1920x1080", "style": "visualizer", "background": "ai_cover"}),
|
||||||
|
json.dumps({"mode": "manual", "privacy": "private", "schedule_time": None}),
|
||||||
|
datetime.utcnow().isoformat(timespec="seconds"),
|
||||||
|
)
|
||||||
|
conn.execute("""
|
||||||
|
INSERT INTO youtube_setup
|
||||||
|
(metadata_template_json, cover_prompts_json, review_weights_json,
|
||||||
|
review_threshold, visual_defaults_json, publish_policy_json, updated_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""", defaults)
|
||||||
|
|
||||||
|
|
||||||
# ── music_tasks CRUD ──────────────────────────────────────────────────────────
|
# ── music_tasks CRUD ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -791,3 +898,253 @@ def update_compile_job(job_id: int, **kwargs) -> None:
|
|||||||
def delete_compile_job(job_id: int) -> None:
|
def delete_compile_job(job_id: int) -> None:
|
||||||
with _conn() as conn:
|
with _conn() as conn:
|
||||||
conn.execute("DELETE FROM compile_jobs WHERE id = ?", (job_id,))
|
conn.execute("DELETE FROM compile_jobs WHERE id = ?", (job_id,))
|
||||||
|
|
||||||
|
|
||||||
|
# ── YouTube pipeline helpers ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# update_pipeline_state: state/state_started_at/updated_at은 자동, 그 외 허용 컬럼 화이트리스트
|
||||||
|
_PIPELINE_STATE_EXTRA_COLS = frozenset({
|
||||||
|
"cover_url",
|
||||||
|
"video_url",
|
||||||
|
"thumbnail_url",
|
||||||
|
"metadata_json",
|
||||||
|
"review_json",
|
||||||
|
"youtube_video_id",
|
||||||
|
"cancelled_at",
|
||||||
|
"failed_reason",
|
||||||
|
"last_telegram_msg_ids",
|
||||||
|
"feedback_count_per_step",
|
||||||
|
})
|
||||||
|
|
||||||
|
# update_pipeline_job 허용 컬럼 화이트리스트
|
||||||
|
_PIPELINE_JOB_COLS = frozenset({
|
||||||
|
"status",
|
||||||
|
"error",
|
||||||
|
"duration_ms",
|
||||||
|
"started_at",
|
||||||
|
"finished_at",
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def _now() -> str:
|
||||||
|
return datetime.utcnow().isoformat(timespec="seconds")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_pipeline_row(row: sqlite3.Row) -> Dict[str, Any]:
|
||||||
|
"""video_pipelines의 sqlite3.Row를 dict로 파싱.
|
||||||
|
|
||||||
|
JSON 컬럼을 디코드하고, metadata/review를 호환을 위해 추가로 노출한다.
|
||||||
|
get_pipeline / list_pipelines가 공유.
|
||||||
|
"""
|
||||||
|
d = dict(row)
|
||||||
|
d["feedback_count_per_step"] = json.loads(d["feedback_count_per_step"] or "{}")
|
||||||
|
d["last_telegram_msg_ids"] = json.loads(d["last_telegram_msg_ids"] or "{}")
|
||||||
|
if d.get("metadata_json"):
|
||||||
|
d["metadata"] = json.loads(d["metadata_json"])
|
||||||
|
if d.get("review_json"):
|
||||||
|
d["review"] = json.loads(d["review_json"])
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def create_pipeline(track_id: int) -> int:
|
||||||
|
with _conn() as conn:
|
||||||
|
now = _now()
|
||||||
|
cur = conn.execute("""
|
||||||
|
INSERT INTO video_pipelines (track_id, state, state_started_at, created_at, updated_at)
|
||||||
|
VALUES (?, 'created', ?, ?, ?)
|
||||||
|
""", (track_id, now, now, now))
|
||||||
|
return cur.lastrowid
|
||||||
|
|
||||||
|
|
||||||
|
def get_pipeline(pid: int) -> Optional[Dict[str, Any]]:
|
||||||
|
with _conn() as conn:
|
||||||
|
row = conn.execute("""
|
||||||
|
SELECT vp.*, ml.title AS track_title
|
||||||
|
FROM video_pipelines vp
|
||||||
|
LEFT JOIN music_library ml ON ml.id = vp.track_id
|
||||||
|
WHERE vp.id = ?
|
||||||
|
""", (pid,)).fetchone()
|
||||||
|
if not row:
|
||||||
|
return None
|
||||||
|
return _parse_pipeline_row(row)
|
||||||
|
|
||||||
|
|
||||||
|
def update_pipeline_state(pid: int, state: str, **fields) -> None:
|
||||||
|
"""파이프라인 state를 갱신하고 옵션 컬럼을 함께 업데이트한다.
|
||||||
|
|
||||||
|
허용 컬럼 화이트리스트(_PIPELINE_STATE_EXTRA_COLS)에 없는 키는 ValueError.
|
||||||
|
"""
|
||||||
|
unknown = set(fields) - _PIPELINE_STATE_EXTRA_COLS
|
||||||
|
if unknown:
|
||||||
|
raise ValueError(f"unknown columns for update_pipeline_state: {sorted(unknown)}")
|
||||||
|
|
||||||
|
now = _now()
|
||||||
|
cols = ["state = ?", "state_started_at = ?", "updated_at = ?"]
|
||||||
|
vals: List[Any] = [state, now, now]
|
||||||
|
for k, v in fields.items():
|
||||||
|
cols.append(f"{k} = ?")
|
||||||
|
vals.append(v)
|
||||||
|
vals.append(pid)
|
||||||
|
with _conn() as conn:
|
||||||
|
conn.execute(f"UPDATE video_pipelines SET {', '.join(cols)} WHERE id = ?", vals)
|
||||||
|
|
||||||
|
|
||||||
|
def list_pipelines(active_only: bool = False) -> List[Dict[str, Any]]:
|
||||||
|
sql = """
|
||||||
|
SELECT vp.*, ml.title AS track_title
|
||||||
|
FROM video_pipelines vp
|
||||||
|
LEFT JOIN music_library ml ON ml.id = vp.track_id
|
||||||
|
"""
|
||||||
|
if active_only:
|
||||||
|
sql += " WHERE vp.state NOT IN ('published','cancelled','failed','awaiting_manual')"
|
||||||
|
sql += " ORDER BY vp.created_at DESC"
|
||||||
|
with _conn() as conn:
|
||||||
|
rows = conn.execute(sql).fetchall()
|
||||||
|
return [_parse_pipeline_row(r) for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
def increment_feedback_count(pid: int, step: str) -> int:
|
||||||
|
"""원자적으로 feedback_count_per_step.<step>를 +1 한 뒤 새 값을 반환.
|
||||||
|
|
||||||
|
json1 확장(SQLite 3.38+)을 사용해 read-modify-write 경합을 제거한다.
|
||||||
|
"""
|
||||||
|
now = _now()
|
||||||
|
with _conn() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE video_pipelines
|
||||||
|
SET feedback_count_per_step = json_set(
|
||||||
|
feedback_count_per_step,
|
||||||
|
'$.' || ?,
|
||||||
|
COALESCE(json_extract(feedback_count_per_step, '$.' || ?), 0) + 1
|
||||||
|
),
|
||||||
|
updated_at = ?
|
||||||
|
WHERE id = ?
|
||||||
|
""",
|
||||||
|
(step, step, now, pid),
|
||||||
|
)
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT json_extract(feedback_count_per_step, '$.' || ?) AS c "
|
||||||
|
"FROM video_pipelines WHERE id = ?",
|
||||||
|
(step, pid),
|
||||||
|
).fetchone()
|
||||||
|
return int(row["c"]) if row and row["c"] is not None else 0
|
||||||
|
|
||||||
|
|
||||||
|
def record_feedback(pid: int, step: str, feedback_text: str) -> None:
|
||||||
|
with _conn() as conn:
|
||||||
|
conn.execute("""
|
||||||
|
INSERT INTO pipeline_feedback (pipeline_id, step, feedback_text, received_at)
|
||||||
|
VALUES (?, ?, ?, ?)
|
||||||
|
""", (pid, step, feedback_text, _now()))
|
||||||
|
|
||||||
|
|
||||||
|
def get_feedback_history(pid: int) -> List[Dict[str, Any]]:
|
||||||
|
with _conn() as conn:
|
||||||
|
rows = conn.execute("""
|
||||||
|
SELECT * FROM pipeline_feedback
|
||||||
|
WHERE pipeline_id = ? ORDER BY id DESC
|
||||||
|
""", (pid,)).fetchall()
|
||||||
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
def create_pipeline_job(pid: int, step: str) -> int:
|
||||||
|
with _conn() as conn:
|
||||||
|
cur = conn.execute("""
|
||||||
|
INSERT INTO pipeline_jobs (pipeline_id, step, status, started_at)
|
||||||
|
VALUES (?, ?, 'queued', ?)
|
||||||
|
""", (pid, step, _now()))
|
||||||
|
return cur.lastrowid
|
||||||
|
|
||||||
|
|
||||||
|
def update_pipeline_job(job_id: int, **fields) -> None:
|
||||||
|
"""pipeline_jobs 행을 갱신. 허용 컬럼 화이트리스트 외 키는 ValueError.
|
||||||
|
|
||||||
|
status가 succeeded/failed로 바뀌면 finished_at을 자동 설정 (호출자가 미지정 시).
|
||||||
|
"""
|
||||||
|
unknown = set(fields) - _PIPELINE_JOB_COLS
|
||||||
|
if unknown:
|
||||||
|
raise ValueError(f"unknown columns for update_pipeline_job: {sorted(unknown)}")
|
||||||
|
if not fields:
|
||||||
|
return
|
||||||
|
|
||||||
|
if (
|
||||||
|
fields.get("status") in ("succeeded", "failed")
|
||||||
|
and "finished_at" not in fields
|
||||||
|
):
|
||||||
|
fields["finished_at"] = _now()
|
||||||
|
|
||||||
|
cols = ", ".join(f"{k} = ?" for k in fields)
|
||||||
|
vals = list(fields.values()) + [job_id]
|
||||||
|
with _conn() as conn:
|
||||||
|
conn.execute(f"UPDATE pipeline_jobs SET {cols} WHERE id = ?", vals)
|
||||||
|
|
||||||
|
|
||||||
|
def list_pipeline_jobs(pid: int) -> List[Dict[str, Any]]:
|
||||||
|
with _conn() as conn:
|
||||||
|
rows = conn.execute("""
|
||||||
|
SELECT * FROM pipeline_jobs WHERE pipeline_id = ? ORDER BY id ASC
|
||||||
|
""", (pid,)).fetchall()
|
||||||
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
def get_youtube_setup() -> Dict[str, Any]:
|
||||||
|
"""youtube_setup의 기본 1행을 반환. 누락 시 자동 시드 후 재조회."""
|
||||||
|
with _conn() as conn:
|
||||||
|
row = conn.execute("SELECT * FROM youtube_setup WHERE id = 1").fetchone()
|
||||||
|
if row is None:
|
||||||
|
_seed_default_youtube_setup(conn)
|
||||||
|
row = conn.execute("SELECT * FROM youtube_setup WHERE id = 1").fetchone()
|
||||||
|
d = dict(row)
|
||||||
|
for k in ("metadata_template_json", "cover_prompts_json",
|
||||||
|
"review_weights_json", "visual_defaults_json", "publish_policy_json"):
|
||||||
|
d[k.replace("_json", "")] = json.loads(d[k])
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def update_youtube_setup(**kwargs) -> None:
|
||||||
|
field_map = {
|
||||||
|
"metadata_template": "metadata_template_json",
|
||||||
|
"cover_prompts": "cover_prompts_json",
|
||||||
|
"review_weights": "review_weights_json",
|
||||||
|
"visual_defaults": "visual_defaults_json",
|
||||||
|
"publish_policy": "publish_policy_json",
|
||||||
|
}
|
||||||
|
cols = []
|
||||||
|
vals: List[Any] = []
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if k in field_map:
|
||||||
|
cols.append(f"{field_map[k]} = ?")
|
||||||
|
vals.append(json.dumps(v))
|
||||||
|
elif k == "review_threshold":
|
||||||
|
cols.append("review_threshold = ?")
|
||||||
|
vals.append(int(v))
|
||||||
|
if not cols:
|
||||||
|
return
|
||||||
|
cols.append("updated_at = ?")
|
||||||
|
vals.append(_now())
|
||||||
|
with _conn() as conn:
|
||||||
|
conn.execute(f"UPDATE youtube_setup SET {', '.join(cols)} WHERE id = 1", vals)
|
||||||
|
|
||||||
|
|
||||||
|
def upsert_oauth_token(channel_id: str, channel_title: Optional[str],
|
||||||
|
avatar_url: Optional[str], refresh_token: str,
|
||||||
|
access_token: Optional[str], expires_at: Optional[str]) -> None:
|
||||||
|
with _conn() as conn:
|
||||||
|
conn.execute("DELETE FROM youtube_oauth_tokens")
|
||||||
|
conn.execute("""
|
||||||
|
INSERT INTO youtube_oauth_tokens
|
||||||
|
(channel_id, channel_title, avatar_url, refresh_token, access_token, expires_at, created_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""", (channel_id, channel_title, avatar_url, refresh_token, access_token, expires_at, _now()))
|
||||||
|
|
||||||
|
|
||||||
|
def get_oauth_token() -> Optional[Dict[str, Any]]:
|
||||||
|
with _conn() as conn:
|
||||||
|
row = conn.execute("SELECT * FROM youtube_oauth_tokens ORDER BY id DESC LIMIT 1").fetchone()
|
||||||
|
return dict(row) if row else None
|
||||||
|
|
||||||
|
|
||||||
|
def delete_oauth_token() -> None:
|
||||||
|
with _conn() as conn:
|
||||||
|
conn.execute("DELETE FROM youtube_oauth_tokens")
|
||||||
|
|||||||
@@ -22,9 +22,12 @@ from .db import (
|
|||||||
create_compile_job, get_compile_jobs, get_compile_job,
|
create_compile_job, get_compile_jobs, get_compile_job,
|
||||||
update_compile_job, delete_compile_job,
|
update_compile_job, delete_compile_job,
|
||||||
)
|
)
|
||||||
|
from . import db as _db_module
|
||||||
from .compiler import run_compile
|
from .compiler import run_compile
|
||||||
from .market import ingest_trends, get_suggestions
|
from .market import ingest_trends, get_suggestions
|
||||||
from .local_provider import run_local_generation
|
from .local_provider import run_local_generation
|
||||||
|
from .pipeline import orchestrator
|
||||||
|
from .pipeline import youtube as yt_module
|
||||||
from .suno_provider import (
|
from .suno_provider import (
|
||||||
run_suno_generation, run_suno_extend, run_vocal_removal,
|
run_suno_generation, run_suno_extend, run_vocal_removal,
|
||||||
run_cover_image, run_wav_convert, run_stem_split,
|
run_cover_image, run_wav_convert, run_stem_split,
|
||||||
@@ -921,3 +924,194 @@ def list_market_reports(limit: int = 10):
|
|||||||
@app.get("/api/music/market/suggest")
|
@app.get("/api/music/market/suggest")
|
||||||
def market_suggest(limit: int = 5):
|
def market_suggest(limit: int = 5):
|
||||||
return {"suggestions": get_suggestions(limit)}
|
return {"suggestions": get_suggestions(limit)}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Pipeline endpoints ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class PipelineCreate(BaseModel):
|
||||||
|
track_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackRequest(BaseModel):
|
||||||
|
step: str
|
||||||
|
intent: str # approve | reject
|
||||||
|
feedback_text: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/music/pipeline", status_code=201)
|
||||||
|
def create_pipeline(req: PipelineCreate):
|
||||||
|
actives = _db_module.list_pipelines(active_only=True)
|
||||||
|
if any(p["track_id"] == req.track_id for p in actives):
|
||||||
|
raise HTTPException(409, "이미 진행 중인 파이프라인이 있습니다")
|
||||||
|
pid = _db_module.create_pipeline(req.track_id)
|
||||||
|
return _db_module.get_pipeline(pid)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/music/pipeline")
|
||||||
|
def list_pipelines_endpoint(status: str = "all"):
|
||||||
|
pipelines = _db_module.list_pipelines(active_only=(status == "active"))
|
||||||
|
return {"pipelines": pipelines}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/music/pipeline/lookup-by-msg/{msg_id}")
|
||||||
|
def lookup_by_msg(msg_id: int):
|
||||||
|
for p in _db_module.list_pipelines(active_only=True):
|
||||||
|
for step, mid in p["last_telegram_msg_ids"].items():
|
||||||
|
if mid == msg_id:
|
||||||
|
return {"pipeline_id": p["id"], "step": step}
|
||||||
|
raise HTTPException(404)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/music/pipeline/{pid}")
|
||||||
|
def get_pipeline_endpoint(pid: int):
|
||||||
|
p = _db_module.get_pipeline(pid)
|
||||||
|
if not p:
|
||||||
|
raise HTTPException(404)
|
||||||
|
p["jobs"] = _db_module.list_pipeline_jobs(pid)
|
||||||
|
p["feedback"] = _db_module.get_feedback_history(pid)
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/music/pipeline/{pid}/start", status_code=202)
|
||||||
|
async def start_pipeline(pid: int, bg: BackgroundTasks):
|
||||||
|
p = _db_module.get_pipeline(pid)
|
||||||
|
if not p:
|
||||||
|
raise HTTPException(404)
|
||||||
|
if p["state"] != "created":
|
||||||
|
raise HTTPException(409, f"이미 시작됨 ({p['state']})")
|
||||||
|
bg.add_task(orchestrator.run_step, pid, "cover")
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
def _state_to_step(state: str) -> Optional[str]:
|
||||||
|
return {
|
||||||
|
"video_pending": "video",
|
||||||
|
"thumb_pending": "thumb",
|
||||||
|
"meta_pending": "meta",
|
||||||
|
"ai_review": "review",
|
||||||
|
"publish_pending": None, # 사용자 명시 발행 호출 필요
|
||||||
|
"publishing": "publish",
|
||||||
|
}.get(state)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/music/pipeline/{pid}/feedback", status_code=202)
|
||||||
|
async def feedback(pid: int, req: FeedbackRequest, bg: BackgroundTasks):
|
||||||
|
p = _db_module.get_pipeline(pid)
|
||||||
|
if not p:
|
||||||
|
raise HTTPException(404)
|
||||||
|
if p["state"] == "awaiting_manual":
|
||||||
|
raise HTTPException(409, "수동 개입 대기 중")
|
||||||
|
state = p["state"]
|
||||||
|
expected = f"{req.step}_pending"
|
||||||
|
if state != expected:
|
||||||
|
# 멱등 처리 — 이미 다음 단계로 넘어갔으면 무시
|
||||||
|
return {"ok": True, "skipped": True}
|
||||||
|
|
||||||
|
if req.intent == "approve":
|
||||||
|
from .pipeline.state_machine import next_state_on_approve
|
||||||
|
next_st = next_state_on_approve(state)
|
||||||
|
_db_module.update_pipeline_state(pid, next_st)
|
||||||
|
next_step = _state_to_step(next_st)
|
||||||
|
if next_step:
|
||||||
|
bg.add_task(orchestrator.run_step, pid, next_step)
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
elif req.intent == "reject":
|
||||||
|
count = _db_module.increment_feedback_count(pid, req.step)
|
||||||
|
if count > 5:
|
||||||
|
_db_module.update_pipeline_state(pid, "awaiting_manual")
|
||||||
|
raise HTTPException(409, "재생성 한도 초과")
|
||||||
|
if req.feedback_text:
|
||||||
|
_db_module.record_feedback(pid, req.step, req.feedback_text)
|
||||||
|
bg.add_task(orchestrator.run_step, pid, req.step, req.feedback_text or "")
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise HTTPException(400, f"unknown intent: {req.intent}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/music/pipeline/{pid}/cancel")
|
||||||
|
def cancel_pipeline(pid: int):
|
||||||
|
p = _db_module.get_pipeline(pid)
|
||||||
|
if not p:
|
||||||
|
raise HTTPException(404)
|
||||||
|
_db_module.update_pipeline_state(pid, "cancelled", cancelled_at=_db_module._now())
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/music/pipeline/{pid}/publish", status_code=202)
|
||||||
|
async def publish_pipeline(pid: int, bg: BackgroundTasks):
|
||||||
|
p = _db_module.get_pipeline(pid)
|
||||||
|
if not p:
|
||||||
|
raise HTTPException(404)
|
||||||
|
if p["state"] != "publish_pending":
|
||||||
|
raise HTTPException(409, f"발행 단계 아님 ({p['state']})")
|
||||||
|
_db_module.update_pipeline_state(pid, "publishing")
|
||||||
|
bg.add_task(orchestrator.run_step, pid, "publish")
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# Telegram 메시지 매칭용 엔드포인트 (agent-office용)
|
||||||
|
|
||||||
|
class TelegramMsgPatch(BaseModel):
|
||||||
|
step: str
|
||||||
|
message_id: int
|
||||||
|
|
||||||
|
|
||||||
|
@app.patch("/api/music/pipeline/{pid}/telegram-msg")
|
||||||
|
def save_telegram_msg(pid: int, req: TelegramMsgPatch):
|
||||||
|
p = _db_module.get_pipeline(pid)
|
||||||
|
if not p:
|
||||||
|
raise HTTPException(404)
|
||||||
|
ids = p["last_telegram_msg_ids"]
|
||||||
|
ids[req.step] = req.message_id
|
||||||
|
_db_module.update_pipeline_state(
|
||||||
|
pid, p["state"], last_telegram_msg_ids=json.dumps(ids)
|
||||||
|
)
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Setup endpoints ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class SetupRequest(BaseModel):
|
||||||
|
metadata_template: Optional[Dict[str, Any]] = None
|
||||||
|
cover_prompts: Optional[Dict[str, Any]] = None
|
||||||
|
review_weights: Optional[Dict[str, Any]] = None
|
||||||
|
review_threshold: Optional[int] = None
|
||||||
|
visual_defaults: Optional[Dict[str, Any]] = None
|
||||||
|
publish_policy: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/music/setup")
|
||||||
|
def get_setup():
|
||||||
|
return _db_module.get_youtube_setup()
|
||||||
|
|
||||||
|
|
||||||
|
@app.put("/api/music/setup")
|
||||||
|
def put_setup(req: SetupRequest):
|
||||||
|
payload = {k: v for k, v in req.dict().items() if v is not None}
|
||||||
|
_db_module.update_youtube_setup(**payload)
|
||||||
|
return _db_module.get_youtube_setup()
|
||||||
|
|
||||||
|
|
||||||
|
# ── YouTube OAuth endpoints ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@app.get("/api/music/youtube/auth-url")
|
||||||
|
def youtube_auth_url():
|
||||||
|
return {"url": yt_module.get_auth_url()}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/music/youtube/callback")
|
||||||
|
async def youtube_callback(code: str):
|
||||||
|
return await yt_module.exchange_code(code)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/music/youtube/disconnect")
|
||||||
|
def youtube_disconnect():
|
||||||
|
yt_module.disconnect()
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/music/youtube/status")
|
||||||
|
def youtube_status():
|
||||||
|
return yt_module.get_status() or {"connected": False}
|
||||||
|
|||||||
0
music-lab/app/pipeline/__init__.py
Normal file
0
music-lab/app/pipeline/__init__.py
Normal file
88
music-lab/app/pipeline/cover.py
Normal file
88
music-lab/app/pipeline/cover.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
"""AI 커버 아트 생성 — DALL·E 3 / gpt-image-1 + 그라데이션 폴백."""
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from . import storage
|
||||||
|
from .gradient import make_gradient_with_title
|
||||||
|
|
||||||
|
logger = logging.getLogger("music-lab.cover")
|
||||||
|
|
||||||
|
DALLE_TIMEOUT_S = 90
|
||||||
|
|
||||||
|
|
||||||
|
def _get_api_key() -> str:
|
||||||
|
return os.getenv("OPENAI_API_KEY", "")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model() -> str:
|
||||||
|
return os.getenv("OPENAI_IMAGE_MODEL", "gpt-image-1")
|
||||||
|
|
||||||
|
|
||||||
|
async def generate(*, pipeline_id: int, genre: str, prompt_template: str,
|
||||||
|
mood: str = "", track_title: str = "", feedback: str = "") -> dict:
|
||||||
|
"""커버 아트 생성. 성공 시 jpg 저장 + URL 반환. 실패 시 그라데이션 폴백.
|
||||||
|
|
||||||
|
반환: {"url": str, "used_fallback": bool, "error": str | None}
|
||||||
|
"""
|
||||||
|
out_path = os.path.join(storage.pipeline_dir(pipeline_id), "cover.jpg")
|
||||||
|
used_fallback = False
|
||||||
|
error = None
|
||||||
|
|
||||||
|
api_key = _get_api_key()
|
||||||
|
model = _get_model()
|
||||||
|
if api_key:
|
||||||
|
try:
|
||||||
|
await _generate_with_dalle(prompt_template, mood, feedback, out_path,
|
||||||
|
api_key=api_key, model=model)
|
||||||
|
except (httpx.HTTPError, httpx.TimeoutException, KeyError, ValueError, OSError) as e:
|
||||||
|
logger.warning("DALL·E 실패 — 폴백: %s", e)
|
||||||
|
error = str(e)
|
||||||
|
used_fallback = True
|
||||||
|
make_gradient_with_title(genre, track_title, out_path)
|
||||||
|
else:
|
||||||
|
used_fallback = True
|
||||||
|
error = "OPENAI_API_KEY 미설정"
|
||||||
|
make_gradient_with_title(genre, track_title, out_path)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"url": storage.media_url(pipeline_id, "cover.jpg"),
|
||||||
|
"used_fallback": used_fallback,
|
||||||
|
"error": error,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _generate_with_dalle(prompt_template: str, mood: str,
|
||||||
|
feedback: str, out_path: str,
|
||||||
|
*, api_key: str, model: str) -> None:
|
||||||
|
prompt = prompt_template
|
||||||
|
if mood:
|
||||||
|
prompt = f"{prompt}, {mood} mood"
|
||||||
|
if feedback:
|
||||||
|
prompt = f"{prompt}. 추가 지시: {feedback}"
|
||||||
|
prompt = f"{prompt}, no text, high quality"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=DALLE_TIMEOUT_S) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
"https://api.openai.com/v1/images/generations",
|
||||||
|
headers={"Authorization": f"Bearer {api_key}"},
|
||||||
|
json={"model": model, "prompt": prompt, "size": "1024x1024", "n": 1},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()["data"][0]
|
||||||
|
if "url" in data:
|
||||||
|
img_resp = await client.get(data["url"])
|
||||||
|
img_resp.raise_for_status()
|
||||||
|
img_bytes = img_resp.content
|
||||||
|
elif "b64_json" in data:
|
||||||
|
img_bytes = base64.b64decode(data["b64_json"])
|
||||||
|
else:
|
||||||
|
raise ValueError("DALL·E response has neither url nor b64_json")
|
||||||
|
# PNG → JPG 변환
|
||||||
|
with Image.open(BytesIO(img_bytes)) as src:
|
||||||
|
img = src.convert("RGB")
|
||||||
|
img.save(out_path, "JPEG", quality=92)
|
||||||
38
music-lab/app/pipeline/gradient.py
Normal file
38
music-lab/app/pipeline/gradient.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""장르별 그라데이션 배경 + 텍스트 오버레이 — cover/video 공용."""
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
GENRE_COLORS = {
|
||||||
|
"lo-fi": ((26, 26, 46), (22, 33, 62)),
|
||||||
|
"phonk": ((26, 10, 10), (45, 0, 0)),
|
||||||
|
"ambient": ((13, 33, 55), (10, 22, 40)),
|
||||||
|
"pop": ((26, 10, 46), (45, 27, 78)),
|
||||||
|
"default": ((17, 24, 39), (31, 41, 55)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def make_gradient_with_title(genre: str, title: str, out_path: str,
|
||||||
|
size: tuple[int, int] = (1024, 1024),
|
||||||
|
quality: int = 92) -> None:
|
||||||
|
w, h = size
|
||||||
|
top, bot = GENRE_COLORS.get(genre.lower(), GENRE_COLORS["default"])
|
||||||
|
with Image.new("RGB", (w, h)) as img:
|
||||||
|
px = img.load()
|
||||||
|
for y in range(h):
|
||||||
|
t = y / h
|
||||||
|
r = int(top[0] + (bot[0] - top[0]) * t)
|
||||||
|
g = int(top[1] + (bot[1] - top[1]) * t)
|
||||||
|
b = int(top[2] + (bot[2] - top[2]) * t)
|
||||||
|
for x in range(w):
|
||||||
|
px[x, y] = (r, g, b)
|
||||||
|
|
||||||
|
if title:
|
||||||
|
try:
|
||||||
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 64)
|
||||||
|
except OSError:
|
||||||
|
font = ImageFont.load_default()
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
bbox = draw.textbbox((0, 0), title, font=font)
|
||||||
|
tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
||||||
|
draw.text(((w - tw) // 2, (h - th) // 2), title, fill=(255, 255, 255), font=font)
|
||||||
|
|
||||||
|
img.save(out_path, "JPEG", quality=quality)
|
||||||
95
music-lab/app/pipeline/metadata.py
Normal file
95
music-lab/app/pipeline/metadata.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
"""메타데이터 생성 — Claude Haiku + 템플릿 폴백."""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
logger = logging.getLogger("music-lab.metadata")
|
||||||
|
|
||||||
|
CLAUDE_HAIKU_MODEL_DEFAULT = "claude-haiku-4-5-20251001"
|
||||||
|
TIMEOUT_S = 30
|
||||||
|
|
||||||
|
|
||||||
|
def _get_api_key() -> str:
|
||||||
|
return os.getenv("ANTHROPIC_API_KEY", "")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model() -> str:
|
||||||
|
return os.getenv("CLAUDE_HAIKU_MODEL", CLAUDE_HAIKU_MODEL_DEFAULT)
|
||||||
|
|
||||||
|
|
||||||
|
async def generate(*, track: dict, template: dict, trend_keywords: list[str],
|
||||||
|
feedback: str = "") -> dict:
|
||||||
|
"""메타데이터 생성. 성공 시 LLM, 실패/미설정 시 템플릿 치환 폴백.
|
||||||
|
|
||||||
|
반환: {"title", "description", "tags", "category_id", "used_fallback", "error"}
|
||||||
|
"""
|
||||||
|
api_key = _get_api_key()
|
||||||
|
if not api_key:
|
||||||
|
return {**_fallback_template(track, template), "used_fallback": True, "error": "no api key"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await _call_claude(track, template, trend_keywords, feedback,
|
||||||
|
api_key=api_key, model=_get_model())
|
||||||
|
return {**result, "used_fallback": False, "error": None}
|
||||||
|
except (httpx.HTTPError, httpx.TimeoutException, KeyError, ValueError, json.JSONDecodeError) as e:
|
||||||
|
logger.warning("메타데이터 LLM 실패 — 폴백: %s", e)
|
||||||
|
return {**_fallback_template(track, template), "used_fallback": True, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
def _fallback_template(track: dict, template: dict) -> dict:
|
||||||
|
fmt_vars = {
|
||||||
|
"title": track.get("title", ""),
|
||||||
|
"genre": track.get("genre", ""),
|
||||||
|
"bpm": track.get("bpm", ""),
|
||||||
|
"key": track.get("key", ""),
|
||||||
|
"scale": track.get("scale", ""),
|
||||||
|
}
|
||||||
|
title = template.get("title", "{title}").format(**fmt_vars)
|
||||||
|
description = template.get("description", "{title}").format(**fmt_vars)
|
||||||
|
return {
|
||||||
|
"title": title[:100],
|
||||||
|
"description": description[:5000],
|
||||||
|
"tags": (template.get("tags") or [])[:15],
|
||||||
|
"category_id": template.get("category_id", 10),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _call_claude(track: dict, template: dict, trend_keywords: list[str],
|
||||||
|
feedback: str, *, api_key: str, model: str) -> dict:
|
||||||
|
user_prompt = (
|
||||||
|
"다음 트랙의 YouTube 메타데이터를 생성하세요. JSON으로만 응답.\n\n"
|
||||||
|
f"트랙: {json.dumps(track, ensure_ascii=False)}\n"
|
||||||
|
f"템플릿: {json.dumps(template, ensure_ascii=False)}\n"
|
||||||
|
f"트렌드 키워드: {', '.join(trend_keywords)}\n"
|
||||||
|
)
|
||||||
|
if feedback:
|
||||||
|
user_prompt += f"\n사용자 피드백: {feedback}\n"
|
||||||
|
user_prompt += (
|
||||||
|
'\n출력 JSON: {"title": "60자 이내", "description": "1000자 이내, 3-5문단",'
|
||||||
|
' "tags": ["15개 이내"], "category_id": 10}'
|
||||||
|
)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=TIMEOUT_S) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
"https://api.anthropic.com/v1/messages",
|
||||||
|
headers={
|
||||||
|
"x-api-key": api_key,
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
"content-type": "application/json",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": model,
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"messages": [{"role": "user", "content": user_prompt}],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
text = resp.json()["content"][0]["text"]
|
||||||
|
# 가장 첫 JSON 블록 추출
|
||||||
|
start = text.find("{")
|
||||||
|
end = text.rfind("}") + 1
|
||||||
|
if start < 0 or end <= start:
|
||||||
|
raise ValueError("Claude 응답에 JSON 블록 없음")
|
||||||
|
return json.loads(text[start:end])
|
||||||
183
music-lab/app/pipeline/orchestrator.py
Normal file
183
music-lab/app/pipeline/orchestrator.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
"""파이프라인 오케스트레이터 — 단계별 BackgroundTask 등록 및 산출물 → DB 반영."""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
from app import db
|
||||||
|
from . import cover, video, thumb, metadata, review, youtube
|
||||||
|
|
||||||
|
logger = logging.getLogger("music-lab.orchestrator")
|
||||||
|
|
||||||
|
|
||||||
|
async def run_step(pipeline_id: int, step: str, feedback: str = "") -> None:
|
||||||
|
"""단계 실행 → 결과를 DB에 반영하고 *_pending 또는 다음 단계로 전이.
|
||||||
|
|
||||||
|
호출 직후 _running 상태로 전환, 끝나면 _pending(사용자 게이트) 또는 자동 다음.
|
||||||
|
실패 시 failed 상태 + reason.
|
||||||
|
"""
|
||||||
|
job_id = db.create_pipeline_job(pipeline_id, step)
|
||||||
|
db.update_pipeline_job(job_id, status="running")
|
||||||
|
p = db.get_pipeline(pipeline_id)
|
||||||
|
track = _get_track(p["track_id"])
|
||||||
|
|
||||||
|
try:
|
||||||
|
if step == "cover":
|
||||||
|
result = await _run_cover(p, track, feedback)
|
||||||
|
elif step == "video":
|
||||||
|
result = await _run_video(p, track)
|
||||||
|
elif step == "thumb":
|
||||||
|
result = await _run_thumb(p, track, feedback)
|
||||||
|
elif step == "meta":
|
||||||
|
result = await _run_meta(p, track, feedback)
|
||||||
|
elif step == "review":
|
||||||
|
result = await _run_review(p, track)
|
||||||
|
elif step == "publish":
|
||||||
|
result = await _run_publish(p, track)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown step: {step}")
|
||||||
|
db.update_pipeline_job(job_id, status="succeeded")
|
||||||
|
db.update_pipeline_state(pipeline_id, result["next_state"], **result.get("fields", {}))
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("step %s failed for pipeline %s", step, pipeline_id)
|
||||||
|
db.update_pipeline_job(job_id, status="failed", error=str(e))
|
||||||
|
db.update_pipeline_state(pipeline_id, "failed", failed_reason=f"{step}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_track(track_id: int) -> dict:
|
||||||
|
# tracks 테이블 헬퍼 — 기존 db에 있는 함수 사용
|
||||||
|
t = None
|
||||||
|
if hasattr(db, "get_track_by_id"):
|
||||||
|
t = db.get_track_by_id(track_id)
|
||||||
|
elif hasattr(db, "get_track"):
|
||||||
|
t = db.get_track(track_id)
|
||||||
|
if not t:
|
||||||
|
# 폴백: music_library 테이블에서 직접 (스키마 확인 필요)
|
||||||
|
t = _fetch_track_fallback(track_id)
|
||||||
|
if not t:
|
||||||
|
raise ValueError(f"트랙 {track_id} 없음")
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_track_fallback(track_id: int) -> dict | None:
|
||||||
|
"""db 모듈에 get_track이 없을 때 대비 — music_library 테이블 직접 조회."""
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(db.DB_PATH)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
# 가능한 테이블/컬럼 시도 (music_library 또는 tracks)
|
||||||
|
for table in ("music_library", "tracks"):
|
||||||
|
try:
|
||||||
|
row = conn.execute(f"SELECT * FROM {table} WHERE id = ?", (track_id,)).fetchone()
|
||||||
|
if row:
|
||||||
|
d = dict(row)
|
||||||
|
# JSON 컬럼 파싱 (있으면)
|
||||||
|
for k in ("moods", "instruments"):
|
||||||
|
if k in d and isinstance(d[k], str):
|
||||||
|
try:
|
||||||
|
d[k] = json.loads(d[k])
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
d[k] = []
|
||||||
|
conn.close()
|
||||||
|
return d
|
||||||
|
except sqlite3.OperationalError:
|
||||||
|
continue
|
||||||
|
conn.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("track fallback fetch 실패: %s", e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_cover(p, track, feedback):
|
||||||
|
setup = db.get_youtube_setup()
|
||||||
|
prompts = setup["cover_prompts"]
|
||||||
|
template = prompts.get(track.get("genre", "default").lower(), prompts.get("default", ""))
|
||||||
|
out = await cover.generate(
|
||||||
|
pipeline_id=p["id"], genre=track.get("genre", "default"),
|
||||||
|
prompt_template=template,
|
||||||
|
mood=", ".join(track.get("moods", []) or []),
|
||||||
|
track_title=track.get("title", ""),
|
||||||
|
feedback=feedback,
|
||||||
|
)
|
||||||
|
return {"next_state": "cover_pending", "fields": {"cover_url": out["url"]}}
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_video(p, track):
|
||||||
|
setup = db.get_youtube_setup()
|
||||||
|
vd = setup["visual_defaults"]
|
||||||
|
audio_path = _local_path(track.get("audio_url", ""))
|
||||||
|
cover_path = _local_path(p["cover_url"])
|
||||||
|
out = video.generate(
|
||||||
|
pipeline_id=p["id"], audio_path=audio_path, cover_path=cover_path,
|
||||||
|
genre=track.get("genre", "default"),
|
||||||
|
duration_sec=track.get("duration_sec", 120),
|
||||||
|
resolution=vd["resolution"], style=vd["style"],
|
||||||
|
)
|
||||||
|
return {"next_state": "video_pending", "fields": {"video_url": out["url"]}}
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_thumb(p, track, feedback):
|
||||||
|
video_path = _local_path(p["video_url"])
|
||||||
|
out = thumb.generate(pipeline_id=p["id"], video_path=video_path,
|
||||||
|
track_title=track.get("title", ""), overlay_text=True)
|
||||||
|
return {"next_state": "thumb_pending", "fields": {"thumbnail_url": out["url"]}}
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_meta(p, track, feedback):
|
||||||
|
setup = db.get_youtube_setup()
|
||||||
|
trend_top = _get_trend_top()
|
||||||
|
out = await metadata.generate(
|
||||||
|
track=track, template=setup["metadata_template"],
|
||||||
|
trend_keywords=trend_top, feedback=feedback,
|
||||||
|
)
|
||||||
|
return {"next_state": "meta_pending",
|
||||||
|
"fields": {"metadata_json": json.dumps(out, ensure_ascii=False)}}
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_review(p, track):
|
||||||
|
setup = db.get_youtube_setup()
|
||||||
|
meta = json.loads(p["metadata_json"]) if p.get("metadata_json") else {}
|
||||||
|
result = await review.run_4_axis(
|
||||||
|
pipeline=p, track=track,
|
||||||
|
video_meta={"length_sec": track.get("duration_sec", 120),
|
||||||
|
"resolution": setup["visual_defaults"]["resolution"]},
|
||||||
|
metadata=meta, thumbnail_url=p.get("thumbnail_url", ""),
|
||||||
|
trend_top=_get_trend_top(),
|
||||||
|
weights=setup["review_weights"], threshold=setup["review_threshold"],
|
||||||
|
)
|
||||||
|
return {"next_state": "publish_pending",
|
||||||
|
"fields": {"review_json": json.dumps(result, ensure_ascii=False)}}
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_publish(p, track):
|
||||||
|
setup = db.get_youtube_setup()
|
||||||
|
meta = json.loads(p["metadata_json"]) if p.get("metadata_json") else {}
|
||||||
|
privacy = setup["publish_policy"].get("privacy", "private")
|
||||||
|
result = youtube.upload_video(
|
||||||
|
video_path=_local_path(p["video_url"]),
|
||||||
|
thumbnail_path=_local_path(p["thumbnail_url"]) if p.get("thumbnail_url") else None,
|
||||||
|
metadata=meta, privacy=privacy,
|
||||||
|
)
|
||||||
|
return {"next_state": "published",
|
||||||
|
"fields": {"youtube_video_id": result["video_id"]}}
|
||||||
|
|
||||||
|
|
||||||
|
def _local_path(media_url: str) -> str:
|
||||||
|
""" /media/videos/123/cover.jpg → /app/data/videos/123/cover.jpg """
|
||||||
|
if not media_url:
|
||||||
|
return ""
|
||||||
|
base_media = os.getenv("VIDEO_MEDIA_BASE", "/media/videos")
|
||||||
|
base_data = os.getenv("VIDEO_DATA_DIR", "/app/data/videos")
|
||||||
|
if media_url.startswith(base_media):
|
||||||
|
return media_url.replace(base_media, base_data, 1)
|
||||||
|
# /media/music/abc.mp3 → /app/data/music/abc.mp3
|
||||||
|
return media_url.replace("/media/", "/app/data/", 1)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_trend_top(n: int = 10) -> list[str]:
|
||||||
|
try:
|
||||||
|
if hasattr(db, "get_market_trends"):
|
||||||
|
rows = db.get_market_trends(days=7)
|
||||||
|
return [r.get("genre", "") for r in rows[:n] if r.get("genre")]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return []
|
||||||
120
music-lab/app/pipeline/review.py
Normal file
120
music-lab/app/pipeline/review.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
"""AI 최종 검토 — 4축(메타/정책/시청/트렌드) 가중 평균."""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
logger = logging.getLogger("music-lab.review")
|
||||||
|
|
||||||
|
CLAUDE_SONNET_MODEL_DEFAULT = "claude-sonnet-4-6"
|
||||||
|
TIMEOUT_S = 60
|
||||||
|
|
||||||
|
POLICY_BANNED = {"f-word", "n-word"} # 운영 시 별도 파일로 — 데모용 자리
|
||||||
|
|
||||||
|
|
||||||
|
def _get_api_key() -> str:
|
||||||
|
return os.getenv("ANTHROPIC_API_KEY", "")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model() -> str:
|
||||||
|
return os.getenv("CLAUDE_SONNET_MODEL", CLAUDE_SONNET_MODEL_DEFAULT)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_4_axis(*, pipeline: dict, track: dict, video_meta: dict,
|
||||||
|
metadata: dict, thumbnail_url: str, trend_top: list[str],
|
||||||
|
weights: dict, threshold: int) -> dict:
|
||||||
|
api_key = _get_api_key()
|
||||||
|
if not api_key:
|
||||||
|
return _heuristic(metadata, video_meta, track, trend_top, weights, threshold,
|
||||||
|
fallback_reason="no api key")
|
||||||
|
try:
|
||||||
|
scores = await _call_claude(pipeline, track, video_meta, metadata,
|
||||||
|
thumbnail_url, trend_top,
|
||||||
|
api_key=api_key, model=_get_model())
|
||||||
|
return _weighted_verdict(scores, weights, threshold, used_fallback=False)
|
||||||
|
except (httpx.HTTPError, httpx.TimeoutException, KeyError, ValueError, json.JSONDecodeError) as e:
|
||||||
|
logger.warning("검토 LLM 실패 — 휴리스틱: %s", e)
|
||||||
|
return _heuristic(metadata, video_meta, track, trend_top, weights, threshold,
|
||||||
|
fallback_reason=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
def _weighted_verdict(scores: dict, weights: dict, threshold: int,
|
||||||
|
used_fallback: bool) -> dict:
|
||||||
|
total = (
|
||||||
|
weights["meta"] / 100 * scores["metadata_quality"]["score"] +
|
||||||
|
weights["policy"] / 100 * scores["policy_compliance"]["score"] +
|
||||||
|
weights["viewer"] / 100 * scores["viewer_experience"]["score"] +
|
||||||
|
weights["trend"] / 100 * scores["trend_alignment"]["score"]
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
**scores,
|
||||||
|
"weighted_total": round(total, 2),
|
||||||
|
"verdict": "pass" if total >= threshold else "fail",
|
||||||
|
"used_fallback": used_fallback,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _call_claude(pipeline, track, video_meta, metadata, thumbnail_url, trend_top,
|
||||||
|
*, api_key: str, model: str):
|
||||||
|
user = (
|
||||||
|
"트랙·영상·메타데이터를 4축으로 평가하고 JSON만 응답:\n"
|
||||||
|
f"트랙: {json.dumps(track, ensure_ascii=False)}\n"
|
||||||
|
f"영상: {json.dumps(video_meta)}\n"
|
||||||
|
f"메타: {json.dumps(metadata, ensure_ascii=False)}\n"
|
||||||
|
f"썸네일: {thumbnail_url}\n"
|
||||||
|
f"트렌드: {trend_top}\n"
|
||||||
|
'출력: {"metadata_quality":{"score":0-100,"notes":""},'
|
||||||
|
'"policy_compliance":{"score":0-100,"issues":[]},'
|
||||||
|
'"viewer_experience":{"score":0-100,"notes":""},'
|
||||||
|
'"trend_alignment":{"score":0-100,"matched_keywords":[]},'
|
||||||
|
'"summary":""}'
|
||||||
|
)
|
||||||
|
async with httpx.AsyncClient(timeout=TIMEOUT_S) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
"https://api.anthropic.com/v1/messages",
|
||||||
|
headers={
|
||||||
|
"x-api-key": api_key,
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
"content-type": "application/json",
|
||||||
|
},
|
||||||
|
json={"model": model, "max_tokens": 1024,
|
||||||
|
"messages": [{"role": "user", "content": user}]},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
text = resp.json()["content"][0]["text"]
|
||||||
|
start = text.find("{")
|
||||||
|
end = text.rfind("}") + 1
|
||||||
|
if start < 0 or end <= start:
|
||||||
|
raise ValueError("Claude 응답 JSON 없음")
|
||||||
|
return json.loads(text[start:end])
|
||||||
|
|
||||||
|
|
||||||
|
def _heuristic(metadata, video_meta, track, trend_top, weights, threshold, fallback_reason):
|
||||||
|
# 메타: 길이·태그 카운트
|
||||||
|
title_len = len(metadata.get("title", ""))
|
||||||
|
desc_len = len(metadata.get("description", ""))
|
||||||
|
tag_n = len(metadata.get("tags", []))
|
||||||
|
meta_score = 100 if 5 <= title_len <= 60 and 50 <= desc_len <= 1000 and 5 <= tag_n <= 15 else 50
|
||||||
|
|
||||||
|
# 정책: 금칙어 매치
|
||||||
|
text_blob = (metadata.get("title", "") + metadata.get("description", "")).lower()
|
||||||
|
policy_score = 100 if not any(w in text_blob for w in POLICY_BANNED) else 30
|
||||||
|
|
||||||
|
# 시청: 영상 길이가 트랙과 큰 차이 없는지 휴리스틱(±5초)
|
||||||
|
expected = track.get("duration_sec", video_meta.get("length_sec", 0))
|
||||||
|
delta = abs(video_meta.get("length_sec", 0) - expected)
|
||||||
|
viewer_score = 90 if delta <= 5 else 60
|
||||||
|
|
||||||
|
# 트렌드: 태그가 트렌드와 겹치는지
|
||||||
|
overlap = set(metadata.get("tags", [])) & set(trend_top)
|
||||||
|
trend_score = 100 if overlap else 40
|
||||||
|
|
||||||
|
scores = {
|
||||||
|
"metadata_quality": {"score": meta_score, "notes": "휴리스틱"},
|
||||||
|
"policy_compliance": {"score": policy_score, "issues": []},
|
||||||
|
"viewer_experience": {"score": viewer_score, "notes": "휴리스틱"},
|
||||||
|
"trend_alignment": {"score": trend_score, "matched_keywords": list(overlap)},
|
||||||
|
"summary": f"휴리스틱 fallback: {fallback_reason}",
|
||||||
|
}
|
||||||
|
return _weighted_verdict(scores, weights, threshold, used_fallback=True)
|
||||||
41
music-lab/app/pipeline/state_machine.py
Normal file
41
music-lab/app/pipeline/state_machine.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""파이프라인 상태 머신 — 전이 규칙 단일 소스."""
|
||||||
|
|
||||||
|
STEPS = ["cover", "video", "thumb", "meta", "review", "publish"]
|
||||||
|
USER_GATES = ["cover", "video", "thumb", "meta", "publish"] # review는 자동
|
||||||
|
|
||||||
|
_APPROVE_NEXT = {
|
||||||
|
"cover_pending": "video_pending",
|
||||||
|
"video_pending": "thumb_pending",
|
||||||
|
"thumb_pending": "meta_pending",
|
||||||
|
"meta_pending": "ai_review", # 자동 검토 단계로
|
||||||
|
"publish_pending": "publishing",
|
||||||
|
}
|
||||||
|
|
||||||
|
TERMINAL_STATES = {"published", "cancelled", "failed", "awaiting_manual"}
|
||||||
|
|
||||||
|
|
||||||
|
def next_state_on_approve(state: str) -> str:
|
||||||
|
if state not in _APPROVE_NEXT:
|
||||||
|
raise ValueError(f"승인 불가 상태: {state}")
|
||||||
|
return _APPROVE_NEXT[state]
|
||||||
|
|
||||||
|
|
||||||
|
def next_state_on_reject(state: str) -> str:
|
||||||
|
if not state.endswith("_pending"):
|
||||||
|
raise ValueError(f"반려 불가 상태: {state}")
|
||||||
|
return state # 같은 상태 유지 (재생성 후 다시 _pending)
|
||||||
|
|
||||||
|
|
||||||
|
def can_transition(from_state: str, to_state: str) -> bool:
|
||||||
|
if from_state in TERMINAL_STATES:
|
||||||
|
return False
|
||||||
|
if to_state in {"cancelled", "failed", "awaiting_manual"}:
|
||||||
|
return True
|
||||||
|
if to_state == _APPROVE_NEXT.get(from_state):
|
||||||
|
return True
|
||||||
|
# 자동 전이 (ai_review → publish_pending, publishing → published)
|
||||||
|
auto_transitions = {
|
||||||
|
("ai_review", "publish_pending"),
|
||||||
|
("publishing", "published"),
|
||||||
|
}
|
||||||
|
return (from_state, to_state) in auto_transitions
|
||||||
15
music-lab/app/pipeline/storage.py
Normal file
15
music-lab/app/pipeline/storage.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""파이프라인 산출물 디렉토리 관리."""
|
||||||
|
import os
|
||||||
|
|
||||||
|
VIDEO_DATA_DIR = os.getenv("VIDEO_DATA_DIR", "/app/data/videos")
|
||||||
|
VIDEO_MEDIA_BASE = os.getenv("VIDEO_MEDIA_BASE", "/media/videos")
|
||||||
|
|
||||||
|
|
||||||
|
def pipeline_dir(pipeline_id: int) -> str:
|
||||||
|
path = os.path.join(VIDEO_DATA_DIR, str(pipeline_id))
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def media_url(pipeline_id: int, filename: str) -> str:
|
||||||
|
return f"{VIDEO_MEDIA_BASE}/{pipeline_id}/{filename}"
|
||||||
51
music-lab/app/pipeline/thumb.py
Normal file
51
music-lab/app/pipeline/thumb.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
"""썸네일 생성 — 영상 5초 프레임 추출 + 텍스트 오버레이."""
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import logging
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
from . import storage
|
||||||
|
|
||||||
|
logger = logging.getLogger("music-lab.thumb")
|
||||||
|
THUMB_TIMEOUT_S = 60
|
||||||
|
|
||||||
|
|
||||||
|
class ThumbGenerationError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def generate(*, pipeline_id: int, video_path: str,
|
||||||
|
track_title: str = "", overlay_text: bool = True) -> dict:
|
||||||
|
out_path = os.path.join(storage.pipeline_dir(pipeline_id), "thumbnail.jpg")
|
||||||
|
cmd = ["ffmpeg", "-y", "-i", video_path,
|
||||||
|
"-ss", "00:00:05", "-vframes", "1", "-q:v", "2", out_path]
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=THUMB_TIMEOUT_S)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise ThumbGenerationError(f"ffmpeg 썸네일 실패: {result.stderr[:300]}")
|
||||||
|
|
||||||
|
if overlay_text and track_title:
|
||||||
|
_overlay_title(out_path, track_title)
|
||||||
|
|
||||||
|
return {"url": storage.media_url(pipeline_id, "thumbnail.jpg"), "used_fallback": False}
|
||||||
|
|
||||||
|
|
||||||
|
def _overlay_title(path: str, title: str) -> None:
|
||||||
|
try:
|
||||||
|
with Image.open(path) as src:
|
||||||
|
img = src.convert("RGB")
|
||||||
|
try:
|
||||||
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 80)
|
||||||
|
except OSError:
|
||||||
|
font = ImageFont.load_default()
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
# 하단 30% 영역에 검정 반투명 박스 + 흰 글씨
|
||||||
|
w, h = img.size
|
||||||
|
box_h = int(h * 0.3)
|
||||||
|
with Image.new("RGBA", (w, box_h), (0, 0, 0, 160)) as overlay:
|
||||||
|
img.paste(overlay, (0, h - box_h), overlay)
|
||||||
|
bbox = draw.textbbox((0, 0), title, font=font)
|
||||||
|
tw = bbox[2] - bbox[0]
|
||||||
|
draw.text(((w - tw) // 2, h - box_h + 30), title, fill=(255, 255, 255), font=font)
|
||||||
|
img.save(path, "JPEG", quality=92)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("썸네일 오버레이 실패: %s", e)
|
||||||
55
music-lab/app/pipeline/video.py
Normal file
55
music-lab/app/pipeline/video.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
"""영상 비주얼 생성 — visualizer/슬라이드쇼 스타일."""
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from . import storage
|
||||||
|
|
||||||
|
logger = logging.getLogger("music-lab.video")
|
||||||
|
|
||||||
|
VIDEO_TIMEOUT_S = 300 # 5분
|
||||||
|
|
||||||
|
|
||||||
|
class VideoGenerationError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def generate(*, pipeline_id: int, audio_path: str, cover_path: str,
|
||||||
|
genre: str, duration_sec: int, resolution: str = "1920x1080",
|
||||||
|
style: str = "visualizer") -> dict:
|
||||||
|
"""영상 생성. 성공 시 mp4 저장 + URL 반환. 실패 시 예외."""
|
||||||
|
w, h = resolution.split("x")
|
||||||
|
out_path = os.path.join(storage.pipeline_dir(pipeline_id), "video.mp4")
|
||||||
|
|
||||||
|
if style == "visualizer":
|
||||||
|
cmd = _build_visualizer_cmd(audio_path, cover_path, out_path, w, h)
|
||||||
|
else:
|
||||||
|
# 차후: 슬라이드쇼 등 다른 스타일 — 현재는 visualizer 폴백
|
||||||
|
cmd = _build_visualizer_cmd(audio_path, cover_path, out_path, w, h)
|
||||||
|
|
||||||
|
logger.info("ffmpeg 실행: %s", " ".join(cmd))
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=VIDEO_TIMEOUT_S)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise VideoGenerationError(f"ffmpeg 실패: {result.stderr[:500]}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"url": storage.media_url(pipeline_id, "video.mp4"),
|
||||||
|
"used_fallback": False,
|
||||||
|
"duration_sec": duration_sec,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_visualizer_cmd(audio: str, bg: str, out: str, w: str, h: str) -> list:
|
||||||
|
return [
|
||||||
|
"ffmpeg", "-y",
|
||||||
|
"-loop", "1", "-i", bg,
|
||||||
|
"-i", audio,
|
||||||
|
"-filter_complex",
|
||||||
|
f"[0:v]scale={w}:{h}[bg];"
|
||||||
|
f"[1:a]showwaves=s={w}x200:mode=cline:colors=0xFF4444@0.8[wave];"
|
||||||
|
f"[bg][wave]overlay=0:({h}-200)[out]",
|
||||||
|
"-map", "[out]", "-map", "1:a",
|
||||||
|
"-c:v", "libx264", "-preset", "fast", "-crf", "23",
|
||||||
|
"-c:a", "aac", "-b:a", "192k",
|
||||||
|
"-shortest", out,
|
||||||
|
]
|
||||||
156
music-lab/app/pipeline/youtube.py
Normal file
156
music-lab/app/pipeline/youtube.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
"""YouTube OAuth flow + resumable 업로드."""
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from google.oauth2.credentials import Credentials
|
||||||
|
from googleapiclient.discovery import build
|
||||||
|
from googleapiclient.http import MediaFileUpload
|
||||||
|
from googleapiclient.errors import HttpError
|
||||||
|
|
||||||
|
from app import db
|
||||||
|
|
||||||
|
logger = logging.getLogger("music-lab.youtube")
|
||||||
|
|
||||||
|
SCOPES = ["https://www.googleapis.com/auth/youtube.upload",
|
||||||
|
"https://www.googleapis.com/auth/youtube.readonly"]
|
||||||
|
|
||||||
|
|
||||||
|
class NotAuthenticatedError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaExceededError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _client_id() -> str:
|
||||||
|
return os.getenv("YOUTUBE_OAUTH_CLIENT_ID", "")
|
||||||
|
|
||||||
|
|
||||||
|
def _client_secret() -> str:
|
||||||
|
return os.getenv("YOUTUBE_OAUTH_CLIENT_SECRET", "")
|
||||||
|
|
||||||
|
|
||||||
|
def _redirect_uri() -> str:
|
||||||
|
return os.getenv("YOUTUBE_OAUTH_REDIRECT_URI", "")
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_url() -> str:
|
||||||
|
cid = _client_id()
|
||||||
|
redirect = _redirect_uri()
|
||||||
|
if not cid or not redirect:
|
||||||
|
raise RuntimeError("OAuth 환경변수 미설정")
|
||||||
|
params = {
|
||||||
|
"client_id": cid,
|
||||||
|
"redirect_uri": redirect,
|
||||||
|
"response_type": "code",
|
||||||
|
"scope": " ".join(SCOPES),
|
||||||
|
"access_type": "offline",
|
||||||
|
"prompt": "consent",
|
||||||
|
}
|
||||||
|
return "https://accounts.google.com/o/oauth2/v2/auth?" + urlencode(params)
|
||||||
|
|
||||||
|
|
||||||
|
async def exchange_code(code: str) -> dict:
|
||||||
|
"""code → refresh_token + access_token + 채널 정보 → DB 저장."""
|
||||||
|
async with httpx.AsyncClient(timeout=30) as client:
|
||||||
|
token_resp = await client.post(
|
||||||
|
"https://oauth2.googleapis.com/token",
|
||||||
|
data={
|
||||||
|
"code": code,
|
||||||
|
"client_id": _client_id(),
|
||||||
|
"client_secret": _client_secret(),
|
||||||
|
"redirect_uri": _redirect_uri(),
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
token_resp.raise_for_status()
|
||||||
|
tok = token_resp.json()
|
||||||
|
access = tok["access_token"]
|
||||||
|
refresh = tok["refresh_token"]
|
||||||
|
expires_at = _expiry_from_seconds(tok["expires_in"])
|
||||||
|
|
||||||
|
creds = _creds(access=access, refresh=refresh)
|
||||||
|
yt = _build_youtube_client(creds)
|
||||||
|
ch = yt.channels().list(part="snippet", mine=True).execute()
|
||||||
|
item = ch["items"][0]
|
||||||
|
db.upsert_oauth_token(
|
||||||
|
channel_id=item["id"],
|
||||||
|
channel_title=item["snippet"]["title"],
|
||||||
|
avatar_url=item["snippet"]["thumbnails"]["default"]["url"],
|
||||||
|
refresh_token=refresh, access_token=access, expires_at=expires_at,
|
||||||
|
)
|
||||||
|
return {"channel_id": item["id"], "channel_title": item["snippet"]["title"]}
|
||||||
|
|
||||||
|
|
||||||
|
def get_status() -> dict | None:
|
||||||
|
tok = db.get_oauth_token()
|
||||||
|
if not tok:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"channel_id": tok["channel_id"],
|
||||||
|
"channel_title": tok["channel_title"],
|
||||||
|
"avatar_url": tok["avatar_url"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def disconnect() -> None:
|
||||||
|
db.delete_oauth_token()
|
||||||
|
|
||||||
|
|
||||||
|
def upload_video(*, video_path: str, thumbnail_path: str | None,
|
||||||
|
metadata: dict, privacy: str) -> dict:
|
||||||
|
tok = db.get_oauth_token()
|
||||||
|
if not tok:
|
||||||
|
raise NotAuthenticatedError("YouTube 인증 없음")
|
||||||
|
creds = _creds(access=tok["access_token"], refresh=tok["refresh_token"])
|
||||||
|
yt = _build_youtube_client(creds)
|
||||||
|
|
||||||
|
body = {
|
||||||
|
"snippet": {
|
||||||
|
"title": metadata["title"],
|
||||||
|
"description": metadata["description"],
|
||||||
|
"tags": metadata.get("tags", []),
|
||||||
|
"categoryId": str(metadata.get("category_id", 10)),
|
||||||
|
},
|
||||||
|
"status": {"privacyStatus": privacy, "selfDeclaredMadeForKids": False},
|
||||||
|
}
|
||||||
|
media = MediaFileUpload(video_path, chunksize=4 * 1024 * 1024, resumable=True, mimetype="video/mp4")
|
||||||
|
req = yt.videos().insert(part="snippet,status", body=body, media_body=media)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = None
|
||||||
|
while response is None:
|
||||||
|
status, response = req.next_chunk()
|
||||||
|
video_id = response["id"]
|
||||||
|
except HttpError as e:
|
||||||
|
if b"quotaExceeded" in (e.content or b""):
|
||||||
|
raise QuotaExceededError(str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
if thumbnail_path:
|
||||||
|
try:
|
||||||
|
yt.thumbnails().set(videoId=video_id, media_body=thumbnail_path).execute()
|
||||||
|
except HttpError as e:
|
||||||
|
logger.warning("썸네일 업로드 실패: %s", e)
|
||||||
|
|
||||||
|
return {"video_id": video_id}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_youtube_client(creds): # patch 포인트
|
||||||
|
return build("youtube", "v3", credentials=creds, cache_discovery=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _creds(access: str, refresh: str) -> Credentials:
|
||||||
|
return Credentials(
|
||||||
|
token=access, refresh_token=refresh,
|
||||||
|
token_uri="https://oauth2.googleapis.com/token",
|
||||||
|
client_id=_client_id(), client_secret=_client_secret(), scopes=SCOPES,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _expiry_from_seconds(secs: int) -> str:
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
return (datetime.utcnow() + timedelta(seconds=secs)).isoformat(timespec="seconds")
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
[pytest]
|
[pytest]
|
||||||
testpaths = tests
|
testpaths = tests
|
||||||
pythonpath = .
|
pythonpath = .
|
||||||
|
asyncio_mode = auto
|
||||||
|
|||||||
@@ -4,6 +4,13 @@ requests==2.32.3
|
|||||||
python-multipart==0.0.12
|
python-multipart==0.0.12
|
||||||
mutagen==1.47.0
|
mutagen==1.47.0
|
||||||
anthropic>=0.40.0
|
anthropic>=0.40.0
|
||||||
|
openai>=1.20.0
|
||||||
Pillow>=11.0.0
|
Pillow>=11.0.0
|
||||||
pytest>=8.0.0
|
pytest>=8.0.0
|
||||||
|
pytest-asyncio>=0.21
|
||||||
httpx>=0.27.0
|
httpx>=0.27.0
|
||||||
|
respx>=0.21
|
||||||
|
freezegun>=1.4
|
||||||
|
google-api-python-client>=2.100
|
||||||
|
google-auth-oauthlib>=1.2
|
||||||
|
google-auth-httplib2>=0.2
|
||||||
|
|||||||
@@ -1,7 +1,36 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def tmp_db(tmp_path, monkeypatch):
|
def tmp_db(tmp_path, monkeypatch):
|
||||||
db_path = str(tmp_path / "test_music.db")
|
db_path = str(tmp_path / "test_music.db")
|
||||||
monkeypatch.setattr("app.db.DB_PATH", db_path)
|
monkeypatch.setattr("app.db.DB_PATH", db_path)
|
||||||
return db_path
|
return db_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def freezer():
|
||||||
|
"""Minimal freezegun-based fixture providing `move_to(time)` to mimic
|
||||||
|
pytest-freezer's `freezer` fixture using only the `freezegun` package."""
|
||||||
|
from freezegun import freeze_time
|
||||||
|
|
||||||
|
class _Freezer:
|
||||||
|
def __init__(self):
|
||||||
|
self._ctx = None
|
||||||
|
|
||||||
|
def move_to(self, target):
|
||||||
|
if self._ctx is not None:
|
||||||
|
self._ctx.stop()
|
||||||
|
self._ctx = freeze_time(target)
|
||||||
|
self._ctx.start()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
if self._ctx is not None:
|
||||||
|
self._ctx.stop()
|
||||||
|
self._ctx = None
|
||||||
|
|
||||||
|
f = _Freezer()
|
||||||
|
try:
|
||||||
|
yield f
|
||||||
|
finally:
|
||||||
|
f.stop()
|
||||||
|
|||||||
93
music-lab/tests/test_cover_generation.py
Normal file
93
music-lab/tests/test_cover_generation.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
import base64
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import respx
|
||||||
|
from httpx import Response
|
||||||
|
from app.pipeline import cover, storage
|
||||||
|
|
||||||
|
|
||||||
|
# Real PNG bytes (1x1 red pixel) so PIL can open
|
||||||
|
_TINY_PNG = bytes.fromhex(
|
||||||
|
"89504e470d0a1a0a0000000d49484452000000010000000108020000009077"
|
||||||
|
"53de0000000c4944415478da6300010000050001"
|
||||||
|
"0d0a2db40000000049454e44ae426082"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmp_storage(monkeypatch, tmp_path):
|
||||||
|
monkeypatch.setattr(storage, "VIDEO_DATA_DIR", str(tmp_path))
|
||||||
|
return tmp_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_dalle_success_saves_jpg(tmp_storage, monkeypatch):
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||||
|
image_url = "https://oaidalleapiprodscus.blob.core.windows.net/x.png"
|
||||||
|
respx.post("https://api.openai.com/v1/images/generations").mock(
|
||||||
|
return_value=Response(200, json={"data": [{"url": image_url}]})
|
||||||
|
)
|
||||||
|
respx.get(image_url).mock(return_value=Response(200, content=_TINY_PNG))
|
||||||
|
|
||||||
|
out = await cover.generate(pipeline_id=42, genre="lo-fi",
|
||||||
|
prompt_template="moody anime", mood="chill",
|
||||||
|
track_title="Test")
|
||||||
|
assert out["used_fallback"] is False
|
||||||
|
assert out["url"].startswith("/media/videos/42/cover")
|
||||||
|
assert (tmp_storage / "42" / "cover.jpg").exists()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_dalle_http_error_falls_back_to_gradient(tmp_storage, monkeypatch):
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||||
|
respx.post("https://api.openai.com/v1/images/generations").mock(
|
||||||
|
return_value=Response(504)
|
||||||
|
)
|
||||||
|
out = await cover.generate(pipeline_id=43, genre="phonk",
|
||||||
|
prompt_template="dark drift", mood="aggressive",
|
||||||
|
track_title="Midnight Drive")
|
||||||
|
assert out["used_fallback"] is True
|
||||||
|
assert (tmp_storage / "43" / "cover.jpg").exists()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_api_key_falls_back(tmp_storage, monkeypatch):
|
||||||
|
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||||
|
out = await cover.generate(pipeline_id=44, genre="ambient",
|
||||||
|
prompt_template="x", mood="calm",
|
||||||
|
track_title="Calm")
|
||||||
|
assert out["used_fallback"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_dalle_with_feedback_appends_to_prompt(tmp_storage, monkeypatch):
|
||||||
|
import json as _json
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||||
|
captured = {}
|
||||||
|
def hook(req):
|
||||||
|
captured["body"] = _json.loads(req.content)
|
||||||
|
return Response(200, json={"data": [{"url": "https://x"}]})
|
||||||
|
respx.post("https://api.openai.com/v1/images/generations").mock(side_effect=hook)
|
||||||
|
respx.get("https://x").mock(return_value=Response(200, content=_TINY_PNG))
|
||||||
|
out = await cover.generate(pipeline_id=45, genre="lo-fi",
|
||||||
|
prompt_template="moody anime", mood="chill",
|
||||||
|
track_title="X", feedback="더 어둡게")
|
||||||
|
assert "더 어둡게" in captured["body"]["prompt"]
|
||||||
|
assert out["used_fallback"] is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_dalle_b64_response_handled(tmp_storage, monkeypatch):
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||||
|
b64 = base64.b64encode(_TINY_PNG).decode()
|
||||||
|
respx.post("https://api.openai.com/v1/images/generations").mock(
|
||||||
|
return_value=Response(200, json={"data": [{"b64_json": b64}]})
|
||||||
|
)
|
||||||
|
out = await cover.generate(pipeline_id=46, genre="lo-fi",
|
||||||
|
prompt_template="x", mood="", track_title="X")
|
||||||
|
assert out["used_fallback"] is False
|
||||||
|
assert (tmp_storage / "46" / "cover.jpg").exists()
|
||||||
82
music-lab/tests/test_metadata_generation.py
Normal file
82
music-lab/tests/test_metadata_generation.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
import pytest
|
||||||
|
import respx
|
||||||
|
from httpx import Response
|
||||||
|
from app.pipeline import metadata
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_metadata_calls_claude_and_parses_json(monkeypatch):
|
||||||
|
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
|
||||||
|
payload = {
|
||||||
|
"content": [{"type": "text", "text": '{"title":"[Lo-fi] Drive | 85BPM",'
|
||||||
|
'"description":"chill","tags":["lofi","85bpm"],'
|
||||||
|
'"category_id":10}'}]
|
||||||
|
}
|
||||||
|
respx.post("https://api.anthropic.com/v1/messages").mock(
|
||||||
|
return_value=Response(200, json=payload)
|
||||||
|
)
|
||||||
|
result = await metadata.generate(
|
||||||
|
track={"title": "Drive", "genre": "lo-fi", "bpm": 85, "key": "C", "scale": "minor",
|
||||||
|
"moods": ["chill"], "instruments": ["piano"]},
|
||||||
|
template={"title": "[{genre}] {title} | {bpm}BPM",
|
||||||
|
"description": "{title}\n", "tags": [], "category_id": 10},
|
||||||
|
trend_keywords=["lofi", "study"],
|
||||||
|
feedback="",
|
||||||
|
)
|
||||||
|
assert result["title"].startswith("[Lo-fi]")
|
||||||
|
assert "lofi" in result["tags"]
|
||||||
|
assert result["used_fallback"] is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_metadata_fallback_when_no_api_key(monkeypatch):
|
||||||
|
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||||
|
result = await metadata.generate(
|
||||||
|
track={"title": "Drive", "genre": "lo-fi", "bpm": 85, "key": "C", "scale": "minor",
|
||||||
|
"moods": [], "instruments": []},
|
||||||
|
template={"title": "[{genre}] {title} | {bpm}BPM",
|
||||||
|
"description": "{title}", "tags": ["lofi"], "category_id": 10},
|
||||||
|
trend_keywords=[],
|
||||||
|
)
|
||||||
|
# 템플릿 변수 그대로 치환된 폴백
|
||||||
|
assert result["title"] == "[lo-fi] Drive | 85BPM"
|
||||||
|
assert result["used_fallback"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_metadata_includes_feedback_in_prompt(monkeypatch):
|
||||||
|
import json
|
||||||
|
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
|
||||||
|
captured = {}
|
||||||
|
def hook(req):
|
||||||
|
captured["body"] = json.loads(req.content)
|
||||||
|
return Response(200, json={"content": [{"type": "text",
|
||||||
|
"text": '{"title":"x","description":"y","tags":[],"category_id":10}'}]})
|
||||||
|
respx.post("https://api.anthropic.com/v1/messages").mock(side_effect=hook)
|
||||||
|
await metadata.generate(
|
||||||
|
track={"title": "X", "genre": "lo-fi", "bpm": 85, "key": "C", "scale": "minor",
|
||||||
|
"moods": [], "instruments": []},
|
||||||
|
template={"title": "{title}", "description": "{title}", "tags": [], "category_id": 10},
|
||||||
|
trend_keywords=[],
|
||||||
|
feedback="제목을 짧게",
|
||||||
|
)
|
||||||
|
assert "제목을 짧게" in str(captured["body"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_metadata_falls_back_on_api_error(monkeypatch):
|
||||||
|
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
|
||||||
|
respx.post("https://api.anthropic.com/v1/messages").mock(
|
||||||
|
return_value=Response(500)
|
||||||
|
)
|
||||||
|
result = await metadata.generate(
|
||||||
|
track={"title": "Drive", "genre": "lo-fi", "bpm": 85, "key": "C", "scale": "minor",
|
||||||
|
"moods": [], "instruments": []},
|
||||||
|
template={"title": "[{genre}] {title}", "description": "x", "tags": ["lofi"], "category_id": 10},
|
||||||
|
trend_keywords=[],
|
||||||
|
)
|
||||||
|
assert result["used_fallback"] is True
|
||||||
|
assert "Drive" in result["title"]
|
||||||
96
music-lab/tests/test_pipeline_db.py
Normal file
96
music-lab/tests/test_pipeline_db.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import pytest
|
||||||
|
from app import db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fresh_db(monkeypatch, tmp_path):
|
||||||
|
db_path = tmp_path / "music.db"
|
||||||
|
monkeypatch.setattr(db, "DB_PATH", str(db_path))
|
||||||
|
db.init_db()
|
||||||
|
return db_path
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_pipeline_inserts_row(fresh_db):
|
||||||
|
pid = db.create_pipeline(track_id=1)
|
||||||
|
row = db.get_pipeline(pid)
|
||||||
|
assert row["id"] == pid
|
||||||
|
assert row["state"] == "created"
|
||||||
|
assert row["track_id"] == 1
|
||||||
|
assert row["feedback_count_per_step"] == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_pipeline_state_records_started_at(fresh_db, freezer):
|
||||||
|
pid = db.create_pipeline(track_id=1)
|
||||||
|
freezer.move_to("2026-05-07T08:00:00")
|
||||||
|
db.update_pipeline_state(pid, "cover_pending")
|
||||||
|
row = db.get_pipeline(pid)
|
||||||
|
assert row["state"] == "cover_pending"
|
||||||
|
assert row["state_started_at"] == "2026-05-07T08:00:00"
|
||||||
|
|
||||||
|
|
||||||
|
def test_increment_feedback_count(fresh_db):
|
||||||
|
pid = db.create_pipeline(track_id=1)
|
||||||
|
db.increment_feedback_count(pid, "cover")
|
||||||
|
db.increment_feedback_count(pid, "cover")
|
||||||
|
row = db.get_pipeline(pid)
|
||||||
|
assert row["feedback_count_per_step"] == {"cover": 2}
|
||||||
|
|
||||||
|
|
||||||
|
def test_record_feedback(fresh_db):
|
||||||
|
pid = db.create_pipeline(track_id=1)
|
||||||
|
db.record_feedback(pid, "cover", "더 어둡게")
|
||||||
|
rows = db.get_feedback_history(pid)
|
||||||
|
assert len(rows) == 1
|
||||||
|
assert rows[0]["feedback_text"] == "더 어둡게"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_pipeline_job_lifecycle(fresh_db):
|
||||||
|
pid = db.create_pipeline(track_id=1)
|
||||||
|
job_id = db.create_pipeline_job(pid, "cover")
|
||||||
|
db.update_pipeline_job(job_id, status="running")
|
||||||
|
db.update_pipeline_job(job_id, status="succeeded", duration_ms=1234)
|
||||||
|
jobs = db.list_pipeline_jobs(pid)
|
||||||
|
assert jobs[0]["status"] == "succeeded"
|
||||||
|
assert jobs[0]["duration_ms"] == 1234
|
||||||
|
|
||||||
|
|
||||||
|
def test_youtube_setup_default_row_created_on_init(fresh_db):
|
||||||
|
setup = db.get_youtube_setup()
|
||||||
|
assert setup["review_threshold"] == 60
|
||||||
|
assert "metadata_template_json" in setup
|
||||||
|
|
||||||
|
|
||||||
|
def test_youtube_oauth_token_upsert(fresh_db):
|
||||||
|
db.upsert_oauth_token(
|
||||||
|
channel_id="UC123",
|
||||||
|
channel_title="My Channel",
|
||||||
|
avatar_url="https://...",
|
||||||
|
refresh_token="r1",
|
||||||
|
access_token="a1",
|
||||||
|
expires_at="2026-05-07T09:00:00",
|
||||||
|
)
|
||||||
|
tok = db.get_oauth_token()
|
||||||
|
assert tok["channel_id"] == "UC123"
|
||||||
|
assert tok["refresh_token"] == "r1"
|
||||||
|
db.upsert_oauth_token(
|
||||||
|
channel_id="UC123", channel_title="My Channel",
|
||||||
|
avatar_url=None, refresh_token="r2",
|
||||||
|
access_token="a2", expires_at="2026-05-07T10:00:00",
|
||||||
|
)
|
||||||
|
tok = db.get_oauth_token()
|
||||||
|
assert tok["refresh_token"] == "r2" # upsert
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_pipeline_state_rejects_unknown_column(fresh_db):
|
||||||
|
pid = db.create_pipeline(track_id=1)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
db.update_pipeline_state(pid, "cover_pending", evil_col="x; DROP TABLE")
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_pipeline_job_rejects_unknown_column(fresh_db):
|
||||||
|
pid = db.create_pipeline(track_id=1)
|
||||||
|
job_id = db.create_pipeline_job(pid, "cover")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
db.update_pipeline_job(job_id, evil_col="x")
|
||||||
110
music-lab/tests/test_pipeline_endpoints.py
Normal file
110
music-lab/tests/test_pipeline_endpoints.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
import sqlite3
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from app.main import app
|
||||||
|
from app import db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(monkeypatch, tmp_path):
|
||||||
|
monkeypatch.setattr(db, "DB_PATH", str(tmp_path / "music.db"))
|
||||||
|
db.init_db()
|
||||||
|
# 최소 트랙 1개 — music_library 테이블에 직접 삽입
|
||||||
|
conn = sqlite3.connect(db.DB_PATH)
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""INSERT INTO music_library
|
||||||
|
(id, title, genre, moods, instruments, duration_sec, bpm, key, scale,
|
||||||
|
prompt, audio_url, file_path, task_id, tags)
|
||||||
|
VALUES (1, 'T', 'lo-fi', '["chill"]', '["piano"]', 120, 85, 'C', 'maj',
|
||||||
|
'p', '/media/music/x.mp3', '/app/data/music/x.mp3', NULL, '[]')""",
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_pipeline(client):
|
||||||
|
r = client.post("/api/music/pipeline", json={"track_id": 1})
|
||||||
|
assert r.status_code == 201
|
||||||
|
assert r.json()["state"] == "created"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_duplicate_pipeline_returns_409(client):
|
||||||
|
client.post("/api/music/pipeline", json={"track_id": 1})
|
||||||
|
r = client.post("/api/music/pipeline", json={"track_id": 1})
|
||||||
|
assert r.status_code == 409
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_pipeline_returns_jobs_and_feedback(client):
|
||||||
|
pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"]
|
||||||
|
r = client.get(f"/api/music/pipeline/{pid}")
|
||||||
|
assert "jobs" in r.json()
|
||||||
|
assert "feedback" in r.json()
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_pipelines_active_filter(client):
|
||||||
|
pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"]
|
||||||
|
db.update_pipeline_state(pid, "published")
|
||||||
|
r = client.get("/api/music/pipeline?status=active")
|
||||||
|
assert all(p["state"] != "published" for p in r.json()["pipelines"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_feedback_reject_records_feedback_and_increments_count(client):
|
||||||
|
pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"]
|
||||||
|
db.update_pipeline_state(pid, "cover_pending")
|
||||||
|
# orchestrator.run_step를 mock해서 백그라운드 작업이 cover_pending을 변경하지 않도록
|
||||||
|
with patch("app.main.orchestrator.run_step", new=AsyncMock()):
|
||||||
|
r = client.post(
|
||||||
|
f"/api/music/pipeline/{pid}/feedback",
|
||||||
|
json={"step": "cover", "intent": "reject", "feedback_text": "더 어둡게"},
|
||||||
|
)
|
||||||
|
assert r.status_code == 202
|
||||||
|
p = db.get_pipeline(pid)
|
||||||
|
assert p["feedback_count_per_step"]["cover"] == 1
|
||||||
|
history = db.get_feedback_history(pid)
|
||||||
|
assert history[0]["feedback_text"] == "더 어둡게"
|
||||||
|
|
||||||
|
|
||||||
|
def test_feedback_after_5_rejects_marks_awaiting_manual(client):
|
||||||
|
pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"]
|
||||||
|
db.update_pipeline_state(pid, "cover_pending")
|
||||||
|
with patch("app.main.orchestrator.run_step", new=AsyncMock()):
|
||||||
|
for i in range(5):
|
||||||
|
client.post(
|
||||||
|
f"/api/music/pipeline/{pid}/feedback",
|
||||||
|
json={"step": "cover", "intent": "reject", "feedback_text": f"again {i}"},
|
||||||
|
)
|
||||||
|
r = client.post(
|
||||||
|
f"/api/music/pipeline/{pid}/feedback",
|
||||||
|
json={"step": "cover", "intent": "reject", "feedback_text": "6th"},
|
||||||
|
)
|
||||||
|
assert r.status_code == 409
|
||||||
|
assert db.get_pipeline(pid)["state"] == "awaiting_manual"
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancel_pipeline(client):
|
||||||
|
pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"]
|
||||||
|
r = client.post(f"/api/music/pipeline/{pid}/cancel")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert db.get_pipeline(pid)["state"] == "cancelled"
|
||||||
|
|
||||||
|
|
||||||
|
def test_setup_get_returns_defaults(client):
|
||||||
|
r = client.get("/api/music/setup")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json()["review_threshold"] == 60
|
||||||
|
|
||||||
|
|
||||||
|
def test_setup_put_updates(client):
|
||||||
|
r = client.put("/api/music/setup", json={"review_threshold": 70})
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json()["review_threshold"] == 70
|
||||||
|
|
||||||
|
|
||||||
|
def test_youtube_status_when_disconnected(client):
|
||||||
|
r = client.get("/api/music/youtube/status")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == {"connected": False}
|
||||||
113
music-lab/tests/test_pipeline_flow.py
Normal file
113
music-lab/tests/test_pipeline_flow.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""풀 파이프라인 happy-path 통합 테스트 — 모든 외부 호출 mock."""
|
||||||
|
import pytest
|
||||||
|
import sqlite3
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from app.main import app
|
||||||
|
from app import db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(monkeypatch, tmp_path):
|
||||||
|
monkeypatch.setattr(db, "DB_PATH", str(tmp_path / "music.db"))
|
||||||
|
db.init_db()
|
||||||
|
# 기본 트랙 1개 등록 (music_library 테이블)
|
||||||
|
conn = sqlite3.connect(db.DB_PATH)
|
||||||
|
cur = conn.cursor()
|
||||||
|
try:
|
||||||
|
cur.execute(
|
||||||
|
"""INSERT INTO music_library
|
||||||
|
(id, title, genre, moods, instruments, duration_sec, bpm, key, scale,
|
||||||
|
prompt, audio_url, file_path, task_id, tags)
|
||||||
|
VALUES (1, 'Integration Test', 'lo-fi', '["chill"]', '["piano"]',
|
||||||
|
120, 85, 'C', 'maj', 'p',
|
||||||
|
'/media/music/x.mp3', '/app/data/music/x.mp3', NULL, '[]')""",
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
except sqlite3.OperationalError as e:
|
||||||
|
# 스키마가 다르면 테스트 스킵 가이드 표시
|
||||||
|
pytest.skip(f"music_library schema mismatch: {e}")
|
||||||
|
conn.close()
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
# `new=...`로 patch한 항목은 mock 인자를 함수에 주입하지 않음.
|
||||||
|
# return_value로 patch한 두 개(cover/video는 new=AsyncMock 사용 → 미주입,
|
||||||
|
# thumb/video는 return_value → 주입, youtube.upload_video도 return_value → 주입).
|
||||||
|
# 데코레이터는 아래에서 위 순서로 인자에 주입됨:
|
||||||
|
# 1) cover (new=AsyncMock) → 미주입
|
||||||
|
# 2) video (return_value) → 주입 → mock_video
|
||||||
|
# 3) thumb (return_value) → 주입 → mock_thumb
|
||||||
|
# 4) metadata (new=AsyncMock) → 미주입
|
||||||
|
# 5) review (new=AsyncMock) → 미주입
|
||||||
|
# 6) youtube.upload_video (return_value) → 주입 → mock_yt
|
||||||
|
@patch("app.pipeline.youtube.upload_video", return_value={"video_id": "VID999"})
|
||||||
|
@patch("app.pipeline.review.run_4_axis", new=AsyncMock(return_value={
|
||||||
|
"metadata_quality": {"score": 80, "notes": "ok"},
|
||||||
|
"policy_compliance": {"score": 90, "issues": []},
|
||||||
|
"viewer_experience": {"score": 80, "notes": "ok"},
|
||||||
|
"trend_alignment": {"score": 70, "matched_keywords": ["lofi"]},
|
||||||
|
"weighted_total": 80.0, "verdict": "pass", "summary": "good", "used_fallback": False,
|
||||||
|
}))
|
||||||
|
@patch("app.pipeline.metadata.generate", new=AsyncMock(return_value={
|
||||||
|
"title": "Integration Test", "description": "Test desc",
|
||||||
|
"tags": ["lofi"], "category_id": 10, "used_fallback": False, "error": None,
|
||||||
|
}))
|
||||||
|
@patch("app.pipeline.thumb.generate", return_value={
|
||||||
|
"url": "/media/videos/1/thumbnail.jpg", "used_fallback": False,
|
||||||
|
})
|
||||||
|
@patch("app.pipeline.video.generate", return_value={
|
||||||
|
"url": "/media/videos/1/video.mp4", "used_fallback": False, "duration_sec": 120,
|
||||||
|
})
|
||||||
|
@patch("app.pipeline.cover.generate", new=AsyncMock(return_value={
|
||||||
|
"url": "/media/videos/1/cover.jpg", "used_fallback": False, "error": None,
|
||||||
|
}))
|
||||||
|
def test_full_pipeline_happy_path(mock_video, mock_thumb, mock_yt, client):
|
||||||
|
# 1. 파이프라인 생성
|
||||||
|
pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"]
|
||||||
|
assert db.get_pipeline(pid)["state"] == "created"
|
||||||
|
|
||||||
|
# 2. 시작 → cover 자동 생성 → cover_pending
|
||||||
|
r = client.post(f"/api/music/pipeline/{pid}/start")
|
||||||
|
assert r.status_code == 202
|
||||||
|
assert db.get_pipeline(pid)["state"] == "cover_pending"
|
||||||
|
|
||||||
|
# 3. 4단계 사용자 승인 (cover, video, thumb, meta)
|
||||||
|
for step in ["cover", "video", "thumb", "meta"]:
|
||||||
|
r = client.post(f"/api/music/pipeline/{pid}/feedback",
|
||||||
|
json={"step": step, "intent": "approve"})
|
||||||
|
assert r.status_code == 202
|
||||||
|
|
||||||
|
# 4. ai_review 자동 진행 후 publish_pending
|
||||||
|
p = db.get_pipeline(pid)
|
||||||
|
assert p["state"] == "publish_pending"
|
||||||
|
assert p["review"]["verdict"] == "pass"
|
||||||
|
|
||||||
|
# 5. 발행 트리거
|
||||||
|
r = client.post(f"/api/music/pipeline/{pid}/publish")
|
||||||
|
assert r.status_code == 202
|
||||||
|
|
||||||
|
# 6. 최종 published
|
||||||
|
p = db.get_pipeline(pid)
|
||||||
|
assert p["state"] == "published"
|
||||||
|
assert p["youtube_video_id"] == "VID999"
|
||||||
|
|
||||||
|
|
||||||
|
# cover.generate는 new=AsyncMock이므로 함수 인자에 주입되지 않음.
|
||||||
|
@patch("app.pipeline.cover.generate", new=AsyncMock(return_value={
|
||||||
|
"url": "/media/videos/2/cover.jpg", "used_fallback": False, "error": None,
|
||||||
|
}))
|
||||||
|
def test_pipeline_reject_and_regenerate(client):
|
||||||
|
pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"]
|
||||||
|
client.post(f"/api/music/pipeline/{pid}/start")
|
||||||
|
assert db.get_pipeline(pid)["state"] == "cover_pending"
|
||||||
|
|
||||||
|
# 반려 + 피드백 → 같은 단계 재진입
|
||||||
|
r = client.post(f"/api/music/pipeline/{pid}/feedback",
|
||||||
|
json={"step": "cover", "intent": "reject", "feedback_text": "더 어둡게"})
|
||||||
|
assert r.status_code == 202
|
||||||
|
p = db.get_pipeline(pid)
|
||||||
|
assert p["state"] == "cover_pending" # 같은 단계 유지
|
||||||
|
assert p["feedback_count_per_step"]["cover"] == 1
|
||||||
|
history = db.get_feedback_history(pid)
|
||||||
|
assert history[0]["feedback_text"] == "더 어둡게"
|
||||||
84
music-lab/tests/test_review.py
Normal file
84
music-lab/tests/test_review.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import pytest
|
||||||
|
import respx
|
||||||
|
from httpx import Response
|
||||||
|
from app.pipeline import review
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_review_returns_pass_when_above_threshold(monkeypatch):
|
||||||
|
monkeypatch.setenv("ANTHROPIC_API_KEY", "k")
|
||||||
|
body = {"content": [{"type": "text", "text":
|
||||||
|
'{"metadata_quality":{"score":80,"notes":"x"},'
|
||||||
|
'"policy_compliance":{"score":90,"issues":[]},'
|
||||||
|
'"viewer_experience":{"score":75,"notes":"y"},'
|
||||||
|
'"trend_alignment":{"score":70,"matched_keywords":["lofi"]},'
|
||||||
|
'"summary":"good"}'}]}
|
||||||
|
respx.post("https://api.anthropic.com/v1/messages").mock(return_value=Response(200, json=body))
|
||||||
|
result = await review.run_4_axis(
|
||||||
|
pipeline={"id": 1}, track={"title": "x", "genre": "lo-fi", "bpm": 85},
|
||||||
|
video_meta={"length_sec": 120, "resolution": "1920x1080"},
|
||||||
|
metadata={"title": "Y", "description": "Z", "tags": ["lofi"], "category_id": 10},
|
||||||
|
thumbnail_url="/m/x.jpg", trend_top=["lofi"],
|
||||||
|
weights={"meta": 25, "policy": 30, "viewer": 25, "trend": 20},
|
||||||
|
threshold=60,
|
||||||
|
)
|
||||||
|
assert result["verdict"] == "pass"
|
||||||
|
expected_total = 0.25 * 80 + 0.30 * 90 + 0.25 * 75 + 0.20 * 70
|
||||||
|
assert result["weighted_total"] == pytest.approx(expected_total, abs=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_review_fail_below_threshold(monkeypatch):
|
||||||
|
monkeypatch.setenv("ANTHROPIC_API_KEY", "k")
|
||||||
|
body = {"content": [{"type": "text", "text":
|
||||||
|
'{"metadata_quality":{"score":40,"notes":"x"},'
|
||||||
|
'"policy_compliance":{"score":50,"issues":[]},'
|
||||||
|
'"viewer_experience":{"score":30,"notes":"y"},'
|
||||||
|
'"trend_alignment":{"score":20,"matched_keywords":[]},'
|
||||||
|
'"summary":"bad"}'}]}
|
||||||
|
respx.post("https://api.anthropic.com/v1/messages").mock(return_value=Response(200, json=body))
|
||||||
|
result = await review.run_4_axis(
|
||||||
|
pipeline={"id": 2}, track={"title": "x", "genre": "lo-fi", "bpm": 85},
|
||||||
|
video_meta={"length_sec": 120, "resolution": "1920x1080"},
|
||||||
|
metadata={"title": "Y", "description": "Z", "tags": [], "category_id": 10},
|
||||||
|
thumbnail_url="/m/x.jpg", trend_top=[],
|
||||||
|
weights={"meta": 25, "policy": 30, "viewer": 25, "trend": 20},
|
||||||
|
threshold=60,
|
||||||
|
)
|
||||||
|
assert result["verdict"] == "fail"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_review_heuristic_fallback_on_llm_error(monkeypatch):
|
||||||
|
monkeypatch.setenv("ANTHROPIC_API_KEY", "k")
|
||||||
|
respx.post("https://api.anthropic.com/v1/messages").mock(return_value=Response(500))
|
||||||
|
result = await review.run_4_axis(
|
||||||
|
pipeline={"id": 3}, track={"title": "x", "genre": "lo-fi", "bpm": 85, "duration_sec": 120},
|
||||||
|
video_meta={"length_sec": 120, "resolution": "1920x1080"},
|
||||||
|
metadata={"title": "Y" * 30, "description": "Z" * 200, "tags": ["a", "b", "c", "d", "e"], "category_id": 10},
|
||||||
|
thumbnail_url="/m/x.jpg", trend_top=["lofi"],
|
||||||
|
weights={"meta": 25, "policy": 30, "viewer": 25, "trend": 20},
|
||||||
|
threshold=60,
|
||||||
|
)
|
||||||
|
assert result["used_fallback"] is True
|
||||||
|
assert "weighted_total" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_review_heuristic_when_no_api_key(monkeypatch):
|
||||||
|
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||||
|
result = await review.run_4_axis(
|
||||||
|
pipeline={"id": 4}, track={"title": "x", "genre": "lo-fi", "bpm": 85, "duration_sec": 120},
|
||||||
|
video_meta={"length_sec": 120, "resolution": "1920x1080"},
|
||||||
|
metadata={"title": "Test Title", "description": "Description here, more text " * 5,
|
||||||
|
"tags": ["lofi", "study", "chill", "ambient", "instrumental"], "category_id": 10},
|
||||||
|
thumbnail_url="/m/x.jpg", trend_top=["lofi"],
|
||||||
|
weights={"meta": 25, "policy": 30, "viewer": 25, "trend": 20},
|
||||||
|
threshold=60,
|
||||||
|
)
|
||||||
|
assert result["used_fallback"] is True
|
||||||
|
# 휴리스틱: 좋은 메타+영상길이 일치+태그 트렌드 겹침 → pass 기대
|
||||||
|
assert result["verdict"] == "pass"
|
||||||
49
music-lab/tests/test_state_machine.py
Normal file
49
music-lab/tests/test_state_machine.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import pytest
|
||||||
|
from app.pipeline.state_machine import (
|
||||||
|
next_state_on_approve, next_state_on_reject, can_transition, STEPS, USER_GATES,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_steps_sequence():
|
||||||
|
assert STEPS == ["cover", "video", "thumb", "meta", "review", "publish"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_gates_excludes_review():
|
||||||
|
assert "review" not in USER_GATES
|
||||||
|
assert "publish" in USER_GATES
|
||||||
|
assert "cover" in USER_GATES
|
||||||
|
|
||||||
|
|
||||||
|
def test_approve_progression():
|
||||||
|
assert next_state_on_approve("cover_pending") == "video_pending"
|
||||||
|
assert next_state_on_approve("video_pending") == "thumb_pending"
|
||||||
|
assert next_state_on_approve("thumb_pending") == "meta_pending"
|
||||||
|
assert next_state_on_approve("meta_pending") == "ai_review"
|
||||||
|
assert next_state_on_approve("publish_pending") == "publishing"
|
||||||
|
|
||||||
|
|
||||||
|
def test_approve_invalid_state_raises():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
next_state_on_approve("ai_review") # 자동 전이 — approve 호출 자체가 무효
|
||||||
|
|
||||||
|
|
||||||
|
def test_reject_keeps_same_state():
|
||||||
|
# 반려는 같은 *_pending 상태를 유지(재생성 트리거)
|
||||||
|
assert next_state_on_reject("cover_pending") == "cover_pending"
|
||||||
|
assert next_state_on_reject("publish_pending") == "publish_pending"
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_transition_blocks_terminal_states():
|
||||||
|
assert not can_transition("published", "cover_pending")
|
||||||
|
assert not can_transition("cancelled", "cover_pending")
|
||||||
|
assert not can_transition("failed", "cover_pending")
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_transition_allows_cancel_from_anywhere():
|
||||||
|
assert can_transition("cover_pending", "cancelled")
|
||||||
|
assert can_transition("publishing", "cancelled")
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_transition_allows_failed_from_pending():
|
||||||
|
assert can_transition("video_pending", "failed")
|
||||||
|
assert can_transition("publishing", "failed")
|
||||||
66
music-lab/tests/test_video_thumb.py
Normal file
66
music-lab/tests/test_video_thumb.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from app.pipeline import video, thumb, storage
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmp_storage(monkeypatch, tmp_path):
|
||||||
|
monkeypatch.setattr(storage, "VIDEO_DATA_DIR", str(tmp_path))
|
||||||
|
# 더미 입력 파일들
|
||||||
|
audio = tmp_path / "audio.mp3"
|
||||||
|
audio.write_bytes(b"\x00" * 100)
|
||||||
|
cover_dir = tmp_path / "50"
|
||||||
|
cover_dir.mkdir()
|
||||||
|
cover = cover_dir / "cover.jpg"
|
||||||
|
cover.write_bytes(b"\x00" * 100)
|
||||||
|
return tmp_path
|
||||||
|
|
||||||
|
|
||||||
|
@patch("subprocess.run")
|
||||||
|
def test_generate_video_calls_ffmpeg(mock_run, tmp_storage):
|
||||||
|
mock_run.return_value = MagicMock(returncode=0, stderr="")
|
||||||
|
out = video.generate(pipeline_id=50, audio_path=str(tmp_storage / "audio.mp3"),
|
||||||
|
cover_path=str(tmp_storage / "50" / "cover.jpg"),
|
||||||
|
genre="lo-fi", duration_sec=120, resolution="1920x1080",
|
||||||
|
style="visualizer")
|
||||||
|
assert out["url"].endswith("/50/video.mp4")
|
||||||
|
assert out["used_fallback"] is False
|
||||||
|
args = mock_run.call_args[0][0]
|
||||||
|
assert args[0] == "ffmpeg"
|
||||||
|
assert "-i" in args
|
||||||
|
assert "showwaves" in " ".join(args)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("subprocess.run")
|
||||||
|
def test_generate_video_failure_marks_failed(mock_run, tmp_storage):
|
||||||
|
mock_run.return_value = MagicMock(returncode=1, stderr="bad codec")
|
||||||
|
with pytest.raises(video.VideoGenerationError):
|
||||||
|
video.generate(pipeline_id=51, audio_path=str(tmp_storage / "audio.mp3"),
|
||||||
|
cover_path=str(tmp_storage / "50" / "cover.jpg"),
|
||||||
|
genre="lo-fi", duration_sec=120, resolution="1920x1080",
|
||||||
|
style="visualizer")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("subprocess.run")
|
||||||
|
def test_thumb_extracts_frame(mock_run, tmp_storage):
|
||||||
|
mock_run.return_value = MagicMock(returncode=0, stderr="")
|
||||||
|
video_path = tmp_storage / "60" / "video.mp4"
|
||||||
|
video_path.parent.mkdir()
|
||||||
|
video_path.write_bytes(b"\x00" * 100)
|
||||||
|
out = thumb.generate(pipeline_id=60, video_path=str(video_path),
|
||||||
|
track_title="Midnight Drive", overlay_text=False)
|
||||||
|
assert out["url"].endswith("/60/thumbnail.jpg")
|
||||||
|
args = mock_run.call_args[0][0]
|
||||||
|
assert args[0] == "ffmpeg"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("subprocess.run")
|
||||||
|
def test_thumb_failure_raises(mock_run, tmp_storage):
|
||||||
|
mock_run.return_value = MagicMock(returncode=1, stderr="bad input")
|
||||||
|
video_path = tmp_storage / "61" / "video.mp4"
|
||||||
|
video_path.parent.mkdir()
|
||||||
|
video_path.write_bytes(b"\x00" * 100)
|
||||||
|
with pytest.raises(thumb.ThumbGenerationError):
|
||||||
|
thumb.generate(pipeline_id=61, video_path=str(video_path),
|
||||||
|
track_title="X", overlay_text=False)
|
||||||
88
music-lab/tests/test_youtube_upload.py
Normal file
88
music-lab/tests/test_youtube_upload.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from app.pipeline import youtube
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fresh_db(monkeypatch, tmp_path):
|
||||||
|
from app import db
|
||||||
|
monkeypatch.setattr(db, "DB_PATH", str(tmp_path / "music.db"))
|
||||||
|
db.init_db()
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_token(db_module):
|
||||||
|
db_module.upsert_oauth_token(
|
||||||
|
channel_id="UC1", channel_title="t", avatar_url=None,
|
||||||
|
refresh_token="r1", access_token="a1", expires_at="2099-01-01T00:00:00",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("app.pipeline.youtube._build_youtube_client")
|
||||||
|
def test_upload_succeeds_after_resumable(mock_client, fresh_db, tmp_path, monkeypatch):
|
||||||
|
monkeypatch.setenv("YOUTUBE_OAUTH_CLIENT_ID", "cid")
|
||||||
|
monkeypatch.setenv("YOUTUBE_OAUTH_CLIENT_SECRET", "sec")
|
||||||
|
_setup_token(fresh_db)
|
||||||
|
|
||||||
|
yt = MagicMock()
|
||||||
|
insert = MagicMock()
|
||||||
|
# next_chunk: first call returns (None, None), second returns (None, response with id)
|
||||||
|
insert.next_chunk.side_effect = [(None, None), (None, {"id": "VID123"})]
|
||||||
|
yt.videos().insert.return_value = insert
|
||||||
|
mock_client.return_value = yt
|
||||||
|
|
||||||
|
video_path = tmp_path / "v.mp4"
|
||||||
|
video_path.write_bytes(b"\x00" * 100)
|
||||||
|
out = youtube.upload_video(
|
||||||
|
video_path=str(video_path),
|
||||||
|
thumbnail_path=None,
|
||||||
|
metadata={"title": "T", "description": "D", "tags": ["x"], "category_id": 10},
|
||||||
|
privacy="private",
|
||||||
|
)
|
||||||
|
assert out["video_id"] == "VID123"
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_no_token_raises(fresh_db, tmp_path):
|
||||||
|
video_path = tmp_path / "v.mp4"
|
||||||
|
video_path.write_bytes(b"\x00")
|
||||||
|
with pytest.raises(youtube.NotAuthenticatedError):
|
||||||
|
youtube.upload_video(
|
||||||
|
video_path=str(video_path), thumbnail_path=None,
|
||||||
|
metadata={"title":"T","description":"D","tags":[],"category_id":10},
|
||||||
|
privacy="private",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("app.pipeline.youtube._build_youtube_client")
|
||||||
|
def test_upload_quota_exceeded_marks_quota(mock_client, fresh_db, tmp_path, monkeypatch):
|
||||||
|
from googleapiclient.errors import HttpError
|
||||||
|
monkeypatch.setenv("YOUTUBE_OAUTH_CLIENT_ID", "cid")
|
||||||
|
monkeypatch.setenv("YOUTUBE_OAUTH_CLIENT_SECRET", "sec")
|
||||||
|
_setup_token(fresh_db)
|
||||||
|
|
||||||
|
yt = MagicMock()
|
||||||
|
err = HttpError(MagicMock(status=403), b'{"error":{"errors":[{"reason":"quotaExceeded"}]}}')
|
||||||
|
insert_call = MagicMock()
|
||||||
|
insert_call.next_chunk.side_effect = err
|
||||||
|
yt.videos().insert.return_value = insert_call
|
||||||
|
mock_client.return_value = yt
|
||||||
|
|
||||||
|
video_path = tmp_path / "v.mp4"
|
||||||
|
video_path.write_bytes(b"\x00")
|
||||||
|
with pytest.raises(youtube.QuotaExceededError):
|
||||||
|
youtube.upload_video(
|
||||||
|
video_path=str(video_path), thumbnail_path=None,
|
||||||
|
metadata={"title":"T","description":"D","tags":[],"category_id":10},
|
||||||
|
privacy="private",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_status_returns_none_when_not_connected(fresh_db):
|
||||||
|
assert youtube.get_status() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_status_returns_channel_info(fresh_db):
|
||||||
|
_setup_token(fresh_db)
|
||||||
|
s = youtube.get_status()
|
||||||
|
assert s["channel_id"] == "UC1"
|
||||||
|
assert s["channel_title"] == "t"
|
||||||
Reference in New Issue
Block a user