import os import uuid from typing import List, Optional from fastapi import FastAPI, HTTPException, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from .db import ( init_db, create_task, get_task, get_all_tracks, add_track, delete_track, get_track_file_path, get_track_by_task_id, ) from .local_provider import run_local_generation from .suno_provider import run_suno_generation, generate_lyrics, SUNO_API_KEY app = FastAPI() _cors_origins = os.getenv("CORS_ALLOW_ORIGINS", "http://localhost:3007,http://localhost:8080").split(",") app.add_middleware( CORSMiddleware, allow_origins=[o.strip() for o in _cors_origins], allow_credentials=False, allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["Content-Type"], ) MUSIC_DATA_DIR = "/app/data" @app.on_event("startup") def on_startup(): init_db() os.makedirs(MUSIC_DATA_DIR, exist_ok=True) @app.get("/health") def health(): return {"ok": True} @app.get("/api/music/providers") def get_providers(): """사용 가능한 음악 생성 프로바이더 목록 반환.""" providers = [] if os.getenv("MUSIC_AI_SERVER_URL"): providers.append({ "id": "local", "name": "MusicGen", "description": "로컬 AI 서버 (인스트루멘탈 전용)", "features": ["instrumental"], }) if SUNO_API_KEY: providers.append({ "id": "suno", "name": "Suno", "description": "Suno AI (보컬·가사·인스트루멘탈)", "features": ["vocals", "lyrics", "instrumental"], }) return {"providers": providers} # ── 음악 생성 API ───────────────────────────────────────────────────────────── class GenerateRequest(BaseModel): provider: str = "suno" # "suno" | "local" title: str = "" genre: str = "" moods: List[str] = [] instruments: List[str] = [] duration_sec: Optional[int] = None bpm: Optional[int] = None key: str = "" scale: str = "" prompt: str = "" # Suno 전용 lyrics: str = "" # 커스텀 가사 ([Verse], [Chorus] 등) instrumental: bool = False # True면 보컬 없이 인스트루멘탈만 @app.post("/api/music/generate") def generate_music(req: GenerateRequest, background_tasks: BackgroundTasks): """ 음악 생성 작업 시작. task_id 즉시 반환 후 백그라운드에서 AI 서버 호출. provider: "suno" (Suno API) 또는 "local" (MusicGen) """ provider = req.provider if provider == "suno" and not SUNO_API_KEY: raise HTTPException(status_code=400, detail="Suno API 키가 설정되지 않았습니다") if provider == "local" and not os.getenv("MUSIC_AI_SERVER_URL"): raise HTTPException(status_code=400, detail="로컬 AI 서버 URL이 설정되지 않았습니다") if provider not in ("suno", "local"): raise HTTPException(status_code=400, detail=f"지원하지 않는 provider: {provider}") task_id = str(uuid.uuid4()) params = req.model_dump() create_task(task_id, params, provider=provider) if provider == "suno": background_tasks.add_task(run_suno_generation, task_id, params) else: background_tasks.add_task(run_local_generation, task_id, params) return {"task_id": task_id, "provider": provider} @app.get("/api/music/status/{task_id}") def get_status(task_id: str): """ 생성 작업 상태 조회. 프론트는 succeeded 또는 failed가 될 때까지 폴링. status: queued | processing | succeeded | failed succeeded 시 track 메타데이터 포함. """ task = get_task(task_id) if not task: raise HTTPException(status_code=404, detail="Task not found") resp = { "status": task["status"], "progress": task["progress"], "message": task["message"], "audio_url": task["audio_url"], "error": task["error"], "provider": task["provider"], } if task["status"] == "succeeded": track = get_track_by_task_id(task_id) resp["track"] = track return resp # ── 가사 생성 API (Suno 전용) ──────────────────────────────────────────────── class LyricsRequest(BaseModel): prompt: str @app.post("/api/music/lyrics") def gen_lyrics(req: LyricsRequest): """Suno AI로 가사를 생성합니다. 곡 생성 전 가사 미리보기용.""" if not SUNO_API_KEY: raise HTTPException(status_code=400, detail="Suno API 키가 설정되지 않았습니다") result = generate_lyrics(req.prompt) if not result: raise HTTPException(status_code=502, detail="가사 생성에 실패했습니다") return result # ── 라이브러리 API ──────────────────────────────────────────────────────────── class TrackCreate(BaseModel): title: str = "" genre: str = "" moods: List[str] = [] instruments: List[str] = [] duration_sec: Optional[int] = None bpm: Optional[int] = None key: str = "" scale: str = "" prompt: str = "" audio_url: str = "" file_path: str = "" task_id: Optional[str] = None tags: List[str] = [] provider: str = "local" lyrics: str = "" image_url: str = "" suno_id: str = "" @app.get("/api/music/library") def list_library(): """저장된 트랙 목록 전체 조회 (생성일 내림차순). 파일시스템과 자동 동기화.""" _sync_library_with_disk() return {"tracks": get_all_tracks()} def _sync_library_with_disk(): """파일시스템의 .mp3 파일과 DB를 동기화. - 디스크에 없는 트랙 → DB에서 삭제 - DB에 없는 .mp3 파일 → 새 트랙으로 추가 """ tracks = get_all_tracks() media_base = os.getenv("MUSIC_MEDIA_BASE", "/media/music") # 디스크의 .mp3 파일 목록 disk_files = set() try: for f in os.listdir(MUSIC_DATA_DIR): if f.lower().endswith(".mp3"): disk_files.add(f) except OSError: return # 디렉토리 접근 불가 시 동기화 스킵 # DB 트랙의 파일명 매핑 db_filenames = {} # filename → track for t in tracks: if t.get("audio_url"): fname = t["audio_url"].split("/")[-1] db_filenames[fname] = t # DB에는 있지만 디스크에 없는 → 삭제 for fname, t in db_filenames.items(): if fname not in disk_files: delete_track(t["id"]) # 디스크에는 있지만 DB에 없는 → 추가 for f in disk_files: if f not in db_filenames: title = os.path.splitext(f)[0].replace("-", " ").replace("_", " ") add_track({ "title": title, "audio_url": f"{media_base}/{f}", "file_path": os.path.join(MUSIC_DATA_DIR, f), "provider": "suno", }) @app.post("/api/music/library", status_code=201) def save_to_library(req: TrackCreate): """트랙 수동 추가 (외부 파일 등록 또는 프론트 직접 저장용)""" track = add_track(req.model_dump()) return track @app.delete("/api/music/library/{track_id}") def remove_from_library(track_id: int): """라이브러리에서 트랙 삭제. 로컬 파일도 함께 삭제.""" file_path = get_track_file_path(track_id) ok = delete_track(track_id) if not ok: raise HTTPException(status_code=404, detail="Track not found") if file_path and os.path.isfile(file_path): try: os.remove(file_path) except OSError: pass return {"ok": True}