import os import uuid from typing import List, Optional from fastapi import FastAPI, HTTPException, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from .db import ( init_db, create_task, get_task, get_all_tracks, add_track, delete_track, get_track_file_path, get_track_by_task_id, ) from .local_provider import run_local_generation from .suno_provider import run_suno_generation, generate_lyrics, SUNO_API_KEY app = FastAPI() _cors_origins = os.getenv("CORS_ALLOW_ORIGINS", "http://localhost:3007,http://localhost:8080").split(",") app.add_middleware( CORSMiddleware, allow_origins=[o.strip() for o in _cors_origins], allow_credentials=False, allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["Content-Type"], ) MUSIC_DATA_DIR = "/app/data/music" @app.on_event("startup") def on_startup(): init_db() os.makedirs(MUSIC_DATA_DIR, exist_ok=True) @app.get("/health") def health(): return {"ok": True} @app.get("/api/music/providers") def get_providers(): """사용 가능한 음악 생성 프로바이더 목록 반환.""" providers = [] if os.getenv("MUSIC_AI_SERVER_URL"): providers.append({ "id": "local", "name": "MusicGen", "description": "로컬 AI 서버 (인스트루멘탈 전용)", "features": ["instrumental"], }) if SUNO_API_KEY: providers.append({ "id": "suno", "name": "Suno", "description": "Suno AI (보컬·가사·인스트루멘탈)", "features": ["vocals", "lyrics", "instrumental"], }) return {"providers": providers} # ── 음악 생성 API ───────────────────────────────────────────────────────────── class GenerateRequest(BaseModel): provider: str = "suno" # "suno" | "local" title: str = "" genre: str = "" moods: List[str] = [] instruments: List[str] = [] duration_sec: Optional[int] = None bpm: Optional[int] = None key: str = "" scale: str = "" prompt: str = "" # Suno 전용 lyrics: str = "" # 커스텀 가사 ([Verse], [Chorus] 등) instrumental: bool = False # True면 보컬 없이 인스트루멘탈만 @app.post("/api/music/generate") def generate_music(req: GenerateRequest, background_tasks: BackgroundTasks): """ 음악 생성 작업 시작. task_id 즉시 반환 후 백그라운드에서 AI 서버 호출. provider: "suno" (Suno API) 또는 "local" (MusicGen) """ provider = req.provider if provider == "suno" and not SUNO_API_KEY: raise HTTPException(status_code=400, detail="Suno API 키가 설정되지 않았습니다") if provider == "local" and not os.getenv("MUSIC_AI_SERVER_URL"): raise HTTPException(status_code=400, detail="로컬 AI 서버 URL이 설정되지 않았습니다") if provider not in ("suno", "local"): raise HTTPException(status_code=400, detail=f"지원하지 않는 provider: {provider}") task_id = str(uuid.uuid4()) params = req.model_dump() create_task(task_id, params, provider=provider) if provider == "suno": background_tasks.add_task(run_suno_generation, task_id, params) else: background_tasks.add_task(run_local_generation, task_id, params) return {"task_id": task_id, "provider": provider} @app.get("/api/music/status/{task_id}") def get_status(task_id: str): """ 생성 작업 상태 조회. 프론트는 succeeded 또는 failed가 될 때까지 폴링. status: queued | processing | succeeded | failed succeeded 시 track 메타데이터 포함. """ task = get_task(task_id) if not task: raise HTTPException(status_code=404, detail="Task not found") resp = { "status": task["status"], "progress": task["progress"], "message": task["message"], "audio_url": task["audio_url"], "error": task["error"], "provider": task["provider"], } if task["status"] == "succeeded": track = get_track_by_task_id(task_id) resp["track"] = track return resp # ── 가사 생성 API (Suno 전용) ──────────────────────────────────────────────── class LyricsRequest(BaseModel): prompt: str @app.post("/api/music/lyrics") def gen_lyrics(req: LyricsRequest): """Suno AI로 가사를 생성합니다. 곡 생성 전 가사 미리보기용.""" if not SUNO_API_KEY: raise HTTPException(status_code=400, detail="Suno API 키가 설정되지 않았습니다") result = generate_lyrics(req.prompt) if not result: raise HTTPException(status_code=502, detail="가사 생성에 실패했습니다") return result # ── 라이브러리 API ──────────────────────────────────────────────────────────── class TrackCreate(BaseModel): title: str = "" genre: str = "" moods: List[str] = [] instruments: List[str] = [] duration_sec: Optional[int] = None bpm: Optional[int] = None key: str = "" scale: str = "" prompt: str = "" audio_url: str = "" file_path: str = "" task_id: Optional[str] = None tags: List[str] = [] provider: str = "local" lyrics: str = "" image_url: str = "" suno_id: str = "" @app.get("/api/music/library") def list_library(): """저장된 트랙 목록 전체 조회 (생성일 내림차순)""" return {"tracks": get_all_tracks()} @app.post("/api/music/library", status_code=201) def save_to_library(req: TrackCreate): """트랙 수동 추가 (외부 파일 등록 또는 프론트 직접 저장용)""" track = add_track(req.model_dump()) return track @app.delete("/api/music/library/{track_id}") def remove_from_library(track_id: int): """라이브러리에서 트랙 삭제. 로컬 파일도 함께 삭제.""" file_path = get_track_file_path(track_id) ok = delete_track(track_id) if not ok: raise HTTPException(status_code=404, detail="Track not found") if file_path and os.path.isfile(file_path): try: os.remove(file_path) except OSError: pass return {"ok": True}