music-lab: Suno API + MusicGen 듀얼 프로바이더 구조 구현
- suno_provider.py: Suno REST API 클라이언트 (곡 생성, 가사, 2변형 저장) - local_provider.py: 기존 MusicGen 로직 분리 - main.py: provider 라우팅, /providers·/lyrics 엔드포인트 추가 - db.py: provider, lyrics, image_url, suno_id 컬럼 마이그레이션 - docker-compose.yml: SUNO_API_KEY 환경변수 추가 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,18 +1,17 @@
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
import requests
|
||||
from typing import List, Optional
|
||||
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .db import (
|
||||
init_db,
|
||||
create_task, update_task, get_task,
|
||||
create_task, get_task,
|
||||
get_all_tracks, add_track, delete_track, get_track_file_path, get_track_by_task_id,
|
||||
)
|
||||
from .local_provider import run_local_generation
|
||||
from .suno_provider import run_suno_generation, generate_lyrics, SUNO_API_KEY
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@@ -25,9 +24,7 @@ app.add_middleware(
|
||||
allow_headers=["Content-Type"],
|
||||
)
|
||||
|
||||
MUSIC_AI_SERVER_URL = os.getenv("MUSIC_AI_SERVER_URL", "")
|
||||
MUSIC_DATA_DIR = "/app/data/music"
|
||||
MUSIC_MEDIA_BASE = os.getenv("MUSIC_MEDIA_BASE", "/media/music")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
@@ -41,106 +38,31 @@ def health():
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── 음악 생성 워커 ────────────────────────────────────────────────────────────
|
||||
|
||||
def _run_generation(task_id: str, params: dict) -> None:
|
||||
"""BackgroundTask: AI 서버에 생성 요청 → 파일 저장 → 라이브러리 등록"""
|
||||
try:
|
||||
update_task(task_id, "processing", 10, "AI 서버에 연결 중...")
|
||||
|
||||
if not MUSIC_AI_SERVER_URL:
|
||||
update_task(task_id, "failed", 0, "", error="MUSIC_AI_SERVER_URL이 설정되지 않았습니다")
|
||||
return
|
||||
|
||||
update_task(task_id, "processing", 30, "음악 생성 중... (수 분 소요될 수 있습니다)")
|
||||
|
||||
# 1단계: 생성 요청 → ai_task_id 반환
|
||||
resp = requests.post(
|
||||
f"{MUSIC_AI_SERVER_URL}/generate",
|
||||
json=params,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
update_task(task_id, "failed", 0, "", error=f"AI 서버 오류: {resp.status_code} {resp.text[:200]}")
|
||||
return
|
||||
|
||||
ai_task_id = resp.json().get("task_id")
|
||||
if not ai_task_id:
|
||||
update_task(task_id, "failed", 0, "", error="AI 서버 응답에 task_id가 없습니다")
|
||||
return
|
||||
|
||||
# 2단계: 상태 폴링 (최대 10분, 5초 간격) — AI 서버 progress/message 그대로 반영
|
||||
remote_url = None
|
||||
for _ in range(120):
|
||||
time.sleep(5)
|
||||
status_resp = requests.get(f"{MUSIC_AI_SERVER_URL}/status/{ai_task_id}", timeout=10)
|
||||
status_data = status_resp.json()
|
||||
ai_status = status_data.get("status")
|
||||
|
||||
# AI 서버의 progress/message를 로컬 task에 전달 (30~79 범위로 스케일)
|
||||
ai_progress = status_data.get("progress", 0)
|
||||
ai_message = status_data.get("message", "음악 생성 중...")
|
||||
scaled = 30 + int(ai_progress * 0.49) # 30% ~ 79%
|
||||
update_task(task_id, "processing", scaled, ai_message)
|
||||
|
||||
if ai_status == "succeeded":
|
||||
remote_url = status_data.get("audio_url")
|
||||
break
|
||||
elif ai_status == "failed":
|
||||
update_task(task_id, "failed", 0, "", error=status_data.get("error", "AI 서버 생성 실패"))
|
||||
return
|
||||
|
||||
if not remote_url:
|
||||
update_task(task_id, "failed", 0, "", error="AI 서버 타임아웃 (10분 초과)")
|
||||
return
|
||||
|
||||
update_task(task_id, "processing", 80, "파일 저장 중...")
|
||||
|
||||
filename = f"{task_id}.mp3"
|
||||
file_path = os.path.join(MUSIC_DATA_DIR, filename)
|
||||
|
||||
# 3단계: 오디오 파일 다운로드
|
||||
dl = requests.get(remote_url, timeout=120, stream=True)
|
||||
with open(file_path, "wb") as f:
|
||||
for chunk in dl.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
# audio_url은 항상 Nginx 상대경로 (Mixed Content 방지)
|
||||
audio_url = f"/media/music/{filename}"
|
||||
|
||||
# 라이브러리 자동 등록 — payload title 우선, 없으면 자동 생성
|
||||
genre = params.get("genre", "")
|
||||
moods = params.get("moods", [])
|
||||
mood_str = moods[0] if moods else "Original"
|
||||
title = params.get("title") or (f"{genre} — {mood_str} Mix" if genre else f"{mood_str} Mix")
|
||||
|
||||
add_track({
|
||||
"title": title,
|
||||
"genre": genre,
|
||||
"moods": params.get("moods", []),
|
||||
"instruments": params.get("instruments", []),
|
||||
"duration_sec": params.get("duration_sec"),
|
||||
"bpm": params.get("bpm"),
|
||||
"key": params.get("key", ""),
|
||||
"scale": params.get("scale", ""),
|
||||
"prompt": params.get("prompt", ""),
|
||||
"audio_url": audio_url,
|
||||
"file_path": file_path,
|
||||
"task_id": task_id,
|
||||
@app.get("/api/music/providers")
|
||||
def get_providers():
|
||||
"""사용 가능한 음악 생성 프로바이더 목록 반환."""
|
||||
providers = []
|
||||
if os.getenv("MUSIC_AI_SERVER_URL"):
|
||||
providers.append({
|
||||
"id": "local",
|
||||
"name": "MusicGen",
|
||||
"description": "로컬 AI 서버 (인스트루멘탈 전용)",
|
||||
"features": ["instrumental"],
|
||||
})
|
||||
|
||||
update_task(task_id, "succeeded", 100, "생성 완료", audio_url=audio_url)
|
||||
|
||||
except requests.Timeout:
|
||||
update_task(task_id, "failed", 0, "", error="AI 서버 타임아웃 (10분 초과)")
|
||||
except Exception as e:
|
||||
update_task(task_id, "failed", 0, "", error=str(e))
|
||||
if SUNO_API_KEY:
|
||||
providers.append({
|
||||
"id": "suno",
|
||||
"name": "Suno",
|
||||
"description": "Suno AI (보컬·가사·인스트루멘탈)",
|
||||
"features": ["vocals", "lyrics", "instrumental"],
|
||||
})
|
||||
return {"providers": providers}
|
||||
|
||||
|
||||
# ── 음악 생성 API ─────────────────────────────────────────────────────────────
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
provider: str = "suno" # "suno" | "local"
|
||||
title: str = ""
|
||||
genre: str = ""
|
||||
moods: List[str] = []
|
||||
@@ -150,19 +72,35 @@ class GenerateRequest(BaseModel):
|
||||
key: str = ""
|
||||
scale: str = ""
|
||||
prompt: str = ""
|
||||
# Suno 전용
|
||||
lyrics: str = "" # 커스텀 가사 ([Verse], [Chorus] 등)
|
||||
instrumental: bool = False # True면 보컬 없이 인스트루멘탈만
|
||||
|
||||
|
||||
@app.post("/api/music/generate")
|
||||
def generate_music(req: GenerateRequest, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
음악 생성 작업 시작. task_id 즉시 반환 후 백그라운드에서 AI 서버 호출.
|
||||
생성 완료 시 music_library에 자동 등록됨.
|
||||
provider: "suno" (Suno API) 또는 "local" (MusicGen)
|
||||
"""
|
||||
provider = req.provider
|
||||
if provider == "suno" and not SUNO_API_KEY:
|
||||
raise HTTPException(status_code=400, detail="Suno API 키가 설정되지 않았습니다")
|
||||
if provider == "local" and not os.getenv("MUSIC_AI_SERVER_URL"):
|
||||
raise HTTPException(status_code=400, detail="로컬 AI 서버 URL이 설정되지 않았습니다")
|
||||
if provider not in ("suno", "local"):
|
||||
raise HTTPException(status_code=400, detail=f"지원하지 않는 provider: {provider}")
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
params = req.model_dump()
|
||||
create_task(task_id, params)
|
||||
background_tasks.add_task(_run_generation, task_id, params)
|
||||
return {"task_id": task_id}
|
||||
create_task(task_id, params, provider=provider)
|
||||
|
||||
if provider == "suno":
|
||||
background_tasks.add_task(run_suno_generation, task_id, params)
|
||||
else:
|
||||
background_tasks.add_task(run_local_generation, task_id, params)
|
||||
|
||||
return {"task_id": task_id, "provider": provider}
|
||||
|
||||
|
||||
@app.get("/api/music/status/{task_id}")
|
||||
@@ -170,7 +108,7 @@ def get_status(task_id: str):
|
||||
"""
|
||||
생성 작업 상태 조회. 프론트는 succeeded 또는 failed가 될 때까지 폴링.
|
||||
status: queued | processing | succeeded | failed
|
||||
succeeded 시 track 메타데이터 포함 (라이브러리 별도 저장 불필요).
|
||||
succeeded 시 track 메타데이터 포함.
|
||||
"""
|
||||
task = get_task(task_id)
|
||||
if not task:
|
||||
@@ -182,10 +120,9 @@ def get_status(task_id: str):
|
||||
"message": task["message"],
|
||||
"audio_url": task["audio_url"],
|
||||
"error": task["error"],
|
||||
"provider": task["provider"],
|
||||
}
|
||||
|
||||
# succeeded 시 라이브러리에 저장된 트랙 메타데이터 포함
|
||||
# 프론트는 이 track 객체로 UI를 바로 업데이트하면 됨 (Save 버튼 불필요)
|
||||
if task["status"] == "succeeded":
|
||||
track = get_track_by_task_id(task_id)
|
||||
resp["track"] = track
|
||||
@@ -193,6 +130,23 @@ def get_status(task_id: str):
|
||||
return resp
|
||||
|
||||
|
||||
# ── 가사 생성 API (Suno 전용) ────────────────────────────────────────────────
|
||||
|
||||
class LyricsRequest(BaseModel):
|
||||
prompt: str
|
||||
|
||||
|
||||
@app.post("/api/music/lyrics")
|
||||
def gen_lyrics(req: LyricsRequest):
|
||||
"""Suno AI로 가사를 생성합니다. 곡 생성 전 가사 미리보기용."""
|
||||
if not SUNO_API_KEY:
|
||||
raise HTTPException(status_code=400, detail="Suno API 키가 설정되지 않았습니다")
|
||||
result = generate_lyrics(req.prompt)
|
||||
if not result:
|
||||
raise HTTPException(status_code=502, detail="가사 생성에 실패했습니다")
|
||||
return result
|
||||
|
||||
|
||||
# ── 라이브러리 API ────────────────────────────────────────────────────────────
|
||||
|
||||
class TrackCreate(BaseModel):
|
||||
@@ -209,6 +163,10 @@ class TrackCreate(BaseModel):
|
||||
file_path: str = ""
|
||||
task_id: Optional[str] = None
|
||||
tags: List[str] = []
|
||||
provider: str = "local"
|
||||
lyrics: str = ""
|
||||
image_url: str = ""
|
||||
suno_id: str = ""
|
||||
|
||||
|
||||
@app.get("/api/music/library")
|
||||
@@ -226,19 +184,16 @@ def save_to_library(req: TrackCreate):
|
||||
|
||||
@app.delete("/api/music/library/{track_id}")
|
||||
def remove_from_library(track_id: int):
|
||||
"""
|
||||
라이브러리에서 트랙 삭제. 로컬 파일도 함께 삭제.
|
||||
"""
|
||||
"""라이브러리에서 트랙 삭제. 로컬 파일도 함께 삭제."""
|
||||
file_path = get_track_file_path(track_id)
|
||||
ok = delete_track(track_id)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="Track not found")
|
||||
|
||||
# 생성된 파일이 있으면 함께 삭제
|
||||
if file_path and os.path.isfile(file_path):
|
||||
try:
|
||||
os.remove(file_path)
|
||||
except OSError:
|
||||
pass # 파일 삭제 실패해도 DB에서는 이미 삭제됨
|
||||
pass
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
Reference in New Issue
Block a user