Compare commits
13 Commits
e03d074222
...
4f67cd02fa
| Author | SHA1 | Date | |
|---|---|---|---|
| 4f67cd02fa | |||
| 868906b8c6 | |||
| bd97cc1e97 | |||
| 7552ce4263 | |||
| 17034ea6ea | |||
| fe60c8d330 | |||
| 4755e34c14 | |||
| ad1c721ba8 | |||
| 1c705b0ef3 | |||
| 68dec2e53d | |||
| e33a2310af | |||
| fceca88db4 | |||
| d66a321982 |
@@ -4,6 +4,7 @@ from .blog import BlogAgent
|
||||
from .realestate import RealestateAgent
|
||||
from .lotto import LottoAgent
|
||||
from .youtube import YouTubeResearchAgent
|
||||
from .youtube_publisher import YoutubePublisherAgent
|
||||
|
||||
AGENT_REGISTRY = {}
|
||||
|
||||
@@ -14,6 +15,7 @@ def init_agents():
|
||||
AGENT_REGISTRY["realestate"] = RealestateAgent()
|
||||
AGENT_REGISTRY["lotto"] = LottoAgent()
|
||||
AGENT_REGISTRY["youtube"] = YouTubeResearchAgent()
|
||||
AGENT_REGISTRY["youtube_publisher"] = YoutubePublisherAgent()
|
||||
|
||||
def get_agent(agent_id: str):
|
||||
return AGENT_REGISTRY.get(agent_id)
|
||||
|
||||
75
agent-office/app/agents/classify_intent.py
Normal file
75
agent-office/app/agents/classify_intent.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""텔레그램 사용자 응답 자연어 분류 — 화이트리스트 우선, 모호 시 LLM."""
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger("agent-office.classify_intent")
|
||||
|
||||
CLAUDE_HAIKU_DEFAULT = "claude-haiku-4-5-20251001"
|
||||
|
||||
APPROVE_WORDS = {
|
||||
"승인", "시작", "진행", "ok", "okay", "agree",
|
||||
"네", "예", "좋아", "좋아요", "go", "yes", "y",
|
||||
}
|
||||
REJECT_WORDS = {"반려", "거절", "취소", "no", "nope", "n"}
|
||||
|
||||
|
||||
def _get_api_key() -> str:
|
||||
return os.getenv("ANTHROPIC_API_KEY", "")
|
||||
|
||||
|
||||
def _get_model() -> str:
|
||||
return os.getenv("CLAUDE_HAIKU_MODEL", CLAUDE_HAIKU_DEFAULT)
|
||||
|
||||
|
||||
def classify(text: str) -> tuple[str, str | None]:
|
||||
"""returns (intent, feedback) — intent ∈ {approve, reject, unclear}"""
|
||||
if not text:
|
||||
return ("unclear", None)
|
||||
t = text.strip().lower()
|
||||
if t in APPROVE_WORDS:
|
||||
return ("approve", None)
|
||||
if t in REJECT_WORDS:
|
||||
return ("reject", None)
|
||||
# 반려 단어로 시작 + 추가 텍스트
|
||||
for w in REJECT_WORDS:
|
||||
if t.startswith(w):
|
||||
rest = text.strip()[len(w):].lstrip(" ,.-:").strip()
|
||||
if rest:
|
||||
return ("reject", rest)
|
||||
# 승인 단어로 시작 (긍정 의도면 추가 텍스트 무시)
|
||||
for w in APPROVE_WORDS:
|
||||
if t.startswith(w + " ") or t == w:
|
||||
return ("approve", None)
|
||||
return _llm_classify(text)
|
||||
|
||||
|
||||
def _llm_classify(text: str) -> tuple[str, str | None]:
|
||||
api_key = _get_api_key()
|
||||
if not api_key:
|
||||
return ("unclear", None)
|
||||
prompt = (
|
||||
"사용자 응답을 분류하세요. JSON으로만 응답.\n"
|
||||
f'응답: "{text}"\n\n'
|
||||
'출력: {"intent":"approve|reject|unclear","feedback":"반려면 수정 방향, 아니면 빈 문자열"}'
|
||||
)
|
||||
try:
|
||||
resp = httpx.post(
|
||||
"https://api.anthropic.com/v1/messages",
|
||||
headers={"x-api-key": api_key, "anthropic-version": "2023-06-01"},
|
||||
json={"model": _get_model(), "max_tokens": 200,
|
||||
"messages": [{"role": "user", "content": prompt}]},
|
||||
timeout=15,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
text_out = resp.json()["content"][0]["text"]
|
||||
start = text_out.find("{")
|
||||
end = text_out.rfind("}") + 1
|
||||
if start < 0 or end <= start:
|
||||
return ("unclear", None)
|
||||
data = json.loads(text_out[start:end])
|
||||
return (data.get("intent", "unclear"), data.get("feedback") or None)
|
||||
except (httpx.HTTPError, httpx.TimeoutException, KeyError, ValueError, json.JSONDecodeError) as e:
|
||||
logger.warning("LLM 분류 실패: %s", e)
|
||||
return ("unclear", None)
|
||||
108
agent-office/app/agents/youtube_publisher.py
Normal file
108
agent-office/app/agents/youtube_publisher.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""텔레그램 단일 채널로 단계별 승인 인터랙션 오케스트레이션."""
|
||||
import logging
|
||||
|
||||
from .base import BaseAgent
|
||||
from . import classify_intent
|
||||
from .. import service_proxy
|
||||
from ..db import add_log
|
||||
from ..telegram.messaging import send_raw
|
||||
|
||||
logger = logging.getLogger("agent-office.youtube_publisher")
|
||||
|
||||
|
||||
_STEP_TITLES = {
|
||||
"cover_pending": ("커버 아트", "cover"),
|
||||
"video_pending": ("영상 비주얼", "video"),
|
||||
"thumb_pending": ("썸네일", "thumb"),
|
||||
"meta_pending": ("메타데이터", "meta"),
|
||||
"publish_pending": ("최종 검토 + 발행", "publish"),
|
||||
}
|
||||
|
||||
|
||||
class YoutubePublisherAgent(BaseAgent):
|
||||
agent_id = "youtube_publisher"
|
||||
display_name = "YouTube 퍼블리셔"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._notified_state_per_pipeline: dict[int, str] = {}
|
||||
|
||||
async def poll_state_changes(self) -> None:
|
||||
"""주기적으로 호출되어 *_pending 신규 진입 시 텔레그램 발송."""
|
||||
try:
|
||||
pipelines = await service_proxy.list_active_pipelines()
|
||||
except Exception as e:
|
||||
logger.warning("폴링 실패: %s", e)
|
||||
return
|
||||
|
||||
for p in pipelines:
|
||||
state = p.get("state")
|
||||
pid = p.get("id")
|
||||
if pid is None:
|
||||
continue
|
||||
if state in _STEP_TITLES and self._notified_state_per_pipeline.get(pid) != state:
|
||||
await self._notify_step(p)
|
||||
self._notified_state_per_pipeline[pid] = state
|
||||
|
||||
async def _notify_step(self, pipeline: dict) -> None:
|
||||
state = pipeline["state"]
|
||||
title_name, step = _STEP_TITLES[state]
|
||||
body = self._format_body(pipeline, step)
|
||||
track_title = pipeline.get("track_title") or f"Pipeline #{pipeline['id']}"
|
||||
text = (
|
||||
f"🎵 [{track_title}] {title_name} 검토\n\n"
|
||||
f"{body}\n\n"
|
||||
f"➡️ 답장으로 알려주세요: '승인' 또는 '반려 + 수정 방향'"
|
||||
)
|
||||
sent = await send_raw(text=text)
|
||||
if sent.get("ok"):
|
||||
msg_id = sent.get("message_id")
|
||||
try:
|
||||
await service_proxy.save_pipeline_telegram_msg(pipeline["id"], step, msg_id)
|
||||
except Exception as e:
|
||||
logger.warning("telegram-msg 저장 실패: %s", e)
|
||||
add_log(self.agent_id, f"pipeline {pipeline['id']} {step} 알림 전송", "info")
|
||||
|
||||
def _format_body(self, p: dict, step: str) -> str:
|
||||
if step == "cover":
|
||||
return f"🖼️ 커버: {p.get('cover_url', '-')}"
|
||||
if step == "video":
|
||||
return f"🎬 영상: {p.get('video_url', '-')}"
|
||||
if step == "thumb":
|
||||
return f"🎴 썸네일: {p.get('thumbnail_url', '-')}"
|
||||
if step == "meta":
|
||||
m = p.get("metadata", {}) or {}
|
||||
tags = m.get("tags", []) or []
|
||||
description = (m.get("description", "") or "")
|
||||
return (
|
||||
f"📝 제목: {m.get('title', '')}\n"
|
||||
f"🏷️ 태그: {', '.join(tags[:8])}\n"
|
||||
f"📄 설명(앞부분): {description[:200]}"
|
||||
)
|
||||
if step == "publish":
|
||||
r = p.get("review", {}) or {}
|
||||
return (
|
||||
f"AI 검토 결과: {r.get('verdict', '?')} "
|
||||
f"(가중 {r.get('weighted_total', '?')}/100)\n"
|
||||
f"{r.get('summary', '')}"
|
||||
)
|
||||
return ""
|
||||
|
||||
async def on_telegram_reply(self, pipeline_id: int, step: str, user_text: str) -> None:
|
||||
intent, feedback = classify_intent.classify(user_text)
|
||||
if intent == "unclear":
|
||||
await send_raw("다시 입력해주세요. 예: '승인' 또는 '반려, 제목 짧게'")
|
||||
return
|
||||
try:
|
||||
await service_proxy.post_pipeline_feedback(pipeline_id, step, intent, feedback)
|
||||
except Exception as e:
|
||||
await send_raw(f"⚠️ 처리 실패: {e}")
|
||||
|
||||
async def on_schedule(self) -> None:
|
||||
await self.poll_state_changes()
|
||||
|
||||
async def on_command(self, command: str, params: dict) -> dict:
|
||||
return {"ok": False, "message": f"Unknown command: {command}"}
|
||||
|
||||
async def on_approval(self, task_id: str, approved: bool, feedback: str = "") -> None:
|
||||
pass
|
||||
@@ -34,6 +34,11 @@ async def _send_youtube_weekly_report():
|
||||
if agent:
|
||||
await agent.send_weekly_report()
|
||||
|
||||
async def _poll_pipelines():
|
||||
agent = AGENT_REGISTRY.get("youtube_publisher")
|
||||
if agent:
|
||||
await agent.poll_state_changes()
|
||||
|
||||
def init_scheduler():
|
||||
scheduler.add_job(_run_stock_schedule, "cron", hour=7, minute=30, id="stock_news")
|
||||
scheduler.add_job(_run_blog_schedule, "cron", hour=10, minute=0, id="blog_pipeline")
|
||||
@@ -41,4 +46,5 @@ def init_scheduler():
|
||||
scheduler.add_job(_run_youtube_research, "cron", hour=9, minute=0, id="youtube_research")
|
||||
scheduler.add_job(_send_youtube_weekly_report, "cron", day_of_week="mon", hour=8, minute=0, id="youtube_weekly_report")
|
||||
scheduler.add_job(_check_idle_breaks, "interval", seconds=60, id="idle_check")
|
||||
scheduler.add_job(_poll_pipelines, "interval", seconds=30, id="pipeline_poll")
|
||||
scheduler.start()
|
||||
|
||||
@@ -178,3 +178,46 @@ async def lotto_save_briefing(payload: dict) -> Dict[str, Any]:
|
||||
resp = await _client.post(f"{LOTTO_BACKEND_URL}/api/lotto/briefing", json=payload)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
# --- music-lab pipeline (YouTube publisher orchestration) ---
|
||||
|
||||
async def list_active_pipelines() -> list[dict]:
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
resp = await client.get(f"{MUSIC_LAB_URL}/api/music/pipeline?status=active")
|
||||
resp.raise_for_status()
|
||||
return resp.json().get("pipelines", [])
|
||||
|
||||
|
||||
async def get_pipeline(pid: int) -> dict:
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
resp = await client.get(f"{MUSIC_LAB_URL}/api/music/pipeline/{pid}")
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def post_pipeline_feedback(pid: int, step: str, intent: str,
|
||||
feedback_text: Optional[str] = None) -> dict:
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
resp = await client.post(
|
||||
f"{MUSIC_LAB_URL}/api/music/pipeline/{pid}/feedback",
|
||||
json={"step": step, "intent": intent, "feedback_text": feedback_text},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def save_pipeline_telegram_msg(pid: int, step: str, msg_id: int) -> None:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
await client.patch(
|
||||
f"{MUSIC_LAB_URL}/api/music/pipeline/{pid}/telegram-msg",
|
||||
json={"step": step, "message_id": msg_id},
|
||||
)
|
||||
|
||||
|
||||
async def lookup_pipeline_by_msg(msg_id: int) -> Optional[dict]:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
resp = await client.get(f"{MUSIC_LAB_URL}/api/music/pipeline/lookup-by-msg/{msg_id}")
|
||||
if resp.status_code == 200:
|
||||
return resp.json()
|
||||
return None
|
||||
|
||||
@@ -103,6 +103,34 @@ def _build_messages(history: list, user_text: str) -> list:
|
||||
return msgs
|
||||
|
||||
|
||||
async def maybe_route_to_pipeline(message: dict) -> bool:
|
||||
"""파이프라인 텔레그램 메시지에 대한 reply 인 경우 youtube_publisher 로 라우팅.
|
||||
|
||||
Returns True if message was routed (caller should stop further processing).
|
||||
"""
|
||||
reply_to = message.get("reply_to_message") or {}
|
||||
msg_id = reply_to.get("message_id")
|
||||
if not msg_id:
|
||||
return False
|
||||
from .. import service_proxy
|
||||
try:
|
||||
link = await service_proxy.lookup_pipeline_by_msg(msg_id)
|
||||
except Exception:
|
||||
return False
|
||||
if not link:
|
||||
return False
|
||||
from ..agents import AGENT_REGISTRY
|
||||
agent = AGENT_REGISTRY.get("youtube_publisher")
|
||||
if not agent:
|
||||
return False
|
||||
pipeline_id = link.get("pipeline_id")
|
||||
step = link.get("step")
|
||||
if pipeline_id is None or not step:
|
||||
return False
|
||||
await agent.on_telegram_reply(pipeline_id, step, message.get("text", ""))
|
||||
return True
|
||||
|
||||
|
||||
async def respond_to_message(chat_id: str, user_text: str) -> Optional[str]:
|
||||
"""자연어 메시지에 응답. 실패 시 사용자에게 돌려줄 문자열 반환(또는 None = 무시)."""
|
||||
if not ANTHROPIC_API_KEY:
|
||||
|
||||
@@ -102,6 +102,11 @@ async def _handle_message(message: dict, agent_dispatcher) -> Optional[dict]:
|
||||
from .router import parse_command, resolve_agent_command, HELP_TEXT
|
||||
from .messaging import send_raw, send_agent_message
|
||||
from .agent_registry import AGENT_META
|
||||
from .conversational import maybe_route_to_pipeline
|
||||
|
||||
# 파이프라인 메시지에 대한 reply라면 youtube_publisher 로 라우팅
|
||||
if await maybe_route_to_pipeline(message):
|
||||
return {"handled": "pipeline_reply"}
|
||||
|
||||
text = message.get("text", "")
|
||||
parsed = parse_command(text)
|
||||
|
||||
@@ -3,5 +3,6 @@ uvicorn[standard]==0.30.6
|
||||
apscheduler==3.10.4
|
||||
websockets>=12.0
|
||||
httpx>=0.27
|
||||
respx>=0.21
|
||||
google-api-python-client>=2.100.0
|
||||
pytrends>=4.9.2
|
||||
|
||||
48
agent-office/tests/test_classify_intent.py
Normal file
48
agent-office/tests/test_classify_intent.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import pytest
|
||||
import respx
|
||||
from httpx import Response
|
||||
from app.agents import classify_intent as ci
|
||||
|
||||
|
||||
def test_clear_approve_no_llm(monkeypatch):
|
||||
# Patch _llm_classify so we can assert it wasn't called
|
||||
called = {"n": 0}
|
||||
def fake(text):
|
||||
called["n"] += 1
|
||||
return ("unclear", None)
|
||||
monkeypatch.setattr(ci, "_llm_classify", fake)
|
||||
assert ci.classify("승인") == ("approve", None)
|
||||
assert ci.classify("OK") == ("approve", None)
|
||||
assert ci.classify("진행") == ("approve", None)
|
||||
assert ci.classify("agree") == ("approve", None)
|
||||
assert called["n"] == 0
|
||||
|
||||
|
||||
def test_clear_reject_only_no_llm(monkeypatch):
|
||||
monkeypatch.setattr(ci, "_llm_classify", lambda t: ("unclear", None))
|
||||
assert ci.classify("반려") == ("reject", None)
|
||||
assert ci.classify("거절") == ("reject", None)
|
||||
|
||||
|
||||
def test_reject_with_text_split(monkeypatch):
|
||||
monkeypatch.setattr(ci, "_llm_classify", lambda t: ("unclear", None))
|
||||
intent, fb = ci.classify("반려, 제목 짧게")
|
||||
assert intent == "reject"
|
||||
assert "제목 짧게" in fb
|
||||
|
||||
|
||||
@respx.mock
|
||||
def test_ambiguous_calls_llm(monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "k")
|
||||
respx.post("https://api.anthropic.com/v1/messages").mock(
|
||||
return_value=Response(200, json={"content": [{"type": "text",
|
||||
"text": '{"intent":"reject","feedback":"좀 더 화려하게"}'}]})
|
||||
)
|
||||
intent, fb = ci.classify("음... 좀 더 화려한 분위기가 좋겠어")
|
||||
assert intent == "reject"
|
||||
assert "화려하게" in fb
|
||||
|
||||
|
||||
def test_empty_text_returns_unclear():
|
||||
assert ci.classify("") == ("unclear", None)
|
||||
assert ci.classify(None) == ("unclear", None)
|
||||
110
agent-office/tests/test_pipeline_polling.py
Normal file
110
agent-office/tests/test_pipeline_polling.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
_fd, _TMP = tempfile.mkstemp(suffix=".db")
|
||||
os.close(_fd)
|
||||
os.unlink(_TMP)
|
||||
os.environ["AGENT_OFFICE_DB_PATH"] = _TMP
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init_db():
|
||||
import gc
|
||||
gc.collect()
|
||||
if os.path.exists(_TMP):
|
||||
os.remove(_TMP)
|
||||
from app.db import init_db
|
||||
init_db()
|
||||
yield
|
||||
gc.collect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_notifies_once_per_state():
|
||||
from app.agents.youtube_publisher import YoutubePublisherAgent
|
||||
|
||||
pipelines = [{
|
||||
"id": 1,
|
||||
"state": "cover_pending",
|
||||
"cover_url": "/x.jpg",
|
||||
"track_title": "Test",
|
||||
}]
|
||||
with patch(
|
||||
"app.agents.youtube_publisher.service_proxy.list_active_pipelines",
|
||||
new=AsyncMock(return_value=pipelines),
|
||||
), patch(
|
||||
"app.agents.youtube_publisher.send_raw",
|
||||
new=AsyncMock(return_value={"ok": True, "message_id": 99}),
|
||||
) as mock_send, patch(
|
||||
"app.agents.youtube_publisher.service_proxy.save_pipeline_telegram_msg",
|
||||
new=AsyncMock(),
|
||||
):
|
||||
a = YoutubePublisherAgent()
|
||||
await a.poll_state_changes()
|
||||
await a.poll_state_changes() # 같은 상태 — 두 번째는 알림 안 함
|
||||
assert mock_send.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_telegram_reply_approve_calls_feedback():
|
||||
from app.agents.youtube_publisher import YoutubePublisherAgent
|
||||
|
||||
with patch(
|
||||
"app.agents.youtube_publisher.service_proxy.post_pipeline_feedback",
|
||||
new=AsyncMock(),
|
||||
) as mock_fb, patch(
|
||||
"app.agents.youtube_publisher.send_raw",
|
||||
new=AsyncMock(),
|
||||
):
|
||||
a = YoutubePublisherAgent()
|
||||
await a.on_telegram_reply(pipeline_id=42, step="cover", user_text="승인")
|
||||
mock_fb.assert_called_once_with(42, "cover", "approve", None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_telegram_reply_reject_with_feedback():
|
||||
from app.agents.youtube_publisher import YoutubePublisherAgent
|
||||
|
||||
with patch(
|
||||
"app.agents.youtube_publisher.service_proxy.post_pipeline_feedback",
|
||||
new=AsyncMock(),
|
||||
) as mock_fb, patch(
|
||||
"app.agents.youtube_publisher.send_raw",
|
||||
new=AsyncMock(),
|
||||
):
|
||||
a = YoutubePublisherAgent()
|
||||
await a.on_telegram_reply(pipeline_id=43, step="meta", user_text="반려, 제목 짧게")
|
||||
args = mock_fb.call_args[0]
|
||||
assert args[0] == 43
|
||||
assert args[1] == "meta"
|
||||
assert args[2] == "reject"
|
||||
assert "제목 짧게" in (args[3] or "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_telegram_reply_unclear_asks_again():
|
||||
from app.agents.youtube_publisher import YoutubePublisherAgent
|
||||
|
||||
sent = []
|
||||
|
||||
async def mock_send(text=None, **kw):
|
||||
sent.append(text)
|
||||
return {"ok": True, "message_id": 1}
|
||||
|
||||
with patch(
|
||||
"app.agents.youtube_publisher.send_raw",
|
||||
new=mock_send,
|
||||
), patch(
|
||||
"app.agents.youtube_publisher.classify_intent.classify",
|
||||
return_value=("unclear", None),
|
||||
):
|
||||
a = YoutubePublisherAgent()
|
||||
await a.on_telegram_reply(pipeline_id=44, step="cover", user_text="huh?")
|
||||
assert any("다시 입력" in (s or "") for s in sent)
|
||||
@@ -66,6 +66,12 @@ services:
|
||||
- CORS_ALLOW_ORIGINS=${CORS_ALLOW_ORIGINS:-http://localhost:3007,http://localhost:8080}
|
||||
- PEXELS_API_KEY=${PEXELS_API_KEY:-}
|
||||
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY:-}
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
|
||||
- YOUTUBE_OAUTH_CLIENT_ID=${YOUTUBE_OAUTH_CLIENT_ID:-}
|
||||
- YOUTUBE_OAUTH_CLIENT_SECRET=${YOUTUBE_OAUTH_CLIENT_SECRET:-}
|
||||
- YOUTUBE_OAUTH_REDIRECT_URI=${YOUTUBE_OAUTH_REDIRECT_URI:-}
|
||||
- CLAUDE_HAIKU_MODEL=${CLAUDE_HAIKU_MODEL:-claude-haiku-4-5-20251001}
|
||||
- CLAUDE_SONNET_MODEL=${CLAUDE_SONNET_MODEL:-claude-sonnet-4-6}
|
||||
- VIDEO_DATA_DIR=${VIDEO_DATA_DIR:-/app/data/videos}
|
||||
volumes:
|
||||
- ${RUNTIME_PATH}/data/music:/app/data
|
||||
@@ -137,6 +143,8 @@ services:
|
||||
- TELEGRAM_WEBHOOK_URL=${TELEGRAM_WEBHOOK_URL:-}
|
||||
- TELEGRAM_WIFE_CHAT_ID=${TELEGRAM_WIFE_CHAT_ID:-}
|
||||
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY:-}
|
||||
- CLAUDE_HAIKU_MODEL=${CLAUDE_HAIKU_MODEL:-claude-haiku-4-5-20251001}
|
||||
- CLAUDE_SONNET_MODEL=${CLAUDE_SONNET_MODEL:-claude-sonnet-4-6}
|
||||
- LOTTO_BACKEND_URL=${LOTTO_BACKEND_URL:-http://lotto:8000}
|
||||
- LOTTO_CURATOR_MODEL=${LOTTO_CURATOR_MODEL:-claude-sonnet-4-5}
|
||||
- CONVERSATION_MODEL=${CONVERSATION_MODEL:-claude-haiku-4-5-20251001}
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
FROM python:3.12-alpine
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
RUN apk add --no-cache ffmpeg
|
||||
# ffmpeg for audio/video processing, ttf-dejavu + fontconfig for PIL overlays.
|
||||
# Alpine installs DejaVu fonts to /usr/share/fonts/dejavu/, but app code
|
||||
# references the Debian-style path; symlink for compatibility.
|
||||
RUN apk add --no-cache ffmpeg ttf-dejavu fontconfig \
|
||||
&& mkdir -p /usr/share/fonts/truetype \
|
||||
&& ln -sf /usr/share/fonts/dejavu /usr/share/fonts/truetype/dejavu \
|
||||
&& fc-cache -f
|
||||
|
||||
WORKDIR /app
|
||||
COPY requirements.txt .
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import sqlite3
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
DB_PATH = "/app/data/music.db"
|
||||
@@ -184,6 +185,112 @@ def init_db() -> None:
|
||||
)
|
||||
""")
|
||||
|
||||
# ── YouTube pipeline 테이블 (5개) ─────────────────────────────────
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS video_pipelines (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
track_id INTEGER NOT NULL,
|
||||
state TEXT NOT NULL DEFAULT 'created',
|
||||
state_started_at TEXT NOT NULL,
|
||||
cover_url TEXT,
|
||||
video_url TEXT,
|
||||
thumbnail_url TEXT,
|
||||
metadata_json TEXT,
|
||||
review_json TEXT,
|
||||
youtube_video_id TEXT,
|
||||
feedback_count_per_step TEXT NOT NULL DEFAULT '{}',
|
||||
last_telegram_msg_ids TEXT NOT NULL DEFAULT '{}',
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
cancelled_at TEXT,
|
||||
failed_reason TEXT
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS pipeline_jobs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
pipeline_id INTEGER NOT NULL,
|
||||
step TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'queued',
|
||||
error TEXT,
|
||||
started_at TEXT,
|
||||
finished_at TEXT,
|
||||
duration_ms INTEGER,
|
||||
FOREIGN KEY (pipeline_id) REFERENCES video_pipelines(id)
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS pipeline_feedback (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
pipeline_id INTEGER NOT NULL,
|
||||
step TEXT NOT NULL,
|
||||
feedback_text TEXT NOT NULL,
|
||||
received_at TEXT NOT NULL,
|
||||
FOREIGN KEY (pipeline_id) REFERENCES video_pipelines(id)
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS youtube_oauth_tokens (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
channel_id TEXT NOT NULL,
|
||||
channel_title TEXT,
|
||||
avatar_url TEXT,
|
||||
refresh_token TEXT NOT NULL,
|
||||
access_token TEXT,
|
||||
expires_at TEXT,
|
||||
created_at TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS youtube_setup (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT CHECK (id = 1),
|
||||
metadata_template_json TEXT NOT NULL,
|
||||
cover_prompts_json TEXT NOT NULL,
|
||||
review_weights_json TEXT NOT NULL,
|
||||
review_threshold INTEGER NOT NULL DEFAULT 60,
|
||||
visual_defaults_json TEXT NOT NULL,
|
||||
publish_policy_json TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
|
||||
# 기본 setup 1행 보장
|
||||
cnt_row = conn.execute("SELECT COUNT(*) AS c FROM youtube_setup").fetchone()
|
||||
if cnt_row["c"] == 0:
|
||||
_seed_default_youtube_setup(conn)
|
||||
|
||||
|
||||
def _seed_default_youtube_setup(conn: sqlite3.Connection) -> None:
|
||||
"""youtube_setup 테이블에 기본 1행을 삽입한다.
|
||||
|
||||
init_db()와 get_youtube_setup() (행이 사라진 경우 self-heal)가 공유한다.
|
||||
"""
|
||||
defaults = (
|
||||
json.dumps({
|
||||
"title": "[{genre}] {title} | {bpm}BPM",
|
||||
"description": "{title}\n\n장르: {genre}\nBPM: {bpm}\nKey: {key}\n",
|
||||
"tags": ["lofi", "chill", "instrumental"],
|
||||
"category_id": 10,
|
||||
}),
|
||||
json.dumps({
|
||||
"lo-fi": "moody anime cityscape at dusk, lofi aesthetic",
|
||||
"phonk": "dark drift car aesthetic, neon, phonk vibe",
|
||||
"ambient": "ethereal mountain landscape, ambient mood",
|
||||
"default": "abstract music album cover art",
|
||||
}),
|
||||
json.dumps({"meta": 25, "policy": 30, "viewer": 25, "trend": 20}),
|
||||
60,
|
||||
json.dumps({"resolution": "1920x1080", "style": "visualizer", "background": "ai_cover"}),
|
||||
json.dumps({"mode": "manual", "privacy": "private", "schedule_time": None}),
|
||||
datetime.utcnow().isoformat(timespec="seconds"),
|
||||
)
|
||||
conn.execute("""
|
||||
INSERT INTO youtube_setup
|
||||
(metadata_template_json, cover_prompts_json, review_weights_json,
|
||||
review_threshold, visual_defaults_json, publish_policy_json, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""", defaults)
|
||||
|
||||
|
||||
# ── music_tasks CRUD ──────────────────────────────────────────────────────────
|
||||
|
||||
@@ -791,3 +898,253 @@ def update_compile_job(job_id: int, **kwargs) -> None:
|
||||
def delete_compile_job(job_id: int) -> None:
|
||||
with _conn() as conn:
|
||||
conn.execute("DELETE FROM compile_jobs WHERE id = ?", (job_id,))
|
||||
|
||||
|
||||
# ── YouTube pipeline helpers ──────────────────────────────────────────────────
|
||||
|
||||
# update_pipeline_state: state/state_started_at/updated_at은 자동, 그 외 허용 컬럼 화이트리스트
|
||||
_PIPELINE_STATE_EXTRA_COLS = frozenset({
|
||||
"cover_url",
|
||||
"video_url",
|
||||
"thumbnail_url",
|
||||
"metadata_json",
|
||||
"review_json",
|
||||
"youtube_video_id",
|
||||
"cancelled_at",
|
||||
"failed_reason",
|
||||
"last_telegram_msg_ids",
|
||||
"feedback_count_per_step",
|
||||
})
|
||||
|
||||
# update_pipeline_job 허용 컬럼 화이트리스트
|
||||
_PIPELINE_JOB_COLS = frozenset({
|
||||
"status",
|
||||
"error",
|
||||
"duration_ms",
|
||||
"started_at",
|
||||
"finished_at",
|
||||
})
|
||||
|
||||
|
||||
def _now() -> str:
|
||||
return datetime.utcnow().isoformat(timespec="seconds")
|
||||
|
||||
|
||||
def _parse_pipeline_row(row: sqlite3.Row) -> Dict[str, Any]:
|
||||
"""video_pipelines의 sqlite3.Row를 dict로 파싱.
|
||||
|
||||
JSON 컬럼을 디코드하고, metadata/review를 호환을 위해 추가로 노출한다.
|
||||
get_pipeline / list_pipelines가 공유.
|
||||
"""
|
||||
d = dict(row)
|
||||
d["feedback_count_per_step"] = json.loads(d["feedback_count_per_step"] or "{}")
|
||||
d["last_telegram_msg_ids"] = json.loads(d["last_telegram_msg_ids"] or "{}")
|
||||
if d.get("metadata_json"):
|
||||
d["metadata"] = json.loads(d["metadata_json"])
|
||||
if d.get("review_json"):
|
||||
d["review"] = json.loads(d["review_json"])
|
||||
return d
|
||||
|
||||
|
||||
def create_pipeline(track_id: int) -> int:
|
||||
with _conn() as conn:
|
||||
now = _now()
|
||||
cur = conn.execute("""
|
||||
INSERT INTO video_pipelines (track_id, state, state_started_at, created_at, updated_at)
|
||||
VALUES (?, 'created', ?, ?, ?)
|
||||
""", (track_id, now, now, now))
|
||||
return cur.lastrowid
|
||||
|
||||
|
||||
def get_pipeline(pid: int) -> Optional[Dict[str, Any]]:
|
||||
with _conn() as conn:
|
||||
row = conn.execute("""
|
||||
SELECT vp.*, ml.title AS track_title
|
||||
FROM video_pipelines vp
|
||||
LEFT JOIN music_library ml ON ml.id = vp.track_id
|
||||
WHERE vp.id = ?
|
||||
""", (pid,)).fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return _parse_pipeline_row(row)
|
||||
|
||||
|
||||
def update_pipeline_state(pid: int, state: str, **fields) -> None:
|
||||
"""파이프라인 state를 갱신하고 옵션 컬럼을 함께 업데이트한다.
|
||||
|
||||
허용 컬럼 화이트리스트(_PIPELINE_STATE_EXTRA_COLS)에 없는 키는 ValueError.
|
||||
"""
|
||||
unknown = set(fields) - _PIPELINE_STATE_EXTRA_COLS
|
||||
if unknown:
|
||||
raise ValueError(f"unknown columns for update_pipeline_state: {sorted(unknown)}")
|
||||
|
||||
now = _now()
|
||||
cols = ["state = ?", "state_started_at = ?", "updated_at = ?"]
|
||||
vals: List[Any] = [state, now, now]
|
||||
for k, v in fields.items():
|
||||
cols.append(f"{k} = ?")
|
||||
vals.append(v)
|
||||
vals.append(pid)
|
||||
with _conn() as conn:
|
||||
conn.execute(f"UPDATE video_pipelines SET {', '.join(cols)} WHERE id = ?", vals)
|
||||
|
||||
|
||||
def list_pipelines(active_only: bool = False) -> List[Dict[str, Any]]:
|
||||
sql = """
|
||||
SELECT vp.*, ml.title AS track_title
|
||||
FROM video_pipelines vp
|
||||
LEFT JOIN music_library ml ON ml.id = vp.track_id
|
||||
"""
|
||||
if active_only:
|
||||
sql += " WHERE vp.state NOT IN ('published','cancelled','failed','awaiting_manual')"
|
||||
sql += " ORDER BY vp.created_at DESC"
|
||||
with _conn() as conn:
|
||||
rows = conn.execute(sql).fetchall()
|
||||
return [_parse_pipeline_row(r) for r in rows]
|
||||
|
||||
|
||||
def increment_feedback_count(pid: int, step: str) -> int:
|
||||
"""원자적으로 feedback_count_per_step.<step>를 +1 한 뒤 새 값을 반환.
|
||||
|
||||
json1 확장(SQLite 3.38+)을 사용해 read-modify-write 경합을 제거한다.
|
||||
"""
|
||||
now = _now()
|
||||
with _conn() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE video_pipelines
|
||||
SET feedback_count_per_step = json_set(
|
||||
feedback_count_per_step,
|
||||
'$.' || ?,
|
||||
COALESCE(json_extract(feedback_count_per_step, '$.' || ?), 0) + 1
|
||||
),
|
||||
updated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(step, step, now, pid),
|
||||
)
|
||||
row = conn.execute(
|
||||
"SELECT json_extract(feedback_count_per_step, '$.' || ?) AS c "
|
||||
"FROM video_pipelines WHERE id = ?",
|
||||
(step, pid),
|
||||
).fetchone()
|
||||
return int(row["c"]) if row and row["c"] is not None else 0
|
||||
|
||||
|
||||
def record_feedback(pid: int, step: str, feedback_text: str) -> None:
|
||||
with _conn() as conn:
|
||||
conn.execute("""
|
||||
INSERT INTO pipeline_feedback (pipeline_id, step, feedback_text, received_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", (pid, step, feedback_text, _now()))
|
||||
|
||||
|
||||
def get_feedback_history(pid: int) -> List[Dict[str, Any]]:
|
||||
with _conn() as conn:
|
||||
rows = conn.execute("""
|
||||
SELECT * FROM pipeline_feedback
|
||||
WHERE pipeline_id = ? ORDER BY id DESC
|
||||
""", (pid,)).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
def create_pipeline_job(pid: int, step: str) -> int:
|
||||
with _conn() as conn:
|
||||
cur = conn.execute("""
|
||||
INSERT INTO pipeline_jobs (pipeline_id, step, status, started_at)
|
||||
VALUES (?, ?, 'queued', ?)
|
||||
""", (pid, step, _now()))
|
||||
return cur.lastrowid
|
||||
|
||||
|
||||
def update_pipeline_job(job_id: int, **fields) -> None:
|
||||
"""pipeline_jobs 행을 갱신. 허용 컬럼 화이트리스트 외 키는 ValueError.
|
||||
|
||||
status가 succeeded/failed로 바뀌면 finished_at을 자동 설정 (호출자가 미지정 시).
|
||||
"""
|
||||
unknown = set(fields) - _PIPELINE_JOB_COLS
|
||||
if unknown:
|
||||
raise ValueError(f"unknown columns for update_pipeline_job: {sorted(unknown)}")
|
||||
if not fields:
|
||||
return
|
||||
|
||||
if (
|
||||
fields.get("status") in ("succeeded", "failed")
|
||||
and "finished_at" not in fields
|
||||
):
|
||||
fields["finished_at"] = _now()
|
||||
|
||||
cols = ", ".join(f"{k} = ?" for k in fields)
|
||||
vals = list(fields.values()) + [job_id]
|
||||
with _conn() as conn:
|
||||
conn.execute(f"UPDATE pipeline_jobs SET {cols} WHERE id = ?", vals)
|
||||
|
||||
|
||||
def list_pipeline_jobs(pid: int) -> List[Dict[str, Any]]:
|
||||
with _conn() as conn:
|
||||
rows = conn.execute("""
|
||||
SELECT * FROM pipeline_jobs WHERE pipeline_id = ? ORDER BY id ASC
|
||||
""", (pid,)).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
def get_youtube_setup() -> Dict[str, Any]:
|
||||
"""youtube_setup의 기본 1행을 반환. 누락 시 자동 시드 후 재조회."""
|
||||
with _conn() as conn:
|
||||
row = conn.execute("SELECT * FROM youtube_setup WHERE id = 1").fetchone()
|
||||
if row is None:
|
||||
_seed_default_youtube_setup(conn)
|
||||
row = conn.execute("SELECT * FROM youtube_setup WHERE id = 1").fetchone()
|
||||
d = dict(row)
|
||||
for k in ("metadata_template_json", "cover_prompts_json",
|
||||
"review_weights_json", "visual_defaults_json", "publish_policy_json"):
|
||||
d[k.replace("_json", "")] = json.loads(d[k])
|
||||
return d
|
||||
|
||||
|
||||
def update_youtube_setup(**kwargs) -> None:
|
||||
field_map = {
|
||||
"metadata_template": "metadata_template_json",
|
||||
"cover_prompts": "cover_prompts_json",
|
||||
"review_weights": "review_weights_json",
|
||||
"visual_defaults": "visual_defaults_json",
|
||||
"publish_policy": "publish_policy_json",
|
||||
}
|
||||
cols = []
|
||||
vals: List[Any] = []
|
||||
for k, v in kwargs.items():
|
||||
if k in field_map:
|
||||
cols.append(f"{field_map[k]} = ?")
|
||||
vals.append(json.dumps(v))
|
||||
elif k == "review_threshold":
|
||||
cols.append("review_threshold = ?")
|
||||
vals.append(int(v))
|
||||
if not cols:
|
||||
return
|
||||
cols.append("updated_at = ?")
|
||||
vals.append(_now())
|
||||
with _conn() as conn:
|
||||
conn.execute(f"UPDATE youtube_setup SET {', '.join(cols)} WHERE id = 1", vals)
|
||||
|
||||
|
||||
def upsert_oauth_token(channel_id: str, channel_title: Optional[str],
|
||||
avatar_url: Optional[str], refresh_token: str,
|
||||
access_token: Optional[str], expires_at: Optional[str]) -> None:
|
||||
with _conn() as conn:
|
||||
conn.execute("DELETE FROM youtube_oauth_tokens")
|
||||
conn.execute("""
|
||||
INSERT INTO youtube_oauth_tokens
|
||||
(channel_id, channel_title, avatar_url, refresh_token, access_token, expires_at, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""", (channel_id, channel_title, avatar_url, refresh_token, access_token, expires_at, _now()))
|
||||
|
||||
|
||||
def get_oauth_token() -> Optional[Dict[str, Any]]:
|
||||
with _conn() as conn:
|
||||
row = conn.execute("SELECT * FROM youtube_oauth_tokens ORDER BY id DESC LIMIT 1").fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
|
||||
def delete_oauth_token() -> None:
|
||||
with _conn() as conn:
|
||||
conn.execute("DELETE FROM youtube_oauth_tokens")
|
||||
|
||||
@@ -22,9 +22,12 @@ from .db import (
|
||||
create_compile_job, get_compile_jobs, get_compile_job,
|
||||
update_compile_job, delete_compile_job,
|
||||
)
|
||||
from . import db as _db_module
|
||||
from .compiler import run_compile
|
||||
from .market import ingest_trends, get_suggestions
|
||||
from .local_provider import run_local_generation
|
||||
from .pipeline import orchestrator
|
||||
from .pipeline import youtube as yt_module
|
||||
from .suno_provider import (
|
||||
run_suno_generation, run_suno_extend, run_vocal_removal,
|
||||
run_cover_image, run_wav_convert, run_stem_split,
|
||||
@@ -921,3 +924,194 @@ def list_market_reports(limit: int = 10):
|
||||
@app.get("/api/music/market/suggest")
|
||||
def market_suggest(limit: int = 5):
|
||||
return {"suggestions": get_suggestions(limit)}
|
||||
|
||||
|
||||
# ── Pipeline endpoints ────────────────────────────────────────────────────────
|
||||
|
||||
class PipelineCreate(BaseModel):
|
||||
track_id: int
|
||||
|
||||
|
||||
class FeedbackRequest(BaseModel):
|
||||
step: str
|
||||
intent: str # approve | reject
|
||||
feedback_text: Optional[str] = None
|
||||
|
||||
|
||||
@app.post("/api/music/pipeline", status_code=201)
|
||||
def create_pipeline(req: PipelineCreate):
|
||||
actives = _db_module.list_pipelines(active_only=True)
|
||||
if any(p["track_id"] == req.track_id for p in actives):
|
||||
raise HTTPException(409, "이미 진행 중인 파이프라인이 있습니다")
|
||||
pid = _db_module.create_pipeline(req.track_id)
|
||||
return _db_module.get_pipeline(pid)
|
||||
|
||||
|
||||
@app.get("/api/music/pipeline")
|
||||
def list_pipelines_endpoint(status: str = "all"):
|
||||
pipelines = _db_module.list_pipelines(active_only=(status == "active"))
|
||||
return {"pipelines": pipelines}
|
||||
|
||||
|
||||
@app.get("/api/music/pipeline/lookup-by-msg/{msg_id}")
|
||||
def lookup_by_msg(msg_id: int):
|
||||
for p in _db_module.list_pipelines(active_only=True):
|
||||
for step, mid in p["last_telegram_msg_ids"].items():
|
||||
if mid == msg_id:
|
||||
return {"pipeline_id": p["id"], "step": step}
|
||||
raise HTTPException(404)
|
||||
|
||||
|
||||
@app.get("/api/music/pipeline/{pid}")
|
||||
def get_pipeline_endpoint(pid: int):
|
||||
p = _db_module.get_pipeline(pid)
|
||||
if not p:
|
||||
raise HTTPException(404)
|
||||
p["jobs"] = _db_module.list_pipeline_jobs(pid)
|
||||
p["feedback"] = _db_module.get_feedback_history(pid)
|
||||
return p
|
||||
|
||||
|
||||
@app.post("/api/music/pipeline/{pid}/start", status_code=202)
|
||||
async def start_pipeline(pid: int, bg: BackgroundTasks):
|
||||
p = _db_module.get_pipeline(pid)
|
||||
if not p:
|
||||
raise HTTPException(404)
|
||||
if p["state"] != "created":
|
||||
raise HTTPException(409, f"이미 시작됨 ({p['state']})")
|
||||
bg.add_task(orchestrator.run_step, pid, "cover")
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
def _state_to_step(state: str) -> Optional[str]:
|
||||
return {
|
||||
"video_pending": "video",
|
||||
"thumb_pending": "thumb",
|
||||
"meta_pending": "meta",
|
||||
"ai_review": "review",
|
||||
"publish_pending": None, # 사용자 명시 발행 호출 필요
|
||||
"publishing": "publish",
|
||||
}.get(state)
|
||||
|
||||
|
||||
@app.post("/api/music/pipeline/{pid}/feedback", status_code=202)
|
||||
async def feedback(pid: int, req: FeedbackRequest, bg: BackgroundTasks):
|
||||
p = _db_module.get_pipeline(pid)
|
||||
if not p:
|
||||
raise HTTPException(404)
|
||||
if p["state"] == "awaiting_manual":
|
||||
raise HTTPException(409, "수동 개입 대기 중")
|
||||
state = p["state"]
|
||||
expected = f"{req.step}_pending"
|
||||
if state != expected:
|
||||
# 멱등 처리 — 이미 다음 단계로 넘어갔으면 무시
|
||||
return {"ok": True, "skipped": True}
|
||||
|
||||
if req.intent == "approve":
|
||||
from .pipeline.state_machine import next_state_on_approve
|
||||
next_st = next_state_on_approve(state)
|
||||
_db_module.update_pipeline_state(pid, next_st)
|
||||
next_step = _state_to_step(next_st)
|
||||
if next_step:
|
||||
bg.add_task(orchestrator.run_step, pid, next_step)
|
||||
return {"ok": True}
|
||||
|
||||
elif req.intent == "reject":
|
||||
count = _db_module.increment_feedback_count(pid, req.step)
|
||||
if count > 5:
|
||||
_db_module.update_pipeline_state(pid, "awaiting_manual")
|
||||
raise HTTPException(409, "재생성 한도 초과")
|
||||
if req.feedback_text:
|
||||
_db_module.record_feedback(pid, req.step, req.feedback_text)
|
||||
bg.add_task(orchestrator.run_step, pid, req.step, req.feedback_text or "")
|
||||
return {"ok": True}
|
||||
|
||||
else:
|
||||
raise HTTPException(400, f"unknown intent: {req.intent}")
|
||||
|
||||
|
||||
@app.post("/api/music/pipeline/{pid}/cancel")
|
||||
def cancel_pipeline(pid: int):
|
||||
p = _db_module.get_pipeline(pid)
|
||||
if not p:
|
||||
raise HTTPException(404)
|
||||
_db_module.update_pipeline_state(pid, "cancelled", cancelled_at=_db_module._now())
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@app.post("/api/music/pipeline/{pid}/publish", status_code=202)
|
||||
async def publish_pipeline(pid: int, bg: BackgroundTasks):
|
||||
p = _db_module.get_pipeline(pid)
|
||||
if not p:
|
||||
raise HTTPException(404)
|
||||
if p["state"] != "publish_pending":
|
||||
raise HTTPException(409, f"발행 단계 아님 ({p['state']})")
|
||||
_db_module.update_pipeline_state(pid, "publishing")
|
||||
bg.add_task(orchestrator.run_step, pid, "publish")
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# Telegram 메시지 매칭용 엔드포인트 (agent-office용)
|
||||
|
||||
class TelegramMsgPatch(BaseModel):
|
||||
step: str
|
||||
message_id: int
|
||||
|
||||
|
||||
@app.patch("/api/music/pipeline/{pid}/telegram-msg")
|
||||
def save_telegram_msg(pid: int, req: TelegramMsgPatch):
|
||||
p = _db_module.get_pipeline(pid)
|
||||
if not p:
|
||||
raise HTTPException(404)
|
||||
ids = p["last_telegram_msg_ids"]
|
||||
ids[req.step] = req.message_id
|
||||
_db_module.update_pipeline_state(
|
||||
pid, p["state"], last_telegram_msg_ids=json.dumps(ids)
|
||||
)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── Setup endpoints ───────────────────────────────────────────────────────────
|
||||
|
||||
class SetupRequest(BaseModel):
|
||||
metadata_template: Optional[Dict[str, Any]] = None
|
||||
cover_prompts: Optional[Dict[str, Any]] = None
|
||||
review_weights: Optional[Dict[str, Any]] = None
|
||||
review_threshold: Optional[int] = None
|
||||
visual_defaults: Optional[Dict[str, Any]] = None
|
||||
publish_policy: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@app.get("/api/music/setup")
|
||||
def get_setup():
|
||||
return _db_module.get_youtube_setup()
|
||||
|
||||
|
||||
@app.put("/api/music/setup")
|
||||
def put_setup(req: SetupRequest):
|
||||
payload = {k: v for k, v in req.dict().items() if v is not None}
|
||||
_db_module.update_youtube_setup(**payload)
|
||||
return _db_module.get_youtube_setup()
|
||||
|
||||
|
||||
# ── YouTube OAuth endpoints ───────────────────────────────────────────────────
|
||||
|
||||
@app.get("/api/music/youtube/auth-url")
|
||||
def youtube_auth_url():
|
||||
return {"url": yt_module.get_auth_url()}
|
||||
|
||||
|
||||
@app.get("/api/music/youtube/callback")
|
||||
async def youtube_callback(code: str):
|
||||
return await yt_module.exchange_code(code)
|
||||
|
||||
|
||||
@app.post("/api/music/youtube/disconnect")
|
||||
def youtube_disconnect():
|
||||
yt_module.disconnect()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@app.get("/api/music/youtube/status")
|
||||
def youtube_status():
|
||||
return yt_module.get_status() or {"connected": False}
|
||||
|
||||
0
music-lab/app/pipeline/__init__.py
Normal file
0
music-lab/app/pipeline/__init__.py
Normal file
88
music-lab/app/pipeline/cover.py
Normal file
88
music-lab/app/pipeline/cover.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""AI 커버 아트 생성 — DALL·E 3 / gpt-image-1 + 그라데이션 폴백."""
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
import httpx
|
||||
from PIL import Image
|
||||
|
||||
from . import storage
|
||||
from .gradient import make_gradient_with_title
|
||||
|
||||
logger = logging.getLogger("music-lab.cover")
|
||||
|
||||
DALLE_TIMEOUT_S = 90
|
||||
|
||||
|
||||
def _get_api_key() -> str:
|
||||
return os.getenv("OPENAI_API_KEY", "")
|
||||
|
||||
|
||||
def _get_model() -> str:
|
||||
return os.getenv("OPENAI_IMAGE_MODEL", "gpt-image-1")
|
||||
|
||||
|
||||
async def generate(*, pipeline_id: int, genre: str, prompt_template: str,
|
||||
mood: str = "", track_title: str = "", feedback: str = "") -> dict:
|
||||
"""커버 아트 생성. 성공 시 jpg 저장 + URL 반환. 실패 시 그라데이션 폴백.
|
||||
|
||||
반환: {"url": str, "used_fallback": bool, "error": str | None}
|
||||
"""
|
||||
out_path = os.path.join(storage.pipeline_dir(pipeline_id), "cover.jpg")
|
||||
used_fallback = False
|
||||
error = None
|
||||
|
||||
api_key = _get_api_key()
|
||||
model = _get_model()
|
||||
if api_key:
|
||||
try:
|
||||
await _generate_with_dalle(prompt_template, mood, feedback, out_path,
|
||||
api_key=api_key, model=model)
|
||||
except (httpx.HTTPError, httpx.TimeoutException, KeyError, ValueError, OSError) as e:
|
||||
logger.warning("DALL·E 실패 — 폴백: %s", e)
|
||||
error = str(e)
|
||||
used_fallback = True
|
||||
make_gradient_with_title(genre, track_title, out_path)
|
||||
else:
|
||||
used_fallback = True
|
||||
error = "OPENAI_API_KEY 미설정"
|
||||
make_gradient_with_title(genre, track_title, out_path)
|
||||
|
||||
return {
|
||||
"url": storage.media_url(pipeline_id, "cover.jpg"),
|
||||
"used_fallback": used_fallback,
|
||||
"error": error,
|
||||
}
|
||||
|
||||
|
||||
async def _generate_with_dalle(prompt_template: str, mood: str,
|
||||
feedback: str, out_path: str,
|
||||
*, api_key: str, model: str) -> None:
|
||||
prompt = prompt_template
|
||||
if mood:
|
||||
prompt = f"{prompt}, {mood} mood"
|
||||
if feedback:
|
||||
prompt = f"{prompt}. 추가 지시: {feedback}"
|
||||
prompt = f"{prompt}, no text, high quality"
|
||||
|
||||
async with httpx.AsyncClient(timeout=DALLE_TIMEOUT_S) as client:
|
||||
resp = await client.post(
|
||||
"https://api.openai.com/v1/images/generations",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json={"model": model, "prompt": prompt, "size": "1024x1024", "n": 1},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()["data"][0]
|
||||
if "url" in data:
|
||||
img_resp = await client.get(data["url"])
|
||||
img_resp.raise_for_status()
|
||||
img_bytes = img_resp.content
|
||||
elif "b64_json" in data:
|
||||
img_bytes = base64.b64decode(data["b64_json"])
|
||||
else:
|
||||
raise ValueError("DALL·E response has neither url nor b64_json")
|
||||
# PNG → JPG 변환
|
||||
with Image.open(BytesIO(img_bytes)) as src:
|
||||
img = src.convert("RGB")
|
||||
img.save(out_path, "JPEG", quality=92)
|
||||
38
music-lab/app/pipeline/gradient.py
Normal file
38
music-lab/app/pipeline/gradient.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""장르별 그라데이션 배경 + 텍스트 오버레이 — cover/video 공용."""
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
GENRE_COLORS = {
|
||||
"lo-fi": ((26, 26, 46), (22, 33, 62)),
|
||||
"phonk": ((26, 10, 10), (45, 0, 0)),
|
||||
"ambient": ((13, 33, 55), (10, 22, 40)),
|
||||
"pop": ((26, 10, 46), (45, 27, 78)),
|
||||
"default": ((17, 24, 39), (31, 41, 55)),
|
||||
}
|
||||
|
||||
|
||||
def make_gradient_with_title(genre: str, title: str, out_path: str,
|
||||
size: tuple[int, int] = (1024, 1024),
|
||||
quality: int = 92) -> None:
|
||||
w, h = size
|
||||
top, bot = GENRE_COLORS.get(genre.lower(), GENRE_COLORS["default"])
|
||||
with Image.new("RGB", (w, h)) as img:
|
||||
px = img.load()
|
||||
for y in range(h):
|
||||
t = y / h
|
||||
r = int(top[0] + (bot[0] - top[0]) * t)
|
||||
g = int(top[1] + (bot[1] - top[1]) * t)
|
||||
b = int(top[2] + (bot[2] - top[2]) * t)
|
||||
for x in range(w):
|
||||
px[x, y] = (r, g, b)
|
||||
|
||||
if title:
|
||||
try:
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 64)
|
||||
except OSError:
|
||||
font = ImageFont.load_default()
|
||||
draw = ImageDraw.Draw(img)
|
||||
bbox = draw.textbbox((0, 0), title, font=font)
|
||||
tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
||||
draw.text(((w - tw) // 2, (h - th) // 2), title, fill=(255, 255, 255), font=font)
|
||||
|
||||
img.save(out_path, "JPEG", quality=quality)
|
||||
95
music-lab/app/pipeline/metadata.py
Normal file
95
music-lab/app/pipeline/metadata.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""메타데이터 생성 — Claude Haiku + 템플릿 폴백."""
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger("music-lab.metadata")
|
||||
|
||||
CLAUDE_HAIKU_MODEL_DEFAULT = "claude-haiku-4-5-20251001"
|
||||
TIMEOUT_S = 30
|
||||
|
||||
|
||||
def _get_api_key() -> str:
|
||||
return os.getenv("ANTHROPIC_API_KEY", "")
|
||||
|
||||
|
||||
def _get_model() -> str:
|
||||
return os.getenv("CLAUDE_HAIKU_MODEL", CLAUDE_HAIKU_MODEL_DEFAULT)
|
||||
|
||||
|
||||
async def generate(*, track: dict, template: dict, trend_keywords: list[str],
|
||||
feedback: str = "") -> dict:
|
||||
"""메타데이터 생성. 성공 시 LLM, 실패/미설정 시 템플릿 치환 폴백.
|
||||
|
||||
반환: {"title", "description", "tags", "category_id", "used_fallback", "error"}
|
||||
"""
|
||||
api_key = _get_api_key()
|
||||
if not api_key:
|
||||
return {**_fallback_template(track, template), "used_fallback": True, "error": "no api key"}
|
||||
|
||||
try:
|
||||
result = await _call_claude(track, template, trend_keywords, feedback,
|
||||
api_key=api_key, model=_get_model())
|
||||
return {**result, "used_fallback": False, "error": None}
|
||||
except (httpx.HTTPError, httpx.TimeoutException, KeyError, ValueError, json.JSONDecodeError) as e:
|
||||
logger.warning("메타데이터 LLM 실패 — 폴백: %s", e)
|
||||
return {**_fallback_template(track, template), "used_fallback": True, "error": str(e)}
|
||||
|
||||
|
||||
def _fallback_template(track: dict, template: dict) -> dict:
|
||||
fmt_vars = {
|
||||
"title": track.get("title", ""),
|
||||
"genre": track.get("genre", ""),
|
||||
"bpm": track.get("bpm", ""),
|
||||
"key": track.get("key", ""),
|
||||
"scale": track.get("scale", ""),
|
||||
}
|
||||
title = template.get("title", "{title}").format(**fmt_vars)
|
||||
description = template.get("description", "{title}").format(**fmt_vars)
|
||||
return {
|
||||
"title": title[:100],
|
||||
"description": description[:5000],
|
||||
"tags": (template.get("tags") or [])[:15],
|
||||
"category_id": template.get("category_id", 10),
|
||||
}
|
||||
|
||||
|
||||
async def _call_claude(track: dict, template: dict, trend_keywords: list[str],
|
||||
feedback: str, *, api_key: str, model: str) -> dict:
|
||||
user_prompt = (
|
||||
"다음 트랙의 YouTube 메타데이터를 생성하세요. JSON으로만 응답.\n\n"
|
||||
f"트랙: {json.dumps(track, ensure_ascii=False)}\n"
|
||||
f"템플릿: {json.dumps(template, ensure_ascii=False)}\n"
|
||||
f"트렌드 키워드: {', '.join(trend_keywords)}\n"
|
||||
)
|
||||
if feedback:
|
||||
user_prompt += f"\n사용자 피드백: {feedback}\n"
|
||||
user_prompt += (
|
||||
'\n출력 JSON: {"title": "60자 이내", "description": "1000자 이내, 3-5문단",'
|
||||
' "tags": ["15개 이내"], "category_id": 10}'
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(timeout=TIMEOUT_S) as client:
|
||||
resp = await client.post(
|
||||
"https://api.anthropic.com/v1/messages",
|
||||
headers={
|
||||
"x-api-key": api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": model,
|
||||
"max_tokens": 1024,
|
||||
"messages": [{"role": "user", "content": user_prompt}],
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
text = resp.json()["content"][0]["text"]
|
||||
# 가장 첫 JSON 블록 추출
|
||||
start = text.find("{")
|
||||
end = text.rfind("}") + 1
|
||||
if start < 0 or end <= start:
|
||||
raise ValueError("Claude 응답에 JSON 블록 없음")
|
||||
return json.loads(text[start:end])
|
||||
183
music-lab/app/pipeline/orchestrator.py
Normal file
183
music-lab/app/pipeline/orchestrator.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""파이프라인 오케스트레이터 — 단계별 BackgroundTask 등록 및 산출물 → DB 반영."""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
|
||||
from app import db
|
||||
from . import cover, video, thumb, metadata, review, youtube
|
||||
|
||||
logger = logging.getLogger("music-lab.orchestrator")
|
||||
|
||||
|
||||
async def run_step(pipeline_id: int, step: str, feedback: str = "") -> None:
|
||||
"""단계 실행 → 결과를 DB에 반영하고 *_pending 또는 다음 단계로 전이.
|
||||
|
||||
호출 직후 _running 상태로 전환, 끝나면 _pending(사용자 게이트) 또는 자동 다음.
|
||||
실패 시 failed 상태 + reason.
|
||||
"""
|
||||
job_id = db.create_pipeline_job(pipeline_id, step)
|
||||
db.update_pipeline_job(job_id, status="running")
|
||||
p = db.get_pipeline(pipeline_id)
|
||||
track = _get_track(p["track_id"])
|
||||
|
||||
try:
|
||||
if step == "cover":
|
||||
result = await _run_cover(p, track, feedback)
|
||||
elif step == "video":
|
||||
result = await _run_video(p, track)
|
||||
elif step == "thumb":
|
||||
result = await _run_thumb(p, track, feedback)
|
||||
elif step == "meta":
|
||||
result = await _run_meta(p, track, feedback)
|
||||
elif step == "review":
|
||||
result = await _run_review(p, track)
|
||||
elif step == "publish":
|
||||
result = await _run_publish(p, track)
|
||||
else:
|
||||
raise ValueError(f"unknown step: {step}")
|
||||
db.update_pipeline_job(job_id, status="succeeded")
|
||||
db.update_pipeline_state(pipeline_id, result["next_state"], **result.get("fields", {}))
|
||||
except Exception as e:
|
||||
logger.exception("step %s failed for pipeline %s", step, pipeline_id)
|
||||
db.update_pipeline_job(job_id, status="failed", error=str(e))
|
||||
db.update_pipeline_state(pipeline_id, "failed", failed_reason=f"{step}: {e}")
|
||||
|
||||
|
||||
def _get_track(track_id: int) -> dict:
|
||||
# tracks 테이블 헬퍼 — 기존 db에 있는 함수 사용
|
||||
t = None
|
||||
if hasattr(db, "get_track_by_id"):
|
||||
t = db.get_track_by_id(track_id)
|
||||
elif hasattr(db, "get_track"):
|
||||
t = db.get_track(track_id)
|
||||
if not t:
|
||||
# 폴백: music_library 테이블에서 직접 (스키마 확인 필요)
|
||||
t = _fetch_track_fallback(track_id)
|
||||
if not t:
|
||||
raise ValueError(f"트랙 {track_id} 없음")
|
||||
return t
|
||||
|
||||
|
||||
def _fetch_track_fallback(track_id: int) -> dict | None:
|
||||
"""db 모듈에 get_track이 없을 때 대비 — music_library 테이블 직접 조회."""
|
||||
try:
|
||||
conn = sqlite3.connect(db.DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
# 가능한 테이블/컬럼 시도 (music_library 또는 tracks)
|
||||
for table in ("music_library", "tracks"):
|
||||
try:
|
||||
row = conn.execute(f"SELECT * FROM {table} WHERE id = ?", (track_id,)).fetchone()
|
||||
if row:
|
||||
d = dict(row)
|
||||
# JSON 컬럼 파싱 (있으면)
|
||||
for k in ("moods", "instruments"):
|
||||
if k in d and isinstance(d[k], str):
|
||||
try:
|
||||
d[k] = json.loads(d[k])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
d[k] = []
|
||||
conn.close()
|
||||
return d
|
||||
except sqlite3.OperationalError:
|
||||
continue
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logger.warning("track fallback fetch 실패: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
async def _run_cover(p, track, feedback):
|
||||
setup = db.get_youtube_setup()
|
||||
prompts = setup["cover_prompts"]
|
||||
template = prompts.get(track.get("genre", "default").lower(), prompts.get("default", ""))
|
||||
out = await cover.generate(
|
||||
pipeline_id=p["id"], genre=track.get("genre", "default"),
|
||||
prompt_template=template,
|
||||
mood=", ".join(track.get("moods", []) or []),
|
||||
track_title=track.get("title", ""),
|
||||
feedback=feedback,
|
||||
)
|
||||
return {"next_state": "cover_pending", "fields": {"cover_url": out["url"]}}
|
||||
|
||||
|
||||
async def _run_video(p, track):
|
||||
setup = db.get_youtube_setup()
|
||||
vd = setup["visual_defaults"]
|
||||
audio_path = _local_path(track.get("audio_url", ""))
|
||||
cover_path = _local_path(p["cover_url"])
|
||||
out = video.generate(
|
||||
pipeline_id=p["id"], audio_path=audio_path, cover_path=cover_path,
|
||||
genre=track.get("genre", "default"),
|
||||
duration_sec=track.get("duration_sec", 120),
|
||||
resolution=vd["resolution"], style=vd["style"],
|
||||
)
|
||||
return {"next_state": "video_pending", "fields": {"video_url": out["url"]}}
|
||||
|
||||
|
||||
async def _run_thumb(p, track, feedback):
|
||||
video_path = _local_path(p["video_url"])
|
||||
out = thumb.generate(pipeline_id=p["id"], video_path=video_path,
|
||||
track_title=track.get("title", ""), overlay_text=True)
|
||||
return {"next_state": "thumb_pending", "fields": {"thumbnail_url": out["url"]}}
|
||||
|
||||
|
||||
async def _run_meta(p, track, feedback):
|
||||
setup = db.get_youtube_setup()
|
||||
trend_top = _get_trend_top()
|
||||
out = await metadata.generate(
|
||||
track=track, template=setup["metadata_template"],
|
||||
trend_keywords=trend_top, feedback=feedback,
|
||||
)
|
||||
return {"next_state": "meta_pending",
|
||||
"fields": {"metadata_json": json.dumps(out, ensure_ascii=False)}}
|
||||
|
||||
|
||||
async def _run_review(p, track):
|
||||
setup = db.get_youtube_setup()
|
||||
meta = json.loads(p["metadata_json"]) if p.get("metadata_json") else {}
|
||||
result = await review.run_4_axis(
|
||||
pipeline=p, track=track,
|
||||
video_meta={"length_sec": track.get("duration_sec", 120),
|
||||
"resolution": setup["visual_defaults"]["resolution"]},
|
||||
metadata=meta, thumbnail_url=p.get("thumbnail_url", ""),
|
||||
trend_top=_get_trend_top(),
|
||||
weights=setup["review_weights"], threshold=setup["review_threshold"],
|
||||
)
|
||||
return {"next_state": "publish_pending",
|
||||
"fields": {"review_json": json.dumps(result, ensure_ascii=False)}}
|
||||
|
||||
|
||||
async def _run_publish(p, track):
|
||||
setup = db.get_youtube_setup()
|
||||
meta = json.loads(p["metadata_json"]) if p.get("metadata_json") else {}
|
||||
privacy = setup["publish_policy"].get("privacy", "private")
|
||||
result = youtube.upload_video(
|
||||
video_path=_local_path(p["video_url"]),
|
||||
thumbnail_path=_local_path(p["thumbnail_url"]) if p.get("thumbnail_url") else None,
|
||||
metadata=meta, privacy=privacy,
|
||||
)
|
||||
return {"next_state": "published",
|
||||
"fields": {"youtube_video_id": result["video_id"]}}
|
||||
|
||||
|
||||
def _local_path(media_url: str) -> str:
|
||||
""" /media/videos/123/cover.jpg → /app/data/videos/123/cover.jpg """
|
||||
if not media_url:
|
||||
return ""
|
||||
base_media = os.getenv("VIDEO_MEDIA_BASE", "/media/videos")
|
||||
base_data = os.getenv("VIDEO_DATA_DIR", "/app/data/videos")
|
||||
if media_url.startswith(base_media):
|
||||
return media_url.replace(base_media, base_data, 1)
|
||||
# /media/music/abc.mp3 → /app/data/music/abc.mp3
|
||||
return media_url.replace("/media/", "/app/data/", 1)
|
||||
|
||||
|
||||
def _get_trend_top(n: int = 10) -> list[str]:
|
||||
try:
|
||||
if hasattr(db, "get_market_trends"):
|
||||
rows = db.get_market_trends(days=7)
|
||||
return [r.get("genre", "") for r in rows[:n] if r.get("genre")]
|
||||
except Exception:
|
||||
pass
|
||||
return []
|
||||
120
music-lab/app/pipeline/review.py
Normal file
120
music-lab/app/pipeline/review.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""AI 최종 검토 — 4축(메타/정책/시청/트렌드) 가중 평균."""
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger("music-lab.review")
|
||||
|
||||
CLAUDE_SONNET_MODEL_DEFAULT = "claude-sonnet-4-6"
|
||||
TIMEOUT_S = 60
|
||||
|
||||
POLICY_BANNED = {"f-word", "n-word"} # 운영 시 별도 파일로 — 데모용 자리
|
||||
|
||||
|
||||
def _get_api_key() -> str:
|
||||
return os.getenv("ANTHROPIC_API_KEY", "")
|
||||
|
||||
|
||||
def _get_model() -> str:
|
||||
return os.getenv("CLAUDE_SONNET_MODEL", CLAUDE_SONNET_MODEL_DEFAULT)
|
||||
|
||||
|
||||
async def run_4_axis(*, pipeline: dict, track: dict, video_meta: dict,
|
||||
metadata: dict, thumbnail_url: str, trend_top: list[str],
|
||||
weights: dict, threshold: int) -> dict:
|
||||
api_key = _get_api_key()
|
||||
if not api_key:
|
||||
return _heuristic(metadata, video_meta, track, trend_top, weights, threshold,
|
||||
fallback_reason="no api key")
|
||||
try:
|
||||
scores = await _call_claude(pipeline, track, video_meta, metadata,
|
||||
thumbnail_url, trend_top,
|
||||
api_key=api_key, model=_get_model())
|
||||
return _weighted_verdict(scores, weights, threshold, used_fallback=False)
|
||||
except (httpx.HTTPError, httpx.TimeoutException, KeyError, ValueError, json.JSONDecodeError) as e:
|
||||
logger.warning("검토 LLM 실패 — 휴리스틱: %s", e)
|
||||
return _heuristic(metadata, video_meta, track, trend_top, weights, threshold,
|
||||
fallback_reason=str(e))
|
||||
|
||||
|
||||
def _weighted_verdict(scores: dict, weights: dict, threshold: int,
|
||||
used_fallback: bool) -> dict:
|
||||
total = (
|
||||
weights["meta"] / 100 * scores["metadata_quality"]["score"] +
|
||||
weights["policy"] / 100 * scores["policy_compliance"]["score"] +
|
||||
weights["viewer"] / 100 * scores["viewer_experience"]["score"] +
|
||||
weights["trend"] / 100 * scores["trend_alignment"]["score"]
|
||||
)
|
||||
return {
|
||||
**scores,
|
||||
"weighted_total": round(total, 2),
|
||||
"verdict": "pass" if total >= threshold else "fail",
|
||||
"used_fallback": used_fallback,
|
||||
}
|
||||
|
||||
|
||||
async def _call_claude(pipeline, track, video_meta, metadata, thumbnail_url, trend_top,
|
||||
*, api_key: str, model: str):
|
||||
user = (
|
||||
"트랙·영상·메타데이터를 4축으로 평가하고 JSON만 응답:\n"
|
||||
f"트랙: {json.dumps(track, ensure_ascii=False)}\n"
|
||||
f"영상: {json.dumps(video_meta)}\n"
|
||||
f"메타: {json.dumps(metadata, ensure_ascii=False)}\n"
|
||||
f"썸네일: {thumbnail_url}\n"
|
||||
f"트렌드: {trend_top}\n"
|
||||
'출력: {"metadata_quality":{"score":0-100,"notes":""},'
|
||||
'"policy_compliance":{"score":0-100,"issues":[]},'
|
||||
'"viewer_experience":{"score":0-100,"notes":""},'
|
||||
'"trend_alignment":{"score":0-100,"matched_keywords":[]},'
|
||||
'"summary":""}'
|
||||
)
|
||||
async with httpx.AsyncClient(timeout=TIMEOUT_S) as client:
|
||||
resp = await client.post(
|
||||
"https://api.anthropic.com/v1/messages",
|
||||
headers={
|
||||
"x-api-key": api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
json={"model": model, "max_tokens": 1024,
|
||||
"messages": [{"role": "user", "content": user}]},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
text = resp.json()["content"][0]["text"]
|
||||
start = text.find("{")
|
||||
end = text.rfind("}") + 1
|
||||
if start < 0 or end <= start:
|
||||
raise ValueError("Claude 응답 JSON 없음")
|
||||
return json.loads(text[start:end])
|
||||
|
||||
|
||||
def _heuristic(metadata, video_meta, track, trend_top, weights, threshold, fallback_reason):
|
||||
# 메타: 길이·태그 카운트
|
||||
title_len = len(metadata.get("title", ""))
|
||||
desc_len = len(metadata.get("description", ""))
|
||||
tag_n = len(metadata.get("tags", []))
|
||||
meta_score = 100 if 5 <= title_len <= 60 and 50 <= desc_len <= 1000 and 5 <= tag_n <= 15 else 50
|
||||
|
||||
# 정책: 금칙어 매치
|
||||
text_blob = (metadata.get("title", "") + metadata.get("description", "")).lower()
|
||||
policy_score = 100 if not any(w in text_blob for w in POLICY_BANNED) else 30
|
||||
|
||||
# 시청: 영상 길이가 트랙과 큰 차이 없는지 휴리스틱(±5초)
|
||||
expected = track.get("duration_sec", video_meta.get("length_sec", 0))
|
||||
delta = abs(video_meta.get("length_sec", 0) - expected)
|
||||
viewer_score = 90 if delta <= 5 else 60
|
||||
|
||||
# 트렌드: 태그가 트렌드와 겹치는지
|
||||
overlap = set(metadata.get("tags", [])) & set(trend_top)
|
||||
trend_score = 100 if overlap else 40
|
||||
|
||||
scores = {
|
||||
"metadata_quality": {"score": meta_score, "notes": "휴리스틱"},
|
||||
"policy_compliance": {"score": policy_score, "issues": []},
|
||||
"viewer_experience": {"score": viewer_score, "notes": "휴리스틱"},
|
||||
"trend_alignment": {"score": trend_score, "matched_keywords": list(overlap)},
|
||||
"summary": f"휴리스틱 fallback: {fallback_reason}",
|
||||
}
|
||||
return _weighted_verdict(scores, weights, threshold, used_fallback=True)
|
||||
41
music-lab/app/pipeline/state_machine.py
Normal file
41
music-lab/app/pipeline/state_machine.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""파이프라인 상태 머신 — 전이 규칙 단일 소스."""
|
||||
|
||||
STEPS = ["cover", "video", "thumb", "meta", "review", "publish"]
|
||||
USER_GATES = ["cover", "video", "thumb", "meta", "publish"] # review는 자동
|
||||
|
||||
_APPROVE_NEXT = {
|
||||
"cover_pending": "video_pending",
|
||||
"video_pending": "thumb_pending",
|
||||
"thumb_pending": "meta_pending",
|
||||
"meta_pending": "ai_review", # 자동 검토 단계로
|
||||
"publish_pending": "publishing",
|
||||
}
|
||||
|
||||
TERMINAL_STATES = {"published", "cancelled", "failed", "awaiting_manual"}
|
||||
|
||||
|
||||
def next_state_on_approve(state: str) -> str:
|
||||
if state not in _APPROVE_NEXT:
|
||||
raise ValueError(f"승인 불가 상태: {state}")
|
||||
return _APPROVE_NEXT[state]
|
||||
|
||||
|
||||
def next_state_on_reject(state: str) -> str:
|
||||
if not state.endswith("_pending"):
|
||||
raise ValueError(f"반려 불가 상태: {state}")
|
||||
return state # 같은 상태 유지 (재생성 후 다시 _pending)
|
||||
|
||||
|
||||
def can_transition(from_state: str, to_state: str) -> bool:
|
||||
if from_state in TERMINAL_STATES:
|
||||
return False
|
||||
if to_state in {"cancelled", "failed", "awaiting_manual"}:
|
||||
return True
|
||||
if to_state == _APPROVE_NEXT.get(from_state):
|
||||
return True
|
||||
# 자동 전이 (ai_review → publish_pending, publishing → published)
|
||||
auto_transitions = {
|
||||
("ai_review", "publish_pending"),
|
||||
("publishing", "published"),
|
||||
}
|
||||
return (from_state, to_state) in auto_transitions
|
||||
15
music-lab/app/pipeline/storage.py
Normal file
15
music-lab/app/pipeline/storage.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""파이프라인 산출물 디렉토리 관리."""
|
||||
import os
|
||||
|
||||
VIDEO_DATA_DIR = os.getenv("VIDEO_DATA_DIR", "/app/data/videos")
|
||||
VIDEO_MEDIA_BASE = os.getenv("VIDEO_MEDIA_BASE", "/media/videos")
|
||||
|
||||
|
||||
def pipeline_dir(pipeline_id: int) -> str:
|
||||
path = os.path.join(VIDEO_DATA_DIR, str(pipeline_id))
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def media_url(pipeline_id: int, filename: str) -> str:
|
||||
return f"{VIDEO_MEDIA_BASE}/{pipeline_id}/{filename}"
|
||||
51
music-lab/app/pipeline/thumb.py
Normal file
51
music-lab/app/pipeline/thumb.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""썸네일 생성 — 영상 5초 프레임 추출 + 텍스트 오버레이."""
|
||||
import os
|
||||
import subprocess
|
||||
import logging
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from . import storage
|
||||
|
||||
logger = logging.getLogger("music-lab.thumb")
|
||||
THUMB_TIMEOUT_S = 60
|
||||
|
||||
|
||||
class ThumbGenerationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def generate(*, pipeline_id: int, video_path: str,
|
||||
track_title: str = "", overlay_text: bool = True) -> dict:
|
||||
out_path = os.path.join(storage.pipeline_dir(pipeline_id), "thumbnail.jpg")
|
||||
cmd = ["ffmpeg", "-y", "-i", video_path,
|
||||
"-ss", "00:00:05", "-vframes", "1", "-q:v", "2", out_path]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=THUMB_TIMEOUT_S)
|
||||
if result.returncode != 0:
|
||||
raise ThumbGenerationError(f"ffmpeg 썸네일 실패: {result.stderr[:300]}")
|
||||
|
||||
if overlay_text and track_title:
|
||||
_overlay_title(out_path, track_title)
|
||||
|
||||
return {"url": storage.media_url(pipeline_id, "thumbnail.jpg"), "used_fallback": False}
|
||||
|
||||
|
||||
def _overlay_title(path: str, title: str) -> None:
|
||||
try:
|
||||
with Image.open(path) as src:
|
||||
img = src.convert("RGB")
|
||||
try:
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 80)
|
||||
except OSError:
|
||||
font = ImageFont.load_default()
|
||||
draw = ImageDraw.Draw(img)
|
||||
# 하단 30% 영역에 검정 반투명 박스 + 흰 글씨
|
||||
w, h = img.size
|
||||
box_h = int(h * 0.3)
|
||||
with Image.new("RGBA", (w, box_h), (0, 0, 0, 160)) as overlay:
|
||||
img.paste(overlay, (0, h - box_h), overlay)
|
||||
bbox = draw.textbbox((0, 0), title, font=font)
|
||||
tw = bbox[2] - bbox[0]
|
||||
draw.text(((w - tw) // 2, h - box_h + 30), title, fill=(255, 255, 255), font=font)
|
||||
img.save(path, "JPEG", quality=92)
|
||||
except Exception as e:
|
||||
logger.warning("썸네일 오버레이 실패: %s", e)
|
||||
55
music-lab/app/pipeline/video.py
Normal file
55
music-lab/app/pipeline/video.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""영상 비주얼 생성 — visualizer/슬라이드쇼 스타일."""
|
||||
import os
|
||||
import subprocess
|
||||
import logging
|
||||
|
||||
from . import storage
|
||||
|
||||
logger = logging.getLogger("music-lab.video")
|
||||
|
||||
VIDEO_TIMEOUT_S = 300 # 5분
|
||||
|
||||
|
||||
class VideoGenerationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def generate(*, pipeline_id: int, audio_path: str, cover_path: str,
|
||||
genre: str, duration_sec: int, resolution: str = "1920x1080",
|
||||
style: str = "visualizer") -> dict:
|
||||
"""영상 생성. 성공 시 mp4 저장 + URL 반환. 실패 시 예외."""
|
||||
w, h = resolution.split("x")
|
||||
out_path = os.path.join(storage.pipeline_dir(pipeline_id), "video.mp4")
|
||||
|
||||
if style == "visualizer":
|
||||
cmd = _build_visualizer_cmd(audio_path, cover_path, out_path, w, h)
|
||||
else:
|
||||
# 차후: 슬라이드쇼 등 다른 스타일 — 현재는 visualizer 폴백
|
||||
cmd = _build_visualizer_cmd(audio_path, cover_path, out_path, w, h)
|
||||
|
||||
logger.info("ffmpeg 실행: %s", " ".join(cmd))
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=VIDEO_TIMEOUT_S)
|
||||
if result.returncode != 0:
|
||||
raise VideoGenerationError(f"ffmpeg 실패: {result.stderr[:500]}")
|
||||
|
||||
return {
|
||||
"url": storage.media_url(pipeline_id, "video.mp4"),
|
||||
"used_fallback": False,
|
||||
"duration_sec": duration_sec,
|
||||
}
|
||||
|
||||
|
||||
def _build_visualizer_cmd(audio: str, bg: str, out: str, w: str, h: str) -> list:
|
||||
return [
|
||||
"ffmpeg", "-y",
|
||||
"-loop", "1", "-i", bg,
|
||||
"-i", audio,
|
||||
"-filter_complex",
|
||||
f"[0:v]scale={w}:{h}[bg];"
|
||||
f"[1:a]showwaves=s={w}x200:mode=cline:colors=0xFF4444@0.8[wave];"
|
||||
f"[bg][wave]overlay=0:({h}-200)[out]",
|
||||
"-map", "[out]", "-map", "1:a",
|
||||
"-c:v", "libx264", "-preset", "fast", "-crf", "23",
|
||||
"-c:a", "aac", "-b:a", "192k",
|
||||
"-shortest", out,
|
||||
]
|
||||
156
music-lab/app/pipeline/youtube.py
Normal file
156
music-lab/app/pipeline/youtube.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""YouTube OAuth flow + resumable 업로드."""
|
||||
import os
|
||||
import logging
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.http import MediaFileUpload
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from app import db
|
||||
|
||||
logger = logging.getLogger("music-lab.youtube")
|
||||
|
||||
SCOPES = ["https://www.googleapis.com/auth/youtube.upload",
|
||||
"https://www.googleapis.com/auth/youtube.readonly"]
|
||||
|
||||
|
||||
class NotAuthenticatedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class QuotaExceededError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _client_id() -> str:
|
||||
return os.getenv("YOUTUBE_OAUTH_CLIENT_ID", "")
|
||||
|
||||
|
||||
def _client_secret() -> str:
|
||||
return os.getenv("YOUTUBE_OAUTH_CLIENT_SECRET", "")
|
||||
|
||||
|
||||
def _redirect_uri() -> str:
|
||||
return os.getenv("YOUTUBE_OAUTH_REDIRECT_URI", "")
|
||||
|
||||
|
||||
def get_auth_url() -> str:
|
||||
cid = _client_id()
|
||||
redirect = _redirect_uri()
|
||||
if not cid or not redirect:
|
||||
raise RuntimeError("OAuth 환경변수 미설정")
|
||||
params = {
|
||||
"client_id": cid,
|
||||
"redirect_uri": redirect,
|
||||
"response_type": "code",
|
||||
"scope": " ".join(SCOPES),
|
||||
"access_type": "offline",
|
||||
"prompt": "consent",
|
||||
}
|
||||
return "https://accounts.google.com/o/oauth2/v2/auth?" + urlencode(params)
|
||||
|
||||
|
||||
async def exchange_code(code: str) -> dict:
|
||||
"""code → refresh_token + access_token + 채널 정보 → DB 저장."""
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
token_resp = await client.post(
|
||||
"https://oauth2.googleapis.com/token",
|
||||
data={
|
||||
"code": code,
|
||||
"client_id": _client_id(),
|
||||
"client_secret": _client_secret(),
|
||||
"redirect_uri": _redirect_uri(),
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
token_resp.raise_for_status()
|
||||
tok = token_resp.json()
|
||||
access = tok["access_token"]
|
||||
refresh = tok["refresh_token"]
|
||||
expires_at = _expiry_from_seconds(tok["expires_in"])
|
||||
|
||||
creds = _creds(access=access, refresh=refresh)
|
||||
yt = _build_youtube_client(creds)
|
||||
ch = yt.channels().list(part="snippet", mine=True).execute()
|
||||
item = ch["items"][0]
|
||||
db.upsert_oauth_token(
|
||||
channel_id=item["id"],
|
||||
channel_title=item["snippet"]["title"],
|
||||
avatar_url=item["snippet"]["thumbnails"]["default"]["url"],
|
||||
refresh_token=refresh, access_token=access, expires_at=expires_at,
|
||||
)
|
||||
return {"channel_id": item["id"], "channel_title": item["snippet"]["title"]}
|
||||
|
||||
|
||||
def get_status() -> dict | None:
|
||||
tok = db.get_oauth_token()
|
||||
if not tok:
|
||||
return None
|
||||
return {
|
||||
"channel_id": tok["channel_id"],
|
||||
"channel_title": tok["channel_title"],
|
||||
"avatar_url": tok["avatar_url"],
|
||||
}
|
||||
|
||||
|
||||
def disconnect() -> None:
|
||||
db.delete_oauth_token()
|
||||
|
||||
|
||||
def upload_video(*, video_path: str, thumbnail_path: str | None,
|
||||
metadata: dict, privacy: str) -> dict:
|
||||
tok = db.get_oauth_token()
|
||||
if not tok:
|
||||
raise NotAuthenticatedError("YouTube 인증 없음")
|
||||
creds = _creds(access=tok["access_token"], refresh=tok["refresh_token"])
|
||||
yt = _build_youtube_client(creds)
|
||||
|
||||
body = {
|
||||
"snippet": {
|
||||
"title": metadata["title"],
|
||||
"description": metadata["description"],
|
||||
"tags": metadata.get("tags", []),
|
||||
"categoryId": str(metadata.get("category_id", 10)),
|
||||
},
|
||||
"status": {"privacyStatus": privacy, "selfDeclaredMadeForKids": False},
|
||||
}
|
||||
media = MediaFileUpload(video_path, chunksize=4 * 1024 * 1024, resumable=True, mimetype="video/mp4")
|
||||
req = yt.videos().insert(part="snippet,status", body=body, media_body=media)
|
||||
|
||||
try:
|
||||
response = None
|
||||
while response is None:
|
||||
status, response = req.next_chunk()
|
||||
video_id = response["id"]
|
||||
except HttpError as e:
|
||||
if b"quotaExceeded" in (e.content or b""):
|
||||
raise QuotaExceededError(str(e))
|
||||
raise
|
||||
|
||||
if thumbnail_path:
|
||||
try:
|
||||
yt.thumbnails().set(videoId=video_id, media_body=thumbnail_path).execute()
|
||||
except HttpError as e:
|
||||
logger.warning("썸네일 업로드 실패: %s", e)
|
||||
|
||||
return {"video_id": video_id}
|
||||
|
||||
|
||||
def _build_youtube_client(creds): # patch 포인트
|
||||
return build("youtube", "v3", credentials=creds, cache_discovery=False)
|
||||
|
||||
|
||||
def _creds(access: str, refresh: str) -> Credentials:
|
||||
return Credentials(
|
||||
token=access, refresh_token=refresh,
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=_client_id(), client_secret=_client_secret(), scopes=SCOPES,
|
||||
)
|
||||
|
||||
|
||||
def _expiry_from_seconds(secs: int) -> str:
|
||||
from datetime import datetime, timedelta
|
||||
return (datetime.utcnow() + timedelta(seconds=secs)).isoformat(timespec="seconds")
|
||||
@@ -1,3 +1,4 @@
|
||||
[pytest]
|
||||
testpaths = tests
|
||||
pythonpath = .
|
||||
asyncio_mode = auto
|
||||
|
||||
@@ -4,6 +4,13 @@ requests==2.32.3
|
||||
python-multipart==0.0.12
|
||||
mutagen==1.47.0
|
||||
anthropic>=0.40.0
|
||||
openai>=1.20.0
|
||||
Pillow>=11.0.0
|
||||
pytest>=8.0.0
|
||||
pytest-asyncio>=0.21
|
||||
httpx>=0.27.0
|
||||
respx>=0.21
|
||||
freezegun>=1.4
|
||||
google-api-python-client>=2.100
|
||||
google-auth-oauthlib>=1.2
|
||||
google-auth-httplib2>=0.2
|
||||
|
||||
@@ -1,7 +1,36 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_db(tmp_path, monkeypatch):
|
||||
db_path = str(tmp_path / "test_music.db")
|
||||
monkeypatch.setattr("app.db.DB_PATH", db_path)
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def freezer():
|
||||
"""Minimal freezegun-based fixture providing `move_to(time)` to mimic
|
||||
pytest-freezer's `freezer` fixture using only the `freezegun` package."""
|
||||
from freezegun import freeze_time
|
||||
|
||||
class _Freezer:
|
||||
def __init__(self):
|
||||
self._ctx = None
|
||||
|
||||
def move_to(self, target):
|
||||
if self._ctx is not None:
|
||||
self._ctx.stop()
|
||||
self._ctx = freeze_time(target)
|
||||
self._ctx.start()
|
||||
|
||||
def stop(self):
|
||||
if self._ctx is not None:
|
||||
self._ctx.stop()
|
||||
self._ctx = None
|
||||
|
||||
f = _Freezer()
|
||||
try:
|
||||
yield f
|
||||
finally:
|
||||
f.stop()
|
||||
|
||||
93
music-lab/tests/test_cover_generation.py
Normal file
93
music-lab/tests/test_cover_generation.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import base64
|
||||
|
||||
import pytest
|
||||
import respx
|
||||
from httpx import Response
|
||||
from app.pipeline import cover, storage
|
||||
|
||||
|
||||
# Real PNG bytes (1x1 red pixel) so PIL can open
|
||||
_TINY_PNG = bytes.fromhex(
|
||||
"89504e470d0a1a0a0000000d49484452000000010000000108020000009077"
|
||||
"53de0000000c4944415478da6300010000050001"
|
||||
"0d0a2db40000000049454e44ae426082"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_storage(monkeypatch, tmp_path):
|
||||
monkeypatch.setattr(storage, "VIDEO_DATA_DIR", str(tmp_path))
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_dalle_success_saves_jpg(tmp_storage, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||
image_url = "https://oaidalleapiprodscus.blob.core.windows.net/x.png"
|
||||
respx.post("https://api.openai.com/v1/images/generations").mock(
|
||||
return_value=Response(200, json={"data": [{"url": image_url}]})
|
||||
)
|
||||
respx.get(image_url).mock(return_value=Response(200, content=_TINY_PNG))
|
||||
|
||||
out = await cover.generate(pipeline_id=42, genre="lo-fi",
|
||||
prompt_template="moody anime", mood="chill",
|
||||
track_title="Test")
|
||||
assert out["used_fallback"] is False
|
||||
assert out["url"].startswith("/media/videos/42/cover")
|
||||
assert (tmp_storage / "42" / "cover.jpg").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_dalle_http_error_falls_back_to_gradient(tmp_storage, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||
respx.post("https://api.openai.com/v1/images/generations").mock(
|
||||
return_value=Response(504)
|
||||
)
|
||||
out = await cover.generate(pipeline_id=43, genre="phonk",
|
||||
prompt_template="dark drift", mood="aggressive",
|
||||
track_title="Midnight Drive")
|
||||
assert out["used_fallback"] is True
|
||||
assert (tmp_storage / "43" / "cover.jpg").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_api_key_falls_back(tmp_storage, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
out = await cover.generate(pipeline_id=44, genre="ambient",
|
||||
prompt_template="x", mood="calm",
|
||||
track_title="Calm")
|
||||
assert out["used_fallback"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_dalle_with_feedback_appends_to_prompt(tmp_storage, monkeypatch):
|
||||
import json as _json
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||
captured = {}
|
||||
def hook(req):
|
||||
captured["body"] = _json.loads(req.content)
|
||||
return Response(200, json={"data": [{"url": "https://x"}]})
|
||||
respx.post("https://api.openai.com/v1/images/generations").mock(side_effect=hook)
|
||||
respx.get("https://x").mock(return_value=Response(200, content=_TINY_PNG))
|
||||
out = await cover.generate(pipeline_id=45, genre="lo-fi",
|
||||
prompt_template="moody anime", mood="chill",
|
||||
track_title="X", feedback="더 어둡게")
|
||||
assert "더 어둡게" in captured["body"]["prompt"]
|
||||
assert out["used_fallback"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_dalle_b64_response_handled(tmp_storage, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||
b64 = base64.b64encode(_TINY_PNG).decode()
|
||||
respx.post("https://api.openai.com/v1/images/generations").mock(
|
||||
return_value=Response(200, json={"data": [{"b64_json": b64}]})
|
||||
)
|
||||
out = await cover.generate(pipeline_id=46, genre="lo-fi",
|
||||
prompt_template="x", mood="", track_title="X")
|
||||
assert out["used_fallback"] is False
|
||||
assert (tmp_storage / "46" / "cover.jpg").exists()
|
||||
82
music-lab/tests/test_metadata_generation.py
Normal file
82
music-lab/tests/test_metadata_generation.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import pytest
|
||||
import respx
|
||||
from httpx import Response
|
||||
from app.pipeline import metadata
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_metadata_calls_claude_and_parses_json(monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
|
||||
payload = {
|
||||
"content": [{"type": "text", "text": '{"title":"[Lo-fi] Drive | 85BPM",'
|
||||
'"description":"chill","tags":["lofi","85bpm"],'
|
||||
'"category_id":10}'}]
|
||||
}
|
||||
respx.post("https://api.anthropic.com/v1/messages").mock(
|
||||
return_value=Response(200, json=payload)
|
||||
)
|
||||
result = await metadata.generate(
|
||||
track={"title": "Drive", "genre": "lo-fi", "bpm": 85, "key": "C", "scale": "minor",
|
||||
"moods": ["chill"], "instruments": ["piano"]},
|
||||
template={"title": "[{genre}] {title} | {bpm}BPM",
|
||||
"description": "{title}\n", "tags": [], "category_id": 10},
|
||||
trend_keywords=["lofi", "study"],
|
||||
feedback="",
|
||||
)
|
||||
assert result["title"].startswith("[Lo-fi]")
|
||||
assert "lofi" in result["tags"]
|
||||
assert result["used_fallback"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metadata_fallback_when_no_api_key(monkeypatch):
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
result = await metadata.generate(
|
||||
track={"title": "Drive", "genre": "lo-fi", "bpm": 85, "key": "C", "scale": "minor",
|
||||
"moods": [], "instruments": []},
|
||||
template={"title": "[{genre}] {title} | {bpm}BPM",
|
||||
"description": "{title}", "tags": ["lofi"], "category_id": 10},
|
||||
trend_keywords=[],
|
||||
)
|
||||
# 템플릿 변수 그대로 치환된 폴백
|
||||
assert result["title"] == "[lo-fi] Drive | 85BPM"
|
||||
assert result["used_fallback"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_metadata_includes_feedback_in_prompt(monkeypatch):
|
||||
import json
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
|
||||
captured = {}
|
||||
def hook(req):
|
||||
captured["body"] = json.loads(req.content)
|
||||
return Response(200, json={"content": [{"type": "text",
|
||||
"text": '{"title":"x","description":"y","tags":[],"category_id":10}'}]})
|
||||
respx.post("https://api.anthropic.com/v1/messages").mock(side_effect=hook)
|
||||
await metadata.generate(
|
||||
track={"title": "X", "genre": "lo-fi", "bpm": 85, "key": "C", "scale": "minor",
|
||||
"moods": [], "instruments": []},
|
||||
template={"title": "{title}", "description": "{title}", "tags": [], "category_id": 10},
|
||||
trend_keywords=[],
|
||||
feedback="제목을 짧게",
|
||||
)
|
||||
assert "제목을 짧게" in str(captured["body"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_metadata_falls_back_on_api_error(monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key")
|
||||
respx.post("https://api.anthropic.com/v1/messages").mock(
|
||||
return_value=Response(500)
|
||||
)
|
||||
result = await metadata.generate(
|
||||
track={"title": "Drive", "genre": "lo-fi", "bpm": 85, "key": "C", "scale": "minor",
|
||||
"moods": [], "instruments": []},
|
||||
template={"title": "[{genre}] {title}", "description": "x", "tags": ["lofi"], "category_id": 10},
|
||||
trend_keywords=[],
|
||||
)
|
||||
assert result["used_fallback"] is True
|
||||
assert "Drive" in result["title"]
|
||||
96
music-lab/tests/test_pipeline_db.py
Normal file
96
music-lab/tests/test_pipeline_db.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import os
|
||||
import tempfile
|
||||
import pytest
|
||||
from app import db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fresh_db(monkeypatch, tmp_path):
|
||||
db_path = tmp_path / "music.db"
|
||||
monkeypatch.setattr(db, "DB_PATH", str(db_path))
|
||||
db.init_db()
|
||||
return db_path
|
||||
|
||||
|
||||
def test_create_pipeline_inserts_row(fresh_db):
|
||||
pid = db.create_pipeline(track_id=1)
|
||||
row = db.get_pipeline(pid)
|
||||
assert row["id"] == pid
|
||||
assert row["state"] == "created"
|
||||
assert row["track_id"] == 1
|
||||
assert row["feedback_count_per_step"] == {}
|
||||
|
||||
|
||||
def test_update_pipeline_state_records_started_at(fresh_db, freezer):
|
||||
pid = db.create_pipeline(track_id=1)
|
||||
freezer.move_to("2026-05-07T08:00:00")
|
||||
db.update_pipeline_state(pid, "cover_pending")
|
||||
row = db.get_pipeline(pid)
|
||||
assert row["state"] == "cover_pending"
|
||||
assert row["state_started_at"] == "2026-05-07T08:00:00"
|
||||
|
||||
|
||||
def test_increment_feedback_count(fresh_db):
|
||||
pid = db.create_pipeline(track_id=1)
|
||||
db.increment_feedback_count(pid, "cover")
|
||||
db.increment_feedback_count(pid, "cover")
|
||||
row = db.get_pipeline(pid)
|
||||
assert row["feedback_count_per_step"] == {"cover": 2}
|
||||
|
||||
|
||||
def test_record_feedback(fresh_db):
|
||||
pid = db.create_pipeline(track_id=1)
|
||||
db.record_feedback(pid, "cover", "더 어둡게")
|
||||
rows = db.get_feedback_history(pid)
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["feedback_text"] == "더 어둡게"
|
||||
|
||||
|
||||
def test_create_pipeline_job_lifecycle(fresh_db):
|
||||
pid = db.create_pipeline(track_id=1)
|
||||
job_id = db.create_pipeline_job(pid, "cover")
|
||||
db.update_pipeline_job(job_id, status="running")
|
||||
db.update_pipeline_job(job_id, status="succeeded", duration_ms=1234)
|
||||
jobs = db.list_pipeline_jobs(pid)
|
||||
assert jobs[0]["status"] == "succeeded"
|
||||
assert jobs[0]["duration_ms"] == 1234
|
||||
|
||||
|
||||
def test_youtube_setup_default_row_created_on_init(fresh_db):
|
||||
setup = db.get_youtube_setup()
|
||||
assert setup["review_threshold"] == 60
|
||||
assert "metadata_template_json" in setup
|
||||
|
||||
|
||||
def test_youtube_oauth_token_upsert(fresh_db):
|
||||
db.upsert_oauth_token(
|
||||
channel_id="UC123",
|
||||
channel_title="My Channel",
|
||||
avatar_url="https://...",
|
||||
refresh_token="r1",
|
||||
access_token="a1",
|
||||
expires_at="2026-05-07T09:00:00",
|
||||
)
|
||||
tok = db.get_oauth_token()
|
||||
assert tok["channel_id"] == "UC123"
|
||||
assert tok["refresh_token"] == "r1"
|
||||
db.upsert_oauth_token(
|
||||
channel_id="UC123", channel_title="My Channel",
|
||||
avatar_url=None, refresh_token="r2",
|
||||
access_token="a2", expires_at="2026-05-07T10:00:00",
|
||||
)
|
||||
tok = db.get_oauth_token()
|
||||
assert tok["refresh_token"] == "r2" # upsert
|
||||
|
||||
|
||||
def test_update_pipeline_state_rejects_unknown_column(fresh_db):
|
||||
pid = db.create_pipeline(track_id=1)
|
||||
with pytest.raises(ValueError):
|
||||
db.update_pipeline_state(pid, "cover_pending", evil_col="x; DROP TABLE")
|
||||
|
||||
|
||||
def test_update_pipeline_job_rejects_unknown_column(fresh_db):
|
||||
pid = db.create_pipeline(track_id=1)
|
||||
job_id = db.create_pipeline_job(pid, "cover")
|
||||
with pytest.raises(ValueError):
|
||||
db.update_pipeline_job(job_id, evil_col="x")
|
||||
110
music-lab/tests/test_pipeline_endpoints.py
Normal file
110
music-lab/tests/test_pipeline_endpoints.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import sqlite3
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.main import app
|
||||
from app import db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(monkeypatch, tmp_path):
|
||||
monkeypatch.setattr(db, "DB_PATH", str(tmp_path / "music.db"))
|
||||
db.init_db()
|
||||
# 최소 트랙 1개 — music_library 테이블에 직접 삽입
|
||||
conn = sqlite3.connect(db.DB_PATH)
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""INSERT INTO music_library
|
||||
(id, title, genre, moods, instruments, duration_sec, bpm, key, scale,
|
||||
prompt, audio_url, file_path, task_id, tags)
|
||||
VALUES (1, 'T', 'lo-fi', '["chill"]', '["piano"]', 120, 85, 'C', 'maj',
|
||||
'p', '/media/music/x.mp3', '/app/data/music/x.mp3', NULL, '[]')""",
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_create_pipeline(client):
|
||||
r = client.post("/api/music/pipeline", json={"track_id": 1})
|
||||
assert r.status_code == 201
|
||||
assert r.json()["state"] == "created"
|
||||
|
||||
|
||||
def test_create_duplicate_pipeline_returns_409(client):
|
||||
client.post("/api/music/pipeline", json={"track_id": 1})
|
||||
r = client.post("/api/music/pipeline", json={"track_id": 1})
|
||||
assert r.status_code == 409
|
||||
|
||||
|
||||
def test_get_pipeline_returns_jobs_and_feedback(client):
|
||||
pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"]
|
||||
r = client.get(f"/api/music/pipeline/{pid}")
|
||||
assert "jobs" in r.json()
|
||||
assert "feedback" in r.json()
|
||||
|
||||
|
||||
def test_list_pipelines_active_filter(client):
|
||||
pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"]
|
||||
db.update_pipeline_state(pid, "published")
|
||||
r = client.get("/api/music/pipeline?status=active")
|
||||
assert all(p["state"] != "published" for p in r.json()["pipelines"])
|
||||
|
||||
|
||||
def test_feedback_reject_records_feedback_and_increments_count(client):
|
||||
pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"]
|
||||
db.update_pipeline_state(pid, "cover_pending")
|
||||
# orchestrator.run_step를 mock해서 백그라운드 작업이 cover_pending을 변경하지 않도록
|
||||
with patch("app.main.orchestrator.run_step", new=AsyncMock()):
|
||||
r = client.post(
|
||||
f"/api/music/pipeline/{pid}/feedback",
|
||||
json={"step": "cover", "intent": "reject", "feedback_text": "더 어둡게"},
|
||||
)
|
||||
assert r.status_code == 202
|
||||
p = db.get_pipeline(pid)
|
||||
assert p["feedback_count_per_step"]["cover"] == 1
|
||||
history = db.get_feedback_history(pid)
|
||||
assert history[0]["feedback_text"] == "더 어둡게"
|
||||
|
||||
|
||||
def test_feedback_after_5_rejects_marks_awaiting_manual(client):
|
||||
pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"]
|
||||
db.update_pipeline_state(pid, "cover_pending")
|
||||
with patch("app.main.orchestrator.run_step", new=AsyncMock()):
|
||||
for i in range(5):
|
||||
client.post(
|
||||
f"/api/music/pipeline/{pid}/feedback",
|
||||
json={"step": "cover", "intent": "reject", "feedback_text": f"again {i}"},
|
||||
)
|
||||
r = client.post(
|
||||
f"/api/music/pipeline/{pid}/feedback",
|
||||
json={"step": "cover", "intent": "reject", "feedback_text": "6th"},
|
||||
)
|
||||
assert r.status_code == 409
|
||||
assert db.get_pipeline(pid)["state"] == "awaiting_manual"
|
||||
|
||||
|
||||
def test_cancel_pipeline(client):
|
||||
pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"]
|
||||
r = client.post(f"/api/music/pipeline/{pid}/cancel")
|
||||
assert r.status_code == 200
|
||||
assert db.get_pipeline(pid)["state"] == "cancelled"
|
||||
|
||||
|
||||
def test_setup_get_returns_defaults(client):
|
||||
r = client.get("/api/music/setup")
|
||||
assert r.status_code == 200
|
||||
assert r.json()["review_threshold"] == 60
|
||||
|
||||
|
||||
def test_setup_put_updates(client):
|
||||
r = client.put("/api/music/setup", json={"review_threshold": 70})
|
||||
assert r.status_code == 200
|
||||
assert r.json()["review_threshold"] == 70
|
||||
|
||||
|
||||
def test_youtube_status_when_disconnected(client):
|
||||
r = client.get("/api/music/youtube/status")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"connected": False}
|
||||
113
music-lab/tests/test_pipeline_flow.py
Normal file
113
music-lab/tests/test_pipeline_flow.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""풀 파이프라인 happy-path 통합 테스트 — 모든 외부 호출 mock."""
|
||||
import pytest
|
||||
import sqlite3
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
from app.main import app
|
||||
from app import db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(monkeypatch, tmp_path):
|
||||
monkeypatch.setattr(db, "DB_PATH", str(tmp_path / "music.db"))
|
||||
db.init_db()
|
||||
# 기본 트랙 1개 등록 (music_library 테이블)
|
||||
conn = sqlite3.connect(db.DB_PATH)
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"""INSERT INTO music_library
|
||||
(id, title, genre, moods, instruments, duration_sec, bpm, key, scale,
|
||||
prompt, audio_url, file_path, task_id, tags)
|
||||
VALUES (1, 'Integration Test', 'lo-fi', '["chill"]', '["piano"]',
|
||||
120, 85, 'C', 'maj', 'p',
|
||||
'/media/music/x.mp3', '/app/data/music/x.mp3', NULL, '[]')""",
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.OperationalError as e:
|
||||
# 스키마가 다르면 테스트 스킵 가이드 표시
|
||||
pytest.skip(f"music_library schema mismatch: {e}")
|
||||
conn.close()
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
# `new=...`로 patch한 항목은 mock 인자를 함수에 주입하지 않음.
|
||||
# return_value로 patch한 두 개(cover/video는 new=AsyncMock 사용 → 미주입,
|
||||
# thumb/video는 return_value → 주입, youtube.upload_video도 return_value → 주입).
|
||||
# 데코레이터는 아래에서 위 순서로 인자에 주입됨:
|
||||
# 1) cover (new=AsyncMock) → 미주입
|
||||
# 2) video (return_value) → 주입 → mock_video
|
||||
# 3) thumb (return_value) → 주입 → mock_thumb
|
||||
# 4) metadata (new=AsyncMock) → 미주입
|
||||
# 5) review (new=AsyncMock) → 미주입
|
||||
# 6) youtube.upload_video (return_value) → 주입 → mock_yt
|
||||
@patch("app.pipeline.youtube.upload_video", return_value={"video_id": "VID999"})
|
||||
@patch("app.pipeline.review.run_4_axis", new=AsyncMock(return_value={
|
||||
"metadata_quality": {"score": 80, "notes": "ok"},
|
||||
"policy_compliance": {"score": 90, "issues": []},
|
||||
"viewer_experience": {"score": 80, "notes": "ok"},
|
||||
"trend_alignment": {"score": 70, "matched_keywords": ["lofi"]},
|
||||
"weighted_total": 80.0, "verdict": "pass", "summary": "good", "used_fallback": False,
|
||||
}))
|
||||
@patch("app.pipeline.metadata.generate", new=AsyncMock(return_value={
|
||||
"title": "Integration Test", "description": "Test desc",
|
||||
"tags": ["lofi"], "category_id": 10, "used_fallback": False, "error": None,
|
||||
}))
|
||||
@patch("app.pipeline.thumb.generate", return_value={
|
||||
"url": "/media/videos/1/thumbnail.jpg", "used_fallback": False,
|
||||
})
|
||||
@patch("app.pipeline.video.generate", return_value={
|
||||
"url": "/media/videos/1/video.mp4", "used_fallback": False, "duration_sec": 120,
|
||||
})
|
||||
@patch("app.pipeline.cover.generate", new=AsyncMock(return_value={
|
||||
"url": "/media/videos/1/cover.jpg", "used_fallback": False, "error": None,
|
||||
}))
|
||||
def test_full_pipeline_happy_path(mock_video, mock_thumb, mock_yt, client):
|
||||
# 1. 파이프라인 생성
|
||||
pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"]
|
||||
assert db.get_pipeline(pid)["state"] == "created"
|
||||
|
||||
# 2. 시작 → cover 자동 생성 → cover_pending
|
||||
r = client.post(f"/api/music/pipeline/{pid}/start")
|
||||
assert r.status_code == 202
|
||||
assert db.get_pipeline(pid)["state"] == "cover_pending"
|
||||
|
||||
# 3. 4단계 사용자 승인 (cover, video, thumb, meta)
|
||||
for step in ["cover", "video", "thumb", "meta"]:
|
||||
r = client.post(f"/api/music/pipeline/{pid}/feedback",
|
||||
json={"step": step, "intent": "approve"})
|
||||
assert r.status_code == 202
|
||||
|
||||
# 4. ai_review 자동 진행 후 publish_pending
|
||||
p = db.get_pipeline(pid)
|
||||
assert p["state"] == "publish_pending"
|
||||
assert p["review"]["verdict"] == "pass"
|
||||
|
||||
# 5. 발행 트리거
|
||||
r = client.post(f"/api/music/pipeline/{pid}/publish")
|
||||
assert r.status_code == 202
|
||||
|
||||
# 6. 최종 published
|
||||
p = db.get_pipeline(pid)
|
||||
assert p["state"] == "published"
|
||||
assert p["youtube_video_id"] == "VID999"
|
||||
|
||||
|
||||
# cover.generate는 new=AsyncMock이므로 함수 인자에 주입되지 않음.
|
||||
@patch("app.pipeline.cover.generate", new=AsyncMock(return_value={
|
||||
"url": "/media/videos/2/cover.jpg", "used_fallback": False, "error": None,
|
||||
}))
|
||||
def test_pipeline_reject_and_regenerate(client):
|
||||
pid = client.post("/api/music/pipeline", json={"track_id": 1}).json()["id"]
|
||||
client.post(f"/api/music/pipeline/{pid}/start")
|
||||
assert db.get_pipeline(pid)["state"] == "cover_pending"
|
||||
|
||||
# 반려 + 피드백 → 같은 단계 재진입
|
||||
r = client.post(f"/api/music/pipeline/{pid}/feedback",
|
||||
json={"step": "cover", "intent": "reject", "feedback_text": "더 어둡게"})
|
||||
assert r.status_code == 202
|
||||
p = db.get_pipeline(pid)
|
||||
assert p["state"] == "cover_pending" # 같은 단계 유지
|
||||
assert p["feedback_count_per_step"]["cover"] == 1
|
||||
history = db.get_feedback_history(pid)
|
||||
assert history[0]["feedback_text"] == "더 어둡게"
|
||||
84
music-lab/tests/test_review.py
Normal file
84
music-lab/tests/test_review.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import pytest
|
||||
import respx
|
||||
from httpx import Response
|
||||
from app.pipeline import review
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_review_returns_pass_when_above_threshold(monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "k")
|
||||
body = {"content": [{"type": "text", "text":
|
||||
'{"metadata_quality":{"score":80,"notes":"x"},'
|
||||
'"policy_compliance":{"score":90,"issues":[]},'
|
||||
'"viewer_experience":{"score":75,"notes":"y"},'
|
||||
'"trend_alignment":{"score":70,"matched_keywords":["lofi"]},'
|
||||
'"summary":"good"}'}]}
|
||||
respx.post("https://api.anthropic.com/v1/messages").mock(return_value=Response(200, json=body))
|
||||
result = await review.run_4_axis(
|
||||
pipeline={"id": 1}, track={"title": "x", "genre": "lo-fi", "bpm": 85},
|
||||
video_meta={"length_sec": 120, "resolution": "1920x1080"},
|
||||
metadata={"title": "Y", "description": "Z", "tags": ["lofi"], "category_id": 10},
|
||||
thumbnail_url="/m/x.jpg", trend_top=["lofi"],
|
||||
weights={"meta": 25, "policy": 30, "viewer": 25, "trend": 20},
|
||||
threshold=60,
|
||||
)
|
||||
assert result["verdict"] == "pass"
|
||||
expected_total = 0.25 * 80 + 0.30 * 90 + 0.25 * 75 + 0.20 * 70
|
||||
assert result["weighted_total"] == pytest.approx(expected_total, abs=0.01)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_review_fail_below_threshold(monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "k")
|
||||
body = {"content": [{"type": "text", "text":
|
||||
'{"metadata_quality":{"score":40,"notes":"x"},'
|
||||
'"policy_compliance":{"score":50,"issues":[]},'
|
||||
'"viewer_experience":{"score":30,"notes":"y"},'
|
||||
'"trend_alignment":{"score":20,"matched_keywords":[]},'
|
||||
'"summary":"bad"}'}]}
|
||||
respx.post("https://api.anthropic.com/v1/messages").mock(return_value=Response(200, json=body))
|
||||
result = await review.run_4_axis(
|
||||
pipeline={"id": 2}, track={"title": "x", "genre": "lo-fi", "bpm": 85},
|
||||
video_meta={"length_sec": 120, "resolution": "1920x1080"},
|
||||
metadata={"title": "Y", "description": "Z", "tags": [], "category_id": 10},
|
||||
thumbnail_url="/m/x.jpg", trend_top=[],
|
||||
weights={"meta": 25, "policy": 30, "viewer": 25, "trend": 20},
|
||||
threshold=60,
|
||||
)
|
||||
assert result["verdict"] == "fail"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_review_heuristic_fallback_on_llm_error(monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "k")
|
||||
respx.post("https://api.anthropic.com/v1/messages").mock(return_value=Response(500))
|
||||
result = await review.run_4_axis(
|
||||
pipeline={"id": 3}, track={"title": "x", "genre": "lo-fi", "bpm": 85, "duration_sec": 120},
|
||||
video_meta={"length_sec": 120, "resolution": "1920x1080"},
|
||||
metadata={"title": "Y" * 30, "description": "Z" * 200, "tags": ["a", "b", "c", "d", "e"], "category_id": 10},
|
||||
thumbnail_url="/m/x.jpg", trend_top=["lofi"],
|
||||
weights={"meta": 25, "policy": 30, "viewer": 25, "trend": 20},
|
||||
threshold=60,
|
||||
)
|
||||
assert result["used_fallback"] is True
|
||||
assert "weighted_total" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_review_heuristic_when_no_api_key(monkeypatch):
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
result = await review.run_4_axis(
|
||||
pipeline={"id": 4}, track={"title": "x", "genre": "lo-fi", "bpm": 85, "duration_sec": 120},
|
||||
video_meta={"length_sec": 120, "resolution": "1920x1080"},
|
||||
metadata={"title": "Test Title", "description": "Description here, more text " * 5,
|
||||
"tags": ["lofi", "study", "chill", "ambient", "instrumental"], "category_id": 10},
|
||||
thumbnail_url="/m/x.jpg", trend_top=["lofi"],
|
||||
weights={"meta": 25, "policy": 30, "viewer": 25, "trend": 20},
|
||||
threshold=60,
|
||||
)
|
||||
assert result["used_fallback"] is True
|
||||
# 휴리스틱: 좋은 메타+영상길이 일치+태그 트렌드 겹침 → pass 기대
|
||||
assert result["verdict"] == "pass"
|
||||
49
music-lab/tests/test_state_machine.py
Normal file
49
music-lab/tests/test_state_machine.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
from app.pipeline.state_machine import (
|
||||
next_state_on_approve, next_state_on_reject, can_transition, STEPS, USER_GATES,
|
||||
)
|
||||
|
||||
|
||||
def test_steps_sequence():
|
||||
assert STEPS == ["cover", "video", "thumb", "meta", "review", "publish"]
|
||||
|
||||
|
||||
def test_user_gates_excludes_review():
|
||||
assert "review" not in USER_GATES
|
||||
assert "publish" in USER_GATES
|
||||
assert "cover" in USER_GATES
|
||||
|
||||
|
||||
def test_approve_progression():
|
||||
assert next_state_on_approve("cover_pending") == "video_pending"
|
||||
assert next_state_on_approve("video_pending") == "thumb_pending"
|
||||
assert next_state_on_approve("thumb_pending") == "meta_pending"
|
||||
assert next_state_on_approve("meta_pending") == "ai_review"
|
||||
assert next_state_on_approve("publish_pending") == "publishing"
|
||||
|
||||
|
||||
def test_approve_invalid_state_raises():
|
||||
with pytest.raises(ValueError):
|
||||
next_state_on_approve("ai_review") # 자동 전이 — approve 호출 자체가 무효
|
||||
|
||||
|
||||
def test_reject_keeps_same_state():
|
||||
# 반려는 같은 *_pending 상태를 유지(재생성 트리거)
|
||||
assert next_state_on_reject("cover_pending") == "cover_pending"
|
||||
assert next_state_on_reject("publish_pending") == "publish_pending"
|
||||
|
||||
|
||||
def test_can_transition_blocks_terminal_states():
|
||||
assert not can_transition("published", "cover_pending")
|
||||
assert not can_transition("cancelled", "cover_pending")
|
||||
assert not can_transition("failed", "cover_pending")
|
||||
|
||||
|
||||
def test_can_transition_allows_cancel_from_anywhere():
|
||||
assert can_transition("cover_pending", "cancelled")
|
||||
assert can_transition("publishing", "cancelled")
|
||||
|
||||
|
||||
def test_can_transition_allows_failed_from_pending():
|
||||
assert can_transition("video_pending", "failed")
|
||||
assert can_transition("publishing", "failed")
|
||||
66
music-lab/tests/test_video_thumb.py
Normal file
66
music-lab/tests/test_video_thumb.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from app.pipeline import video, thumb, storage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_storage(monkeypatch, tmp_path):
|
||||
monkeypatch.setattr(storage, "VIDEO_DATA_DIR", str(tmp_path))
|
||||
# 더미 입력 파일들
|
||||
audio = tmp_path / "audio.mp3"
|
||||
audio.write_bytes(b"\x00" * 100)
|
||||
cover_dir = tmp_path / "50"
|
||||
cover_dir.mkdir()
|
||||
cover = cover_dir / "cover.jpg"
|
||||
cover.write_bytes(b"\x00" * 100)
|
||||
return tmp_path
|
||||
|
||||
|
||||
@patch("subprocess.run")
|
||||
def test_generate_video_calls_ffmpeg(mock_run, tmp_storage):
|
||||
mock_run.return_value = MagicMock(returncode=0, stderr="")
|
||||
out = video.generate(pipeline_id=50, audio_path=str(tmp_storage / "audio.mp3"),
|
||||
cover_path=str(tmp_storage / "50" / "cover.jpg"),
|
||||
genre="lo-fi", duration_sec=120, resolution="1920x1080",
|
||||
style="visualizer")
|
||||
assert out["url"].endswith("/50/video.mp4")
|
||||
assert out["used_fallback"] is False
|
||||
args = mock_run.call_args[0][0]
|
||||
assert args[0] == "ffmpeg"
|
||||
assert "-i" in args
|
||||
assert "showwaves" in " ".join(args)
|
||||
|
||||
|
||||
@patch("subprocess.run")
|
||||
def test_generate_video_failure_marks_failed(mock_run, tmp_storage):
|
||||
mock_run.return_value = MagicMock(returncode=1, stderr="bad codec")
|
||||
with pytest.raises(video.VideoGenerationError):
|
||||
video.generate(pipeline_id=51, audio_path=str(tmp_storage / "audio.mp3"),
|
||||
cover_path=str(tmp_storage / "50" / "cover.jpg"),
|
||||
genre="lo-fi", duration_sec=120, resolution="1920x1080",
|
||||
style="visualizer")
|
||||
|
||||
|
||||
@patch("subprocess.run")
|
||||
def test_thumb_extracts_frame(mock_run, tmp_storage):
|
||||
mock_run.return_value = MagicMock(returncode=0, stderr="")
|
||||
video_path = tmp_storage / "60" / "video.mp4"
|
||||
video_path.parent.mkdir()
|
||||
video_path.write_bytes(b"\x00" * 100)
|
||||
out = thumb.generate(pipeline_id=60, video_path=str(video_path),
|
||||
track_title="Midnight Drive", overlay_text=False)
|
||||
assert out["url"].endswith("/60/thumbnail.jpg")
|
||||
args = mock_run.call_args[0][0]
|
||||
assert args[0] == "ffmpeg"
|
||||
|
||||
|
||||
@patch("subprocess.run")
|
||||
def test_thumb_failure_raises(mock_run, tmp_storage):
|
||||
mock_run.return_value = MagicMock(returncode=1, stderr="bad input")
|
||||
video_path = tmp_storage / "61" / "video.mp4"
|
||||
video_path.parent.mkdir()
|
||||
video_path.write_bytes(b"\x00" * 100)
|
||||
with pytest.raises(thumb.ThumbGenerationError):
|
||||
thumb.generate(pipeline_id=61, video_path=str(video_path),
|
||||
track_title="X", overlay_text=False)
|
||||
88
music-lab/tests/test_youtube_upload.py
Normal file
88
music-lab/tests/test_youtube_upload.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from app.pipeline import youtube
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fresh_db(monkeypatch, tmp_path):
|
||||
from app import db
|
||||
monkeypatch.setattr(db, "DB_PATH", str(tmp_path / "music.db"))
|
||||
db.init_db()
|
||||
return db
|
||||
|
||||
|
||||
def _setup_token(db_module):
|
||||
db_module.upsert_oauth_token(
|
||||
channel_id="UC1", channel_title="t", avatar_url=None,
|
||||
refresh_token="r1", access_token="a1", expires_at="2099-01-01T00:00:00",
|
||||
)
|
||||
|
||||
|
||||
@patch("app.pipeline.youtube._build_youtube_client")
|
||||
def test_upload_succeeds_after_resumable(mock_client, fresh_db, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("YOUTUBE_OAUTH_CLIENT_ID", "cid")
|
||||
monkeypatch.setenv("YOUTUBE_OAUTH_CLIENT_SECRET", "sec")
|
||||
_setup_token(fresh_db)
|
||||
|
||||
yt = MagicMock()
|
||||
insert = MagicMock()
|
||||
# next_chunk: first call returns (None, None), second returns (None, response with id)
|
||||
insert.next_chunk.side_effect = [(None, None), (None, {"id": "VID123"})]
|
||||
yt.videos().insert.return_value = insert
|
||||
mock_client.return_value = yt
|
||||
|
||||
video_path = tmp_path / "v.mp4"
|
||||
video_path.write_bytes(b"\x00" * 100)
|
||||
out = youtube.upload_video(
|
||||
video_path=str(video_path),
|
||||
thumbnail_path=None,
|
||||
metadata={"title": "T", "description": "D", "tags": ["x"], "category_id": 10},
|
||||
privacy="private",
|
||||
)
|
||||
assert out["video_id"] == "VID123"
|
||||
|
||||
|
||||
def test_upload_no_token_raises(fresh_db, tmp_path):
|
||||
video_path = tmp_path / "v.mp4"
|
||||
video_path.write_bytes(b"\x00")
|
||||
with pytest.raises(youtube.NotAuthenticatedError):
|
||||
youtube.upload_video(
|
||||
video_path=str(video_path), thumbnail_path=None,
|
||||
metadata={"title":"T","description":"D","tags":[],"category_id":10},
|
||||
privacy="private",
|
||||
)
|
||||
|
||||
|
||||
@patch("app.pipeline.youtube._build_youtube_client")
|
||||
def test_upload_quota_exceeded_marks_quota(mock_client, fresh_db, tmp_path, monkeypatch):
|
||||
from googleapiclient.errors import HttpError
|
||||
monkeypatch.setenv("YOUTUBE_OAUTH_CLIENT_ID", "cid")
|
||||
monkeypatch.setenv("YOUTUBE_OAUTH_CLIENT_SECRET", "sec")
|
||||
_setup_token(fresh_db)
|
||||
|
||||
yt = MagicMock()
|
||||
err = HttpError(MagicMock(status=403), b'{"error":{"errors":[{"reason":"quotaExceeded"}]}}')
|
||||
insert_call = MagicMock()
|
||||
insert_call.next_chunk.side_effect = err
|
||||
yt.videos().insert.return_value = insert_call
|
||||
mock_client.return_value = yt
|
||||
|
||||
video_path = tmp_path / "v.mp4"
|
||||
video_path.write_bytes(b"\x00")
|
||||
with pytest.raises(youtube.QuotaExceededError):
|
||||
youtube.upload_video(
|
||||
video_path=str(video_path), thumbnail_path=None,
|
||||
metadata={"title":"T","description":"D","tags":[],"category_id":10},
|
||||
privacy="private",
|
||||
)
|
||||
|
||||
|
||||
def test_get_status_returns_none_when_not_connected(fresh_db):
|
||||
assert youtube.get_status() is None
|
||||
|
||||
|
||||
def test_get_status_returns_channel_info(fresh_db):
|
||||
_setup_token(fresh_db)
|
||||
s = youtube.get_status()
|
||||
assert s["channel_id"] == "UC1"
|
||||
assert s["channel_title"] == "t"
|
||||
Reference in New Issue
Block a user