Compare commits
34 Commits
c8793cc3cf
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| cb70226f42 | |||
| de24bae984 | |||
| 0e6c893b4e | |||
| fb80973e38 | |||
| 31b0e7dbc4 | |||
| 6169f48eb8 | |||
| 27a6df6cff | |||
| 803fdb6278 | |||
| 77e21b54e6 | |||
| 4d0c89ce79 | |||
| 4b60ab34c3 | |||
| 53a0657027 | |||
| 91f01d126b | |||
| 0702cf052f | |||
| 8aa3f1c3b2 | |||
| 4db0551d33 | |||
| 4d837fdd31 | |||
| 2567a6f10b | |||
| 17ed1943f1 | |||
| 8d246b5b32 | |||
| b4bec9d51b | |||
| f32792e4a9 | |||
| f152545d3b | |||
| bf3d6ee694 | |||
| 44bc065796 | |||
| 9127616669 | |||
| 900f45c2ff | |||
| eb34cbc0f7 | |||
| 0de09613d2 | |||
| a5274a4fa7 | |||
| 4e72f8ca2e | |||
| 44c6811352 | |||
| 9eef2c5015 | |||
| b05e5714e3 |
@@ -23,3 +23,102 @@ services:
|
||||
interval: 60s
|
||||
timeout: 5s
|
||||
retries: 3
|
||||
|
||||
music-render:
|
||||
build:
|
||||
context: ./music-render
|
||||
container_name: music-render
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "18711:8000"
|
||||
environment:
|
||||
- TZ=Asia/Seoul
|
||||
- REDIS_URL=${REDIS_URL:-redis://192.168.45.54:6379}
|
||||
- NAS_BASE_URL=${NAS_BASE_URL:-http://192.168.45.54:18600}
|
||||
- INTERNAL_API_KEY=${INTERNAL_API_KEY:-}
|
||||
- SUNO_API_KEY=${SUNO_API_KEY:-}
|
||||
- MUSIC_AI_SERVER_URL=${MUSIC_AI_SERVER_URL:-http://host.docker.internal:8765}
|
||||
- MUSIC_MEDIA_ROOT=${MUSIC_MEDIA_ROOT:-/mnt/nas/webpage/data/music}
|
||||
- MUSIC_MEDIA_URL_PREFIX=${MUSIC_MEDIA_URL_PREFIX:-/media/music}
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- /mnt/nas/webpage/data/music:/mnt/nas/webpage/data/music
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
|
||||
interval: 60s
|
||||
timeout: 5s
|
||||
retries: 3
|
||||
|
||||
video-render:
|
||||
build:
|
||||
context: ./video-render
|
||||
container_name: video-render
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "18712:8000"
|
||||
environment:
|
||||
- TZ=Asia/Seoul
|
||||
- REDIS_URL=${REDIS_URL:-redis://192.168.45.54:6379}
|
||||
- NAS_BASE_URL=${NAS_BASE_URL:-http://192.168.45.54:18801}
|
||||
- INTERNAL_API_KEY=${INTERNAL_API_KEY:-}
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
|
||||
- GEMINI_API_KEY=${GEMINI_API_KEY:-}
|
||||
- KLING_ACCESS_KEY=${KLING_ACCESS_KEY:-}
|
||||
- KLING_SECRET_KEY=${KLING_SECRET_KEY:-}
|
||||
- SEEDANCE_API_KEY=${SEEDANCE_API_KEY:-}
|
||||
- VIDEO_MEDIA_ROOT=${VIDEO_MEDIA_ROOT:-/mnt/nas/webpage/data/video}
|
||||
- VIDEO_MEDIA_URL_PREFIX=${VIDEO_MEDIA_URL_PREFIX:-/media/video}
|
||||
volumes:
|
||||
- /mnt/nas/webpage/data/video:/mnt/nas/webpage/data/video
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
|
||||
interval: 60s
|
||||
timeout: 5s
|
||||
retries: 3
|
||||
|
||||
task-watcher:
|
||||
build:
|
||||
context: ./task-watcher
|
||||
container_name: task-watcher
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "18713:8000"
|
||||
environment:
|
||||
- TZ=Asia/Seoul
|
||||
- REDIS_URL=${REDIS_URL:-redis://192.168.45.54:6379}
|
||||
- STOCK_BASE_URL=${STOCK_BASE_URL:-http://192.168.45.54:18500}
|
||||
- TRADING_START=${TRADING_START:-07:00}
|
||||
- TRADING_END=${TRADING_END:-16:30}
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
|
||||
interval: 60s
|
||||
timeout: 5s
|
||||
retries: 3
|
||||
|
||||
image-render:
|
||||
build:
|
||||
context: ./image-render
|
||||
container_name: image-render
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "18714:8000"
|
||||
environment:
|
||||
- TZ=Asia/Seoul
|
||||
- REDIS_URL=${REDIS_URL:-redis://192.168.45.54:6379}
|
||||
- NAS_BASE_URL=${NAS_BASE_URL:-http://192.168.45.54:18802}
|
||||
- INTERNAL_API_KEY=${INTERNAL_API_KEY:-}
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
|
||||
- GEMINI_API_KEY=${GEMINI_API_KEY:-}
|
||||
- COMFYUI_URL=${COMFYUI_URL:-http://host.docker.internal:8188}
|
||||
- FLUX_BLOCK_TRADING_HOURS=${FLUX_BLOCK_TRADING_HOURS:-1}
|
||||
- IMAGE_MEDIA_ROOT=${IMAGE_MEDIA_ROOT:-/mnt/nas/webpage/data/image}
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- /mnt/nas/webpage/data/image:/mnt/nas/webpage/data/image
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
|
||||
interval: 60s
|
||||
timeout: 5s
|
||||
retries: 3
|
||||
|
||||
16
services/image-render/Dockerfile
Normal file
16
services/image-render/Dockerfile
Normal file
@@ -0,0 +1,16 @@
|
||||
FROM python:3.12-slim-bookworm
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir --timeout 600 --retries 5 -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
EXPOSE 8000
|
||||
CMD ["python", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "1"]
|
||||
18
services/image-render/env.example
Normal file
18
services/image-render/env.example
Normal file
@@ -0,0 +1,18 @@
|
||||
# Redis (NAS)
|
||||
REDIS_URL=redis://192.168.45.54:6379
|
||||
|
||||
# NAS image-lab webhook
|
||||
NAS_BASE_URL=http://192.168.45.54:18802
|
||||
INTERNAL_API_KEY=replace-me
|
||||
|
||||
# API provider keys (worker reports failed if missing)
|
||||
OPENAI_API_KEY=
|
||||
GEMINI_API_KEY=
|
||||
# Seedance key not used by image-render
|
||||
|
||||
# FLUX local
|
||||
COMFYUI_URL=http://host.docker.internal:8188
|
||||
FLUX_BLOCK_TRADING_HOURS=1
|
||||
|
||||
# NAS SMB mount target (image-render writes to this, NAS reads via /media/image/)
|
||||
IMAGE_MEDIA_ROOT=/mnt/nas/webpage/data/image
|
||||
36
services/image-render/main.py
Normal file
36
services/image-render/main.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""image-render FastAPI entry — health + lifespan (worker loop spawn)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
import worker
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
worker_task = asyncio.create_task(worker.worker_loop())
|
||||
logger.info("image-render lifespan 시작")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
worker_task.cancel()
|
||||
try:
|
||||
await worker_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("image-render lifespan 종료")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"ok": True, "service": "image-render"}
|
||||
54
services/image-render/nas_client.py
Normal file
54
services/image-render/nas_client.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""NAS webhook 어댑터 — Windows worker → NAS image-lab HTTP 위임.
|
||||
|
||||
video-render nas_client 복제 (call-time os.getenv으로 테스트 격리).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TIMEOUT = 10.0
|
||||
|
||||
|
||||
def _post(payload: Dict[str, Any]) -> None:
|
||||
nas_base_url = os.getenv("NAS_BASE_URL", "http://192.168.45.54:18802")
|
||||
internal_api_key = os.getenv("INTERNAL_API_KEY", "")
|
||||
url = f"{nas_base_url}/api/internal/image/update"
|
||||
try:
|
||||
r = httpx.post(
|
||||
url,
|
||||
headers={"X-Internal-Key": internal_api_key},
|
||||
json=payload,
|
||||
timeout=_TIMEOUT,
|
||||
)
|
||||
if r.status_code != 200:
|
||||
logger.error("webhook %s returned %d: %s",
|
||||
payload.get("task_id"), r.status_code, r.text[:200])
|
||||
except Exception:
|
||||
logger.exception("webhook %s 호출 실패", payload.get("task_id"))
|
||||
|
||||
|
||||
def webhook_update_task(
|
||||
task_id: str,
|
||||
status: str,
|
||||
progress: int,
|
||||
message: str = "",
|
||||
image_url: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> None:
|
||||
payload: Dict[str, Any] = {
|
||||
"task_id": task_id,
|
||||
"status": status,
|
||||
"progress": progress,
|
||||
"message": message,
|
||||
}
|
||||
if image_url is not None:
|
||||
payload["image_url"] = image_url
|
||||
if error is not None:
|
||||
payload["error"] = error
|
||||
_post(payload)
|
||||
0
services/image-render/providers/__init__.py
Normal file
0
services/image-render/providers/__init__.py
Normal file
18
services/image-render/providers/_media.py
Normal file
18
services/image-render/providers/_media.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""b64 이미지 → NAS SMB 경로 저장 → /media/image URL 반환."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import os
|
||||
import uuid
|
||||
|
||||
IMAGE_MEDIA_ROOT = os.getenv("IMAGE_MEDIA_ROOT", "/mnt/nas/webpage/data/image")
|
||||
IMAGE_MEDIA_URL_PREFIX = os.getenv("IMAGE_MEDIA_URL_PREFIX", "/media/image")
|
||||
|
||||
|
||||
def save_b64_png(task_id: str, b64_data: str) -> str:
|
||||
os.makedirs(IMAGE_MEDIA_ROOT, exist_ok=True)
|
||||
fname = f"{task_id}-{uuid.uuid4().hex[:8]}.png"
|
||||
path = os.path.join(IMAGE_MEDIA_ROOT, fname)
|
||||
with open(path, "wb") as f:
|
||||
f.write(base64.b64decode(b64_data))
|
||||
return f"{IMAGE_MEDIA_URL_PREFIX}/{fname}"
|
||||
79
services/image-render/providers/flux.py
Normal file
79
services/image-render/providers/flux.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""FLUX 로컬 — ComfyUI HTTP API.
|
||||
|
||||
POST {COMFYUI_URL}/prompt (workflow JSON) → prompt_id
|
||||
GET {COMFYUI_URL}/history/{prompt_id} → outputs → image filename
|
||||
GET {COMFYUI_URL}/view?filename=... → PNG bytes → b64
|
||||
|
||||
워크플로우 JSON은 `flux_workflow.json` (ComfyUI UI에서 "Save (API Format)"로 export, CLIPTextEncode 노드 text를 "%PROMPT%"로 수동 치환). 박재오 산출물.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64, json, logging, os, time
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import requests
|
||||
|
||||
from nas_client import webhook_update_task
|
||||
from providers._media import save_b64_png
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
COMFYUI_URL = os.getenv("COMFYUI_URL", "http://127.0.0.1:8188")
|
||||
WORKFLOW_PATH = os.path.join(os.path.dirname(__file__), "flux_workflow.json")
|
||||
POLL_INTERVAL = 2
|
||||
POLL_MAX = 120
|
||||
|
||||
|
||||
def _is_trading_hours() -> bool:
|
||||
kst = timezone(timedelta(hours=9))
|
||||
now = datetime.now(kst)
|
||||
if now.weekday() >= 5:
|
||||
return False
|
||||
return (now.hour, now.minute) >= (9, 0) and (now.hour, now.minute) <= (15, 30)
|
||||
|
||||
|
||||
def _load_workflow(prompt: str, size: str) -> dict:
|
||||
with open(WORKFLOW_PATH, encoding="utf-8") as f:
|
||||
wf = json.load(f)
|
||||
# CLIPTextEncode 노드의 text를 prompt로 치환 (workflow에 "%PROMPT%" placeholder 사용)
|
||||
raw = json.dumps(wf).replace("%PROMPT%", prompt.replace('"', "'"))
|
||||
return json.loads(raw)
|
||||
|
||||
|
||||
def _submit_prompt(workflow: dict) -> str:
|
||||
r = requests.post(f"{COMFYUI_URL}/prompt", json={"prompt": workflow}, timeout=30)
|
||||
r.raise_for_status()
|
||||
return r.json()["prompt_id"]
|
||||
|
||||
|
||||
def _poll_image_b64(prompt_id: str):
|
||||
for _ in range(POLL_MAX):
|
||||
h = requests.get(f"{COMFYUI_URL}/history/{prompt_id}", timeout=10)
|
||||
data = h.json().get(prompt_id)
|
||||
if data and data.get("outputs"):
|
||||
for node_out in data["outputs"].values():
|
||||
for img in node_out.get("images", []):
|
||||
view = requests.get(f"{COMFYUI_URL}/view",
|
||||
params={"filename": img["filename"], "subfolder": img.get("subfolder", ""), "type": img.get("type", "output")},
|
||||
timeout=30)
|
||||
view.raise_for_status()
|
||||
return base64.b64encode(view.content).decode()
|
||||
time.sleep(POLL_INTERVAL)
|
||||
return None
|
||||
|
||||
|
||||
def run_flux_generation(task_id: str, params: dict) -> None:
|
||||
try:
|
||||
if os.getenv("FLUX_BLOCK_TRADING_HOURS") == "1" and _is_trading_hours():
|
||||
webhook_update_task(task_id, "failed", 0, "", error="장중 GPU 보호 — FLUX 거부 (API provider 사용 권장)")
|
||||
return
|
||||
webhook_update_task(task_id, "processing", 10, "FLUX (ComfyUI) 생성 중...")
|
||||
wf = _load_workflow(params["prompt"], params.get("size") or "1024x1024")
|
||||
pid = _submit_prompt(wf)
|
||||
b64 = _poll_image_b64(pid)
|
||||
if not b64:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="ComfyUI 타임아웃 또는 출력 없음")
|
||||
return
|
||||
url = save_b64_png(task_id, b64)
|
||||
webhook_update_task(task_id, "succeeded", 100, "완료", image_url=url)
|
||||
except Exception as e:
|
||||
logger.exception("flux task=%s 실패", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
47
services/image-render/providers/gpt_image.py
Normal file
47
services/image-render/providers/gpt_image.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""GPT Image 2.0 — OpenAI Images API.
|
||||
|
||||
POST https://api.openai.com/v1/images/generations
|
||||
body {model:"gpt-image-1", prompt, size, n:1} → data[0].b64_json
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
from nas_client import webhook_update_task
|
||||
from providers._media import save_b64_png
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
OPENAI_URL = "https://api.openai.com/v1/images/generations"
|
||||
DEFAULT_MODEL = "gpt-image-1"
|
||||
|
||||
|
||||
def run_gpt_image_generation(task_id: str, params: dict) -> None:
|
||||
try:
|
||||
if not os.getenv("OPENAI_API_KEY"):
|
||||
webhook_update_task(task_id, "failed", 0, "", error="OPENAI_API_KEY 미설정 (Windows .env)")
|
||||
return
|
||||
webhook_update_task(task_id, "processing", 10, "GPT Image 호출 중...")
|
||||
body = {
|
||||
"model": params.get("model") or DEFAULT_MODEL,
|
||||
"prompt": params["prompt"],
|
||||
"size": params.get("size") or "1024x1024",
|
||||
"n": 1,
|
||||
}
|
||||
resp = requests.post(
|
||||
OPENAI_URL,
|
||||
headers={"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}", "Content-Type": "application/json"},
|
||||
json=body,
|
||||
timeout=120,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"OpenAI {resp.status_code}: {resp.text[:200]}")
|
||||
return
|
||||
b64 = resp.json()["data"][0]["b64_json"]
|
||||
url = save_b64_png(task_id, b64)
|
||||
webhook_update_task(task_id, "succeeded", 100, "완료", image_url=url)
|
||||
except Exception as e:
|
||||
logger.exception("gpt_image task=%s 실패", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
52
services/image-render/providers/nano_banana.py
Normal file
52
services/image-render/providers/nano_banana.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Nano Banana — Gemini 2.5 Flash Image (generativelanguage API).
|
||||
|
||||
POST /v1beta/models/{MODEL}:generateContent
|
||||
→ candidates[0].content.parts[*].inlineData.data (b64 png)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging, os
|
||||
import requests
|
||||
|
||||
from nas_client import webhook_update_task
|
||||
from providers._media import save_b64_png
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
GEMINI_BASE = "https://generativelanguage.googleapis.com/v1beta"
|
||||
DEFAULT_MODEL = "gemini-2.5-flash-image"
|
||||
|
||||
|
||||
def _extract_b64(data: dict):
|
||||
for cand in data.get("candidates", []):
|
||||
for part in cand.get("content", {}).get("parts", []):
|
||||
inline = part.get("inlineData") or part.get("inline_data")
|
||||
if inline and inline.get("data"):
|
||||
return inline["data"]
|
||||
return None
|
||||
|
||||
|
||||
def run_nano_banana_generation(task_id: str, params: dict) -> None:
|
||||
try:
|
||||
if not os.getenv("GEMINI_API_KEY"):
|
||||
webhook_update_task(task_id, "failed", 0, "", error="GEMINI_API_KEY 미설정 (Windows .env)")
|
||||
return
|
||||
webhook_update_task(task_id, "processing", 10, "Nano Banana (Gemini) 호출 중...")
|
||||
model_id = params.get("model") or DEFAULT_MODEL
|
||||
body = {"contents": [{"parts": [{"text": params["prompt"]}]}]}
|
||||
resp = requests.post(
|
||||
f"{GEMINI_BASE}/models/{model_id}:generateContent",
|
||||
headers={"x-goog-api-key": os.getenv("GEMINI_API_KEY"), "Content-Type": "application/json"},
|
||||
json=body, timeout=120,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Gemini {resp.status_code}: {resp.text[:200]}")
|
||||
return
|
||||
b64 = _extract_b64(resp.json())
|
||||
if not b64:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Gemini 응답에 이미지 없음")
|
||||
return
|
||||
url = save_b64_png(task_id, b64)
|
||||
webhook_update_task(task_id, "succeeded", 100, "완료", image_url=url)
|
||||
except Exception as e:
|
||||
logger.exception("nano_banana task=%s 실패", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
9
services/image-render/requirements.txt
Normal file
9
services/image-render/requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
fastapi==0.115.6
|
||||
uvicorn[standard]==0.34.0
|
||||
requests==2.32.3
|
||||
redis>=5.0
|
||||
httpx>=0.27
|
||||
openai>=1.50.0
|
||||
pytest>=8.0
|
||||
pytest-asyncio>=0.24
|
||||
respx>=0.21
|
||||
0
services/image-render/tests/__init__.py
Normal file
0
services/image-render/tests/__init__.py
Normal file
21
services/image-render/tests/test_flux.py
Normal file
21
services/image-render/tests/test_flux.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import providers.flux as fx
|
||||
|
||||
def test_blocked_during_trading_hours(monkeypatch):
|
||||
monkeypatch.setenv("FLUX_BLOCK_TRADING_HOURS", "1")
|
||||
monkeypatch.setattr(fx, "_is_trading_hours", lambda: True)
|
||||
calls = []
|
||||
monkeypatch.setattr(fx, "webhook_update_task", lambda *a, **k: calls.append((a, k)))
|
||||
fx.run_flux_generation("t1", {"prompt": "a cat"})
|
||||
assert calls[-1][0][1] == "failed"
|
||||
assert "장중" in calls[-1][1]["error"]
|
||||
|
||||
def test_success_polls_history_and_saves(monkeypatch):
|
||||
monkeypatch.setattr(fx, "_is_trading_hours", lambda: False)
|
||||
calls = []
|
||||
monkeypatch.setattr(fx, "webhook_update_task", lambda *a, **k: calls.append((a, k)))
|
||||
monkeypatch.setattr(fx, "_load_workflow", lambda prompt, size: {"3": {}})
|
||||
monkeypatch.setattr(fx, "_submit_prompt", lambda wf: "pid-1")
|
||||
monkeypatch.setattr(fx, "_poll_image_b64", lambda pid: "ZmFrZQ==")
|
||||
monkeypatch.setattr(fx, "save_b64_png", lambda tid, b64: "/media/image/t1.png")
|
||||
fx.run_flux_generation("t1", {"prompt": "a cat"})
|
||||
assert [c for c in calls if c[0][1] == "succeeded"]
|
||||
32
services/image-render/tests/test_gpt_image.py
Normal file
32
services/image-render/tests/test_gpt_image.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import providers.gpt_image as gi
|
||||
|
||||
|
||||
def test_missing_key_reports_failed(monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
calls = []
|
||||
monkeypatch.setattr(gi, "webhook_update_task", lambda *a, **k: calls.append((a, k)))
|
||||
gi.run_gpt_image_generation("t1", {"prompt": "a cat"})
|
||||
# 마지막 호출이 failed
|
||||
assert calls[-1][0][1] == "failed"
|
||||
|
||||
|
||||
def test_success_saves_and_reports_url(monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-test")
|
||||
calls = []
|
||||
monkeypatch.setattr(gi, "webhook_update_task", lambda *a, **k: calls.append((a, k)))
|
||||
monkeypatch.setattr(gi, "save_b64_png", lambda tid, b64: "/media/image/t1.png")
|
||||
|
||||
class FakeResp:
|
||||
status_code = 200
|
||||
|
||||
def json(self):
|
||||
return {"data": [{"b64_json": "ZmFrZQ=="}]}
|
||||
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(gi.requests, "post", lambda *a, **k: FakeResp())
|
||||
|
||||
gi.run_gpt_image_generation("t1", {"prompt": "a cat"})
|
||||
succeeded = [c for c in calls if c[0][1] == "succeeded"]
|
||||
assert succeeded and succeeded[-1][1]["image_url"] == "/media/image/t1.png"
|
||||
25
services/image-render/tests/test_nano_banana.py
Normal file
25
services/image-render/tests/test_nano_banana.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import providers.nano_banana as nb
|
||||
|
||||
def test_missing_key_reports_failed(monkeypatch):
|
||||
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
||||
calls = []
|
||||
monkeypatch.setattr(nb, "webhook_update_task", lambda *a, **k: calls.append((a, k)))
|
||||
nb.run_nano_banana_generation("t1", {"prompt": "a cat"})
|
||||
assert calls[-1][0][1] == "failed"
|
||||
|
||||
def test_success_extracts_inline_data(monkeypatch):
|
||||
monkeypatch.setenv("GEMINI_API_KEY", "g-test")
|
||||
calls = []
|
||||
monkeypatch.setattr(nb, "webhook_update_task", lambda *a, **k: calls.append((a, k)))
|
||||
monkeypatch.setattr(nb, "save_b64_png", lambda tid, b64: "/media/image/t1.png")
|
||||
|
||||
class FakeResp:
|
||||
status_code = 200
|
||||
def json(self):
|
||||
return {"candidates": [{"content": {"parts": [
|
||||
{"inlineData": {"mimeType": "image/png", "data": "ZmFrZQ=="}}
|
||||
]}}]}
|
||||
monkeypatch.setattr(nb.requests, "post", lambda *a, **k: FakeResp())
|
||||
|
||||
nb.run_nano_banana_generation("t1", {"prompt": "a cat"})
|
||||
assert [c for c in calls if c[0][1] == "succeeded"]
|
||||
20
services/image-render/tests/test_nas_client.py
Normal file
20
services/image-render/tests/test_nas_client.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import nas_client
|
||||
|
||||
|
||||
def test_webhook_includes_image_url(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
def fake_post(payload):
|
||||
captured.update(payload)
|
||||
|
||||
monkeypatch.setattr(nas_client, "_post", fake_post)
|
||||
nas_client.webhook_update_task("t1", "succeeded", 100, "done", image_url="/media/image/t1.png")
|
||||
assert captured["task_id"] == "t1"
|
||||
assert captured["image_url"] == "/media/image/t1.png"
|
||||
|
||||
|
||||
def test_webhook_omits_none_fields(monkeypatch):
|
||||
captured = {}
|
||||
monkeypatch.setattr(nas_client, "_post", lambda p: captured.update(p))
|
||||
nas_client.webhook_update_task("t2", "processing", 10, "working")
|
||||
assert "image_url" not in captured and "error" not in captured
|
||||
15
services/image-render/tests/test_worker.py
Normal file
15
services/image-render/tests/test_worker.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import worker
|
||||
|
||||
|
||||
def test_dispatch_routes_to_provider(monkeypatch):
|
||||
called = {}
|
||||
monkeypatch.setattr(worker, "run_gpt_image_generation", lambda tid, p: called.setdefault("gpt", (tid, p)))
|
||||
worker._dispatch({"job_type": "gpt_image_generation", "task_id": "t1", "params": {"prompt": "x"}})
|
||||
assert called["gpt"][0] == "t1"
|
||||
|
||||
|
||||
def test_dispatch_unknown_job_type_reports_failed(monkeypatch):
|
||||
calls = []
|
||||
monkeypatch.setattr(worker, "webhook_update_task", lambda *a, **k: calls.append((a, k)))
|
||||
worker._dispatch({"job_type": "midjourney_generation", "task_id": "t9", "params": {}})
|
||||
assert calls[-1][0][1] == "failed"
|
||||
84
services/image-render/worker.py
Normal file
84
services/image-render/worker.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Redis BLPOP worker — queue:image-render → job_type dispatch → NAS webhook.
|
||||
|
||||
queue:paused 가 set이면 대기 (task-watcher가 박재오 활동 감지 시 set).
|
||||
video-render worker.py 패턴 — string-based dispatch + getattr (테스트 patch 호환).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from nas_client import webhook_update_task
|
||||
from providers.gpt_image import run_gpt_image_generation
|
||||
from providers.nano_banana import run_nano_banana_generation
|
||||
from providers.flux import run_flux_generation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379")
|
||||
QUEUE_KEY = "queue:image-render"
|
||||
PAUSED_KEY = "queue:paused"
|
||||
|
||||
# string names so `unittest.mock.patch` / `monkeypatch.setattr` on `worker.<name>`
|
||||
# is correctly intercepted by getattr(sys.modules[__name__], ...)
|
||||
_DISPATCH_TABLE = {
|
||||
"gpt_image_generation": "run_gpt_image_generation",
|
||||
"nano_banana_generation": "run_nano_banana_generation",
|
||||
"flux_generation": "run_flux_generation",
|
||||
}
|
||||
|
||||
|
||||
def _dispatch(payload: dict) -> None:
|
||||
"""payload[job_type] → provider 함수 호출 (sync, worker_loop에서 asyncio.to_thread로 wrap)."""
|
||||
job_type = payload.get("job_type", "")
|
||||
task_id = payload.get("task_id", "")
|
||||
params = payload.get("params", {})
|
||||
fn_name = _DISPATCH_TABLE.get(job_type)
|
||||
if fn_name is None:
|
||||
logger.error("unknown job_type=%s task=%s", job_type, task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"unknown job_type: {job_type}")
|
||||
return
|
||||
try:
|
||||
fn = getattr(sys.modules[__name__], fn_name)
|
||||
except AttributeError:
|
||||
logger.error("dispatch table typo for job_type=%s name=%s task=%s", job_type, fn_name, task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"internal dispatch error: {fn_name}")
|
||||
return
|
||||
fn(task_id, params)
|
||||
|
||||
|
||||
async def worker_loop():
|
||||
redis = aioredis.from_url(REDIS_URL, decode_responses=False)
|
||||
logger.info("image-render worker started (queue=%s)", QUEUE_KEY)
|
||||
while True:
|
||||
try:
|
||||
paused = await redis.get(PAUSED_KEY)
|
||||
if paused == b"1":
|
||||
await asyncio.sleep(10)
|
||||
continue
|
||||
item = await redis.blpop(QUEUE_KEY, timeout=5)
|
||||
if item is None:
|
||||
continue
|
||||
_, raw = item
|
||||
try:
|
||||
payload = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("invalid queue payload: %r", raw[:200])
|
||||
continue
|
||||
await asyncio.to_thread(_dispatch, payload)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("worker_loop cancelled")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("worker_loop iteration 실패, 5초 후 재시도")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
asyncio.run(worker_loop())
|
||||
20
services/music-render/.env.example
Normal file
20
services/music-render/.env.example
Normal file
@@ -0,0 +1,20 @@
|
||||
# Plan-B-Music — Windows music-render worker
|
||||
|
||||
# NAS Redis 큐
|
||||
REDIS_URL=redis://192.168.45.54:6379
|
||||
|
||||
# NAS internal webhook
|
||||
NAS_BASE_URL=http://192.168.45.54:18600
|
||||
INTERNAL_API_KEY=__copy_from_nas_dotenv__
|
||||
|
||||
# Suno API (sunoapi.org 래퍼) — NAS .env에서 옮겨옴
|
||||
SUNO_API_KEY=__paste_suno_key_here__
|
||||
|
||||
# MusicGen 호스트 (Windows native Python — 박재오 PC localhost)
|
||||
MUSIC_AI_SERVER_URL=http://host.docker.internal:8765
|
||||
|
||||
# NAS SMB mount 안의 음악 디렉토리
|
||||
MUSIC_MEDIA_ROOT=/mnt/nas/webpage/data/music
|
||||
|
||||
# nginx 서빙 prefix (NAS webhook payload용)
|
||||
MUSIC_MEDIA_URL_PREFIX=/media/music
|
||||
17
services/music-render/Dockerfile
Normal file
17
services/music-render/Dockerfile
Normal file
@@ -0,0 +1,17 @@
|
||||
FROM python:3.12-slim-bookworm
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# requests SSL 의존성만 필요 (Chromium 불필요)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir --timeout 600 --retries 5 -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
EXPOSE 8000
|
||||
CMD ["python", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "1"]
|
||||
88
services/music-render/main.py
Normal file
88
services/music-render/main.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""music-render FastAPI entry — health + lifespan + sync forward endpoints.
|
||||
|
||||
NAS music-lab이 sync helpers(lyrics, credits, timestamped, style-boost)를
|
||||
httpx로 forward해서 이 endpoint들을 호출.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
import worker
|
||||
from providers.sync_ops import (
|
||||
generate_lyrics, get_credits,
|
||||
get_timestamped_lyrics, generate_style_boost,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
worker_task = asyncio.create_task(worker.worker_loop())
|
||||
logger.info("music-render lifespan 시작")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
worker_task.cancel()
|
||||
try:
|
||||
await worker_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("music-render lifespan 종료")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"ok": True, "service": "music-render"}
|
||||
|
||||
|
||||
# ── Sync forward endpoints ──────────────────────────────────────────────
|
||||
# NAS music-lab의 /api/music/lyrics 등 sync helpers가 이 endpoint들로 forward.
|
||||
|
||||
class LyricsRequest(BaseModel):
|
||||
prompt: str
|
||||
|
||||
|
||||
@app.post("/api/music-render/sync/lyrics")
|
||||
def sync_lyrics(req: LyricsRequest):
|
||||
result = generate_lyrics(req.prompt)
|
||||
if not result:
|
||||
raise HTTPException(502, "가사 생성 실패")
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/api/music-render/sync/credits")
|
||||
def sync_credits():
|
||||
result = get_credits()
|
||||
if result is None:
|
||||
raise HTTPException(502, "크레딧 조회 실패")
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/api/music-render/sync/timestamped-lyrics")
|
||||
def sync_timestamped_lyrics(task_id: str, suno_id: str):
|
||||
result = get_timestamped_lyrics(task_id, suno_id)
|
||||
if not result:
|
||||
raise HTTPException(502, "타임스탬프 가사 조회 실패")
|
||||
return result
|
||||
|
||||
|
||||
class StyleBoostRequest(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
@app.post("/api/music-render/sync/style-boost")
|
||||
def sync_style_boost(req: StyleBoostRequest):
|
||||
result = generate_style_boost(req.content)
|
||||
if not result:
|
||||
raise HTTPException(502, "스타일 부스트 생성 실패")
|
||||
return result
|
||||
80
services/music-render/nas_client.py
Normal file
80
services/music-render/nas_client.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""NAS webhook 어댑터 — Windows worker가 NAS DB 직접 접근 못하므로 HTTP로 위임.
|
||||
|
||||
기존 NAS suno_provider/local_provider의 `update_task`, `add_track` 호출을
|
||||
이 모듈의 webhook_update_task/webhook_add_track으로 치환.
|
||||
|
||||
webhook 실패는 raise하지 않고 logger.error로 기록 (provider 로직 흐름 유지).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TIMEOUT = 10.0
|
||||
|
||||
|
||||
def _post(payload: Dict[str, Any]) -> None:
|
||||
nas_base_url = os.getenv("NAS_BASE_URL", "http://192.168.45.54:18600")
|
||||
internal_api_key = os.getenv("INTERNAL_API_KEY", "")
|
||||
url = f"{nas_base_url}/api/internal/music/update"
|
||||
try:
|
||||
r = httpx.post(
|
||||
url,
|
||||
headers={"X-Internal-Key": internal_api_key},
|
||||
json=payload,
|
||||
timeout=_TIMEOUT,
|
||||
)
|
||||
if r.status_code != 200:
|
||||
logger.error("webhook %s returned %d: %s",
|
||||
payload.get("task_id"), r.status_code, r.text[:200])
|
||||
except Exception:
|
||||
logger.exception("webhook %s 호출 실패", payload.get("task_id"))
|
||||
|
||||
|
||||
def webhook_update_task(
|
||||
task_id: str,
|
||||
status: str,
|
||||
progress: int,
|
||||
message: str = "",
|
||||
audio_url: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> None:
|
||||
"""기존 update_task(task_id, status, progress, message, audio_url=None, error=None) 대체."""
|
||||
payload: Dict[str, Any] = {
|
||||
"task_id": task_id,
|
||||
"status": status,
|
||||
"progress": progress,
|
||||
"message": message,
|
||||
}
|
||||
if audio_url is not None:
|
||||
payload["audio_url"] = audio_url
|
||||
if error is not None:
|
||||
payload["error"] = error
|
||||
_post(payload)
|
||||
|
||||
|
||||
def webhook_add_track(
|
||||
task_id: str,
|
||||
status: str,
|
||||
progress: int,
|
||||
message: str = "",
|
||||
audio_url: Optional[str] = None,
|
||||
track: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""update + add_track을 한 webhook 호출로 결합 (NAS internal_router가 둘 다 처리)."""
|
||||
payload: Dict[str, Any] = {
|
||||
"task_id": task_id,
|
||||
"status": status,
|
||||
"progress": progress,
|
||||
"message": message,
|
||||
}
|
||||
if audio_url is not None:
|
||||
payload["audio_url"] = audio_url
|
||||
if track is not None:
|
||||
payload["track"] = track
|
||||
_post(payload)
|
||||
0
services/music-render/providers/__init__.py
Normal file
0
services/music-render/providers/__init__.py
Normal file
106
services/music-render/providers/local.py
Normal file
106
services/music-render/providers/local.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Local MusicGen Provider — Windows AI 머신의 native MusicGen 서버(:8765) 호출.
|
||||
|
||||
NAS music-lab/app/local_provider.py 이식. DB 호출만 webhook으로 변환.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from nas_client import webhook_update_task, webhook_add_track
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MUSIC_AI_SERVER_URL = os.getenv("MUSIC_AI_SERVER_URL", "")
|
||||
MUSIC_MEDIA_ROOT = os.getenv("MUSIC_MEDIA_ROOT", "/mnt/nas/webpage/data/music")
|
||||
MUSIC_MEDIA_BASE = os.getenv("MUSIC_MEDIA_URL_PREFIX", "/media/music")
|
||||
|
||||
|
||||
def run_local_generation(task_id: str, params: dict) -> None:
|
||||
"""MusicGen 생성 → /mnt/nas/.../music/{task_id}.mp3 저장 → add_track."""
|
||||
try:
|
||||
webhook_update_task(task_id, "processing", 10, "AI 서버에 연결 중...")
|
||||
if not MUSIC_AI_SERVER_URL:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="MUSIC_AI_SERVER_URL 미설정")
|
||||
return
|
||||
|
||||
webhook_update_task(task_id, "processing", 30, "음악 생성 중...")
|
||||
resp = requests.post(f"{MUSIC_AI_SERVER_URL}/generate", json=params, timeout=30)
|
||||
if resp.status_code != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "",
|
||||
error=f"AI 서버 오류: {resp.status_code} {resp.text[:200]}")
|
||||
return
|
||||
|
||||
ai_task_id = resp.json().get("task_id")
|
||||
if not ai_task_id:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="AI 서버 응답에 task_id 없음")
|
||||
return
|
||||
|
||||
remote_url = None
|
||||
for _ in range(120):
|
||||
time.sleep(5)
|
||||
sr = requests.get(f"{MUSIC_AI_SERVER_URL}/status/{ai_task_id}", timeout=10)
|
||||
sd = sr.json()
|
||||
st = sd.get("status")
|
||||
prog = sd.get("progress", 0)
|
||||
msg = sd.get("message", "음악 생성 중...")
|
||||
scaled = 30 + int(prog * 0.49)
|
||||
webhook_update_task(task_id, "processing", scaled, msg)
|
||||
|
||||
if st == "succeeded":
|
||||
remote_url = sd.get("audio_url")
|
||||
break
|
||||
elif st == "failed":
|
||||
webhook_update_task(task_id, "failed", 0, "",
|
||||
error=sd.get("error", "AI 서버 생성 실패"))
|
||||
return
|
||||
|
||||
if not remote_url:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="AI 서버 타임아웃 (10분)")
|
||||
return
|
||||
|
||||
webhook_update_task(task_id, "processing", 80, "파일 저장 중...")
|
||||
filename = f"{task_id}.mp3"
|
||||
os.makedirs(MUSIC_MEDIA_ROOT, exist_ok=True)
|
||||
file_path = os.path.join(MUSIC_MEDIA_ROOT, filename)
|
||||
|
||||
dl = requests.get(remote_url, timeout=120, stream=True)
|
||||
dl.raise_for_status()
|
||||
with open(file_path, "wb") as f:
|
||||
for chunk in dl.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
audio_url = f"{MUSIC_MEDIA_BASE}/{filename}"
|
||||
genre = params.get("genre", "")
|
||||
moods = params.get("moods", [])
|
||||
mood_str = moods[0] if moods else "Original"
|
||||
title = params.get("title") or (
|
||||
f"{genre} — {mood_str} Mix" if genre else f"{mood_str} Mix"
|
||||
)
|
||||
|
||||
track = {
|
||||
"title": title,
|
||||
"genre": genre,
|
||||
"moods": moods,
|
||||
"instruments": params.get("instruments", []),
|
||||
"duration_sec": params.get("duration_sec"),
|
||||
"bpm": params.get("bpm"),
|
||||
"key": params.get("key", ""),
|
||||
"scale": params.get("scale", ""),
|
||||
"prompt": params.get("prompt", ""),
|
||||
"audio_url": audio_url,
|
||||
"file_path": f"/app/data/{filename}",
|
||||
"task_id": task_id,
|
||||
"provider": "local",
|
||||
}
|
||||
webhook_add_track(task_id, "succeeded", 100, "생성 완료",
|
||||
audio_url=audio_url, track=track)
|
||||
|
||||
except requests.Timeout:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="AI 서버 타임아웃")
|
||||
except Exception as e:
|
||||
logger.exception("local generation error task=%s", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
690
services/music-render/providers/suno.py
Normal file
690
services/music-render/providers/suno.py
Normal file
@@ -0,0 +1,690 @@
|
||||
"""Suno API Provider — sunoapi.org 래퍼.
|
||||
|
||||
NAS music-lab/app/suno_provider.py에서 이식. 차이점:
|
||||
- DB 호출(update_task, add_track 등)을 nas_client.webhook_* 으로 변환
|
||||
- 결과 MP3는 MUSIC_MEDIA_ROOT (/mnt/nas/webpage/data/music/)에 직접 저장
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
from nas_client import webhook_update_task, webhook_add_track
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SUNO_BASE_URL = "https://api.sunoapi.org/api/v1"
|
||||
SUNO_API_KEY = os.getenv("SUNO_API_KEY", "")
|
||||
MUSIC_MEDIA_ROOT = os.getenv("MUSIC_MEDIA_ROOT", "/mnt/nas/webpage/data/music")
|
||||
MUSIC_MEDIA_BASE = os.getenv("MUSIC_MEDIA_URL_PREFIX", "/media/music")
|
||||
|
||||
POLL_INTERVAL = 8
|
||||
POLL_MAX_ATTEMPTS = 40
|
||||
|
||||
|
||||
def _headers() -> dict:
|
||||
return {
|
||||
"Authorization": f"Bearer {SUNO_API_KEY}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
def _build_suno_payload(params: dict) -> dict:
|
||||
"""프론트엔드 params → sunoapi.org 요청 형식 (NAS 코드 그대로 이식)."""
|
||||
instrumental = params.get("instrumental", False)
|
||||
has_lyrics = bool(params.get("lyrics"))
|
||||
custom_mode = has_lyrics or bool(params.get("genre")) or bool(params.get("moods"))
|
||||
|
||||
payload = {
|
||||
"customMode": custom_mode,
|
||||
"instrumental": instrumental,
|
||||
"model": params.get("model", "V4"),
|
||||
"callBackUrl": "https://example.com/noop",
|
||||
}
|
||||
|
||||
if custom_mode:
|
||||
if instrumental:
|
||||
payload["prompt"] = ""
|
||||
elif has_lyrics:
|
||||
payload["prompt"] = params["lyrics"][:3000]
|
||||
else:
|
||||
prompt_text = params.get("prompt", "")
|
||||
payload["prompt"] = prompt_text[:3000] if prompt_text else ""
|
||||
|
||||
style_parts = []
|
||||
if params.get("genre"):
|
||||
style_parts.append(params["genre"])
|
||||
if params.get("moods"):
|
||||
style_parts.extend(params["moods"])
|
||||
if params.get("instruments"):
|
||||
style_parts.extend(params["instruments"][:3])
|
||||
if style_parts:
|
||||
payload["style"] = ", ".join(style_parts)[:200]
|
||||
|
||||
if params.get("title"):
|
||||
payload["title"] = params["title"][:80]
|
||||
else:
|
||||
parts = []
|
||||
if params.get("prompt"):
|
||||
parts.append(params["prompt"])
|
||||
if params.get("genre"):
|
||||
parts.append(params["genre"])
|
||||
if params.get("moods"):
|
||||
parts.append(", ".join(params["moods"]))
|
||||
payload["prompt"] = " ".join(parts)[:500] if parts else "instrumental music"
|
||||
|
||||
if params.get("vocal_gender"):
|
||||
payload["vocalGender"] = params["vocal_gender"]
|
||||
if params.get("negative_tags"):
|
||||
payload["negativeTags"] = params["negative_tags"]
|
||||
if params.get("style_weight") is not None:
|
||||
payload["styleWeight"] = params["style_weight"]
|
||||
if params.get("audio_weight") is not None:
|
||||
payload["audioWeight"] = params["audio_weight"]
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _poll_suno_record(
|
||||
record_info_path: str,
|
||||
suno_task_id: str,
|
||||
task_id: str,
|
||||
max_attempts: int = POLL_MAX_ATTEMPTS,
|
||||
interval: int = POLL_INTERVAL,
|
||||
progress_msg_map: dict = None,
|
||||
) -> Optional[dict]:
|
||||
"""범용 Suno 작업 폴링. SUCCESS 시 response 객체 반환."""
|
||||
error_statuses = {
|
||||
"CREATE_TASK_FAILED", "GENERATE_AUDIO_FAILED",
|
||||
"CALLBACK_EXCEPTION", "SENSITIVE_WORD_ERROR",
|
||||
}
|
||||
default_msgs = {
|
||||
"PENDING": "대기열에서 대기 중...",
|
||||
"TEXT_SUCCESS": "가사 생성 완료, 음악 생성 중...",
|
||||
"FIRST_SUCCESS": "첫 번째 트랙 완료, 두 번째 생성 중...",
|
||||
"GENERATING": "생성 중...",
|
||||
}
|
||||
msgs = {**default_msgs, **(progress_msg_map or {})}
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
time.sleep(interval)
|
||||
try:
|
||||
resp = requests.get(
|
||||
f"{SUNO_BASE_URL}{record_info_path}",
|
||||
headers=_headers(),
|
||||
params={"taskId": suno_task_id},
|
||||
timeout=15,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
continue
|
||||
body = resp.json()
|
||||
if body.get("code") != 200:
|
||||
continue
|
||||
data = body.get("data", {})
|
||||
status = data.get("status", "")
|
||||
progress = min(15 + int((attempt / max_attempts) * 65), 79)
|
||||
|
||||
if status == "SUCCESS":
|
||||
return data.get("response", data)
|
||||
elif status in error_statuses:
|
||||
error_msg = data.get("errorMessage") or data.get("msg") or f"Suno 작업 실패 ({status})"
|
||||
webhook_update_task(task_id, "failed", 0, "", error=error_msg)
|
||||
return None
|
||||
else:
|
||||
msg = msgs.get(status, f"처리 중... ({status})")
|
||||
if status == "FIRST_SUCCESS":
|
||||
progress = max(progress, 60)
|
||||
webhook_update_task(task_id, "processing", progress, msg)
|
||||
except Exception as e:
|
||||
logger.warning("Suno poll error (attempt %d): %s", attempt, e)
|
||||
continue
|
||||
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Suno 작업 타임아웃")
|
||||
return None
|
||||
|
||||
|
||||
def _download_and_register(
|
||||
task_id: str, song: dict, params: dict, filename_suffix: str = "",
|
||||
) -> Optional[dict]:
|
||||
"""Suno CDN에서 MP3 다운로드 → /mnt/nas/...에 직접 저장 → webhook으로 add_track."""
|
||||
audio_url_remote = song.get("audioUrl") or song.get("audio_url", "")
|
||||
if not audio_url_remote:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Suno 응답에 audioUrl이 없습니다")
|
||||
return None
|
||||
|
||||
filename = f"{task_id}{filename_suffix}.mp3"
|
||||
os.makedirs(MUSIC_MEDIA_ROOT, exist_ok=True)
|
||||
file_path = os.path.join(MUSIC_MEDIA_ROOT, filename)
|
||||
|
||||
try:
|
||||
dl = requests.get(audio_url_remote, timeout=120, stream=True)
|
||||
dl.raise_for_status()
|
||||
with open(file_path, "wb") as f:
|
||||
for chunk in dl.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
except Exception as e:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"오디오 다운로드 실패: {e}")
|
||||
return None
|
||||
|
||||
local_audio_url = f"{MUSIC_MEDIA_BASE}/{filename}"
|
||||
|
||||
genre = params.get("genre", song.get("tags", ""))
|
||||
moods = params.get("moods", [])
|
||||
mood_str = moods[0] if moods else "Original"
|
||||
title = (
|
||||
song.get("title")
|
||||
or params.get("title")
|
||||
or (f"{genre} — {mood_str} Mix" if genre else f"{mood_str} Mix")
|
||||
)
|
||||
|
||||
track_data = {
|
||||
"title": title,
|
||||
"genre": genre,
|
||||
"moods": moods,
|
||||
"instruments": params.get("instruments", []),
|
||||
"duration_sec": int(song["duration"]) if song.get("duration") else params.get("duration_sec"),
|
||||
"bpm": params.get("bpm"),
|
||||
"key": params.get("key", ""),
|
||||
"scale": params.get("scale", ""),
|
||||
"prompt": song.get("prompt", params.get("prompt", "")),
|
||||
"audio_url": local_audio_url,
|
||||
# NAS file_path는 NAS 관점 — /app/data 안의 경로
|
||||
"file_path": f"/app/data/{filename}",
|
||||
"task_id": task_id,
|
||||
"provider": "suno",
|
||||
"lyrics": song.get("prompt", params.get("lyrics", "")),
|
||||
"image_url": song.get("imageUrl") or song.get("image_url", ""),
|
||||
"suno_id": song.get("id", ""),
|
||||
}
|
||||
return track_data
|
||||
|
||||
|
||||
def run_suno_generation(task_id: str, params: dict) -> None:
|
||||
"""BackgroundTask: Suno API로 곡 생성 → MP3 → NAS SMB 저장 → webhook add_track."""
|
||||
try:
|
||||
if not SUNO_API_KEY:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="SUNO_API_KEY 미설정 (Windows .env)")
|
||||
return
|
||||
|
||||
webhook_update_task(task_id, "processing", 5, "Suno API에 연결 중...")
|
||||
payload = _build_suno_payload(params)
|
||||
resp = requests.post(f"{SUNO_BASE_URL}/generate", headers=_headers(), json=payload, timeout=30)
|
||||
|
||||
if resp.status_code != 200:
|
||||
err = resp.text[:300] if resp.text else f"HTTP {resp.status_code}"
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Suno API 오류: {err}")
|
||||
return
|
||||
|
||||
body = resp.json()
|
||||
if body.get("code") != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Suno API 거부: {body.get('msg', '?')}")
|
||||
return
|
||||
|
||||
suno_task_id = body.get("data", {}).get("taskId", "")
|
||||
if not suno_task_id:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Suno 응답에 taskId 없음")
|
||||
return
|
||||
|
||||
webhook_update_task(task_id, "processing", 15, "곡 생성 대기열에 등록됨...")
|
||||
|
||||
response = _poll_suno_record("/generate/record-info", suno_task_id, task_id)
|
||||
if not response:
|
||||
return
|
||||
completed = response.get("sunoData") or []
|
||||
if not completed:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Suno 완료했으나 트랙 데이터 없음")
|
||||
return
|
||||
|
||||
webhook_update_task(task_id, "processing", 80, "오디오 파일 다운로드 중...")
|
||||
track = _download_and_register(task_id, completed[0], params)
|
||||
if not track:
|
||||
return
|
||||
|
||||
webhook_add_track(task_id, "succeeded", 100, "생성 완료",
|
||||
audio_url=track["audio_url"], track=track)
|
||||
|
||||
if len(completed) > 1:
|
||||
try:
|
||||
# 보조 변형은 SMB에 파일만 저장. NAS _sync_library_with_disk가 다음
|
||||
# GET /api/music/library 호출 시 자동으로 라이브러리에 등록.
|
||||
_download_and_register(f"{task_id}_v2", completed[1], params)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except requests.Timeout:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Suno API 타임아웃")
|
||||
except Exception as e:
|
||||
logger.exception("Suno generation error task=%s", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
|
||||
|
||||
def run_suno_extend(task_id: str, params: dict) -> None:
|
||||
"""기존 곡을 특정 지점부터 연장."""
|
||||
try:
|
||||
if not SUNO_API_KEY:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="SUNO_API_KEY 미설정")
|
||||
return
|
||||
webhook_update_task(task_id, "processing", 5, "곡 연장 요청 중...")
|
||||
payload = {
|
||||
"audioId": params["suno_id"],
|
||||
"defaultParamFlag": not bool(params.get("prompt")),
|
||||
"prompt": params.get("prompt", ""),
|
||||
"continueAt": params.get("continue_at", 0),
|
||||
"model": params.get("model", "V4"),
|
||||
"callBackUrl": "https://example.com/noop",
|
||||
}
|
||||
if params.get("style"):
|
||||
payload["style"] = params["style"]
|
||||
if params.get("title"):
|
||||
payload["title"] = params["title"]
|
||||
|
||||
resp = requests.post(f"{SUNO_BASE_URL}/generate/extend", headers=_headers(), json=payload, timeout=30)
|
||||
if resp.status_code != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Suno Extend 오류: {resp.text[:300]}")
|
||||
return
|
||||
body = resp.json()
|
||||
if body.get("code") != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Extend 거부: {body.get('msg', '?')}")
|
||||
return
|
||||
suno_task_id = body.get("data", {}).get("taskId", "")
|
||||
if not suno_task_id:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Extend 응답에 taskId 없음")
|
||||
return
|
||||
webhook_update_task(task_id, "processing", 15, "곡 연장 대기열에 등록됨...")
|
||||
|
||||
response = _poll_suno_record("/generate/record-info", suno_task_id, task_id)
|
||||
if not response:
|
||||
return
|
||||
completed = response.get("sunoData") or []
|
||||
if not completed:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="연장 완료했으나 트랙 없음")
|
||||
return
|
||||
webhook_update_task(task_id, "processing", 80, "연장된 오디오 다운로드 중...")
|
||||
track = _download_and_register(task_id, completed[0], params)
|
||||
if track:
|
||||
webhook_add_track(task_id, "succeeded", 100, "곡 연장 완료",
|
||||
audio_url=track["audio_url"], track=track)
|
||||
except Exception as e:
|
||||
logger.exception("Suno extend error task=%s", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
|
||||
|
||||
def run_vocal_removal(task_id: str, params: dict) -> None:
|
||||
try:
|
||||
if not SUNO_API_KEY:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="SUNO_API_KEY 미설정")
|
||||
return
|
||||
webhook_update_task(task_id, "processing", 5, "보컬 분리 요청 중...")
|
||||
payload = {"audioId": params["suno_id"], "callBackUrl": "https://example.com/noop"}
|
||||
resp = requests.post(f"{SUNO_BASE_URL}/vocal-removal/generate", headers=_headers(), json=payload, timeout=30)
|
||||
if resp.status_code != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Vocal Removal 오류: {resp.text[:300]}")
|
||||
return
|
||||
body = resp.json()
|
||||
if body.get("code") != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Vocal Removal 거부: {body.get('msg', '?')}")
|
||||
return
|
||||
suno_task_id = body.get("data", {}).get("taskId", "")
|
||||
if not suno_task_id:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="응답에 taskId 없음")
|
||||
return
|
||||
webhook_update_task(task_id, "processing", 15, "보컬 분리 처리 중...")
|
||||
response = _poll_suno_record("/vocal-removal/record-info", suno_task_id, task_id)
|
||||
if not response:
|
||||
return
|
||||
completed = response.get("sunoData") or []
|
||||
if not completed:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="분리 완료했으나 트랙 없음")
|
||||
return
|
||||
webhook_update_task(task_id, "processing", 80, "분리된 오디오 다운로드 중...")
|
||||
vp = {**params, "title": f"{params.get('title', 'Track')} (Vocals)"}
|
||||
track = _download_and_register(task_id, completed[0], vp)
|
||||
if len(completed) > 1:
|
||||
ip = {**params, "title": f"{params.get('title', 'Track')} (Instrumental)"}
|
||||
# Instrumental 변형은 SMB에 파일만 저장. NAS _sync_library_with_disk가 자동 등록.
|
||||
_download_and_register(f"{task_id}_inst", completed[1], ip)
|
||||
if track:
|
||||
webhook_add_track(task_id, "succeeded", 100, "보컬 분리 완료",
|
||||
audio_url=track["audio_url"], track=track)
|
||||
except Exception as e:
|
||||
logger.exception("vocal removal error task=%s", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
|
||||
|
||||
def run_cover_image(task_id: str, params: dict) -> None:
|
||||
"""Suno 곡의 커버 이미지 2장 (URL JSON 반환)."""
|
||||
try:
|
||||
if not SUNO_API_KEY:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="SUNO_API_KEY 미설정"); return
|
||||
webhook_update_task(task_id, "processing", 5, "커버 이미지 생성 요청 중...")
|
||||
suno_task_id = params.get("suno_task_id", "")
|
||||
if not suno_task_id:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="suno_task_id 필요"); return
|
||||
payload = {"taskId": suno_task_id, "callBackUrl": "https://example.com/noop"}
|
||||
resp = requests.post(f"{SUNO_BASE_URL}/suno/cover/generate", headers=_headers(), json=payload, timeout=30)
|
||||
if resp.status_code != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Cover API 오류: {resp.text[:300]}"); return
|
||||
body = resp.json()
|
||||
if body.get("code") != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Cover 거부: {body.get('msg', '?')}"); return
|
||||
cover_task_id = body.get("data", {}).get("taskId", suno_task_id)
|
||||
webhook_update_task(task_id, "processing", 15, "커버 이미지 생성 중...")
|
||||
response = _poll_suno_record(
|
||||
"/suno/cover/record-info", cover_task_id, task_id,
|
||||
max_attempts=30, interval=5,
|
||||
progress_msg_map={"PENDING": "이미지 생성 대기 중...", "GENERATING": "이미지 생성 중..."},
|
||||
)
|
||||
if not response:
|
||||
return
|
||||
images = response.get("images") or response.get("sunoData") or []
|
||||
urls = []
|
||||
if isinstance(images, list):
|
||||
for img in images:
|
||||
if isinstance(img, str):
|
||||
urls.append(img)
|
||||
elif isinstance(img, dict):
|
||||
urls.append(img.get("imageUrl") or img.get("image_url", ""))
|
||||
webhook_update_task(task_id, "succeeded", 100, "커버 완료",
|
||||
audio_url=json.dumps(urls))
|
||||
except Exception as e:
|
||||
logger.exception("cover image error task=%s", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
|
||||
|
||||
def run_wav_convert(task_id: str, params: dict) -> None:
|
||||
"""곡을 WAV 포맷으로 변환 (URL만)."""
|
||||
try:
|
||||
if not SUNO_API_KEY:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="SUNO_API_KEY 미설정"); return
|
||||
webhook_update_task(task_id, "processing", 5, "WAV 변환 요청 중...")
|
||||
payload = {
|
||||
"taskId": params["suno_task_id"],
|
||||
"audioId": params["suno_id"],
|
||||
"callBackUrl": "https://example.com/noop",
|
||||
}
|
||||
resp = requests.post(f"{SUNO_BASE_URL}/wav/generate", headers=_headers(), json=payload, timeout=30)
|
||||
if resp.status_code == 409:
|
||||
body = resp.json()
|
||||
wav_url = body.get("data", {}).get("audioWavUrl", "")
|
||||
if wav_url:
|
||||
webhook_update_task(task_id, "succeeded", 100, "WAV 캐시", audio_url=wav_url)
|
||||
return
|
||||
if resp.status_code != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"WAV 오류: {resp.text[:300]}"); return
|
||||
body = resp.json()
|
||||
if body.get("code") != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"WAV 거부: {body.get('msg', '?')}"); return
|
||||
wav_task_id = body.get("data", {}).get("taskId", params["suno_task_id"])
|
||||
webhook_update_task(task_id, "processing", 15, "WAV 변환 처리 중...")
|
||||
response = _poll_suno_record(
|
||||
"/wav/record-info", wav_task_id, task_id,
|
||||
max_attempts=30, interval=5,
|
||||
progress_msg_map={"PENDING": "WAV 대기 중...", "GENERATING": "WAV 변환 중..."},
|
||||
)
|
||||
if not response:
|
||||
return
|
||||
wav_url = ""
|
||||
sd = response.get("sunoData") or []
|
||||
if sd and isinstance(sd, list) and isinstance(sd[0], dict):
|
||||
wav_url = sd[0].get("audioWavUrl", "")
|
||||
if not wav_url:
|
||||
wav_url = response.get("audioWavUrl", "")
|
||||
webhook_update_task(task_id, "succeeded", 100, "WAV 변환 완료", audio_url=wav_url)
|
||||
except Exception as e:
|
||||
logger.exception("wav convert error task=%s", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
|
||||
|
||||
def run_stem_split(task_id: str, params: dict) -> None:
|
||||
try:
|
||||
if not SUNO_API_KEY:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="SUNO_API_KEY 미설정"); return
|
||||
webhook_update_task(task_id, "processing", 5, "12스템 분리 요청 중...")
|
||||
payload = {
|
||||
"taskId": params["suno_task_id"],
|
||||
"audioId": params["suno_id"],
|
||||
"type": "split_stem",
|
||||
"callBackUrl": "https://example.com/noop",
|
||||
}
|
||||
resp = requests.post(f"{SUNO_BASE_URL}/vocal-removal/generate", headers=_headers(), json=payload, timeout=30)
|
||||
if resp.status_code != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Stem API 오류: {resp.text[:300]}"); return
|
||||
body = resp.json()
|
||||
if body.get("code") != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Stem 거부: {body.get('msg', '?')}"); return
|
||||
stem_task_id = body.get("data", {}).get("taskId", "")
|
||||
if not stem_task_id:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="응답에 taskId 없음"); return
|
||||
webhook_update_task(task_id, "processing", 15, "12스템 분리 처리 중 (2~3분)...")
|
||||
response = _poll_suno_record(
|
||||
"/vocal-removal/record-info", stem_task_id, task_id,
|
||||
max_attempts=40, interval=8,
|
||||
progress_msg_map={"PENDING": "스템 대기 중...", "GENERATING": "스템 분리 중..."},
|
||||
)
|
||||
if not response:
|
||||
return
|
||||
sd = response.get("sunoData") or []
|
||||
stems = {}
|
||||
names = ["vocal", "backing_vocals", "drums", "bass", "guitar", "keyboard",
|
||||
"strings", "brass", "woodwinds", "percussion", "synth", "fx"]
|
||||
for i, item in enumerate(sd):
|
||||
if isinstance(item, dict):
|
||||
nm = names[i] if i < len(names) else f"stem_{i}"
|
||||
stems[nm] = item.get("audioUrl") or item.get("audio_url", "")
|
||||
webhook_update_task(task_id, "succeeded", 100, "12스템 완료",
|
||||
audio_url=json.dumps(stems))
|
||||
except Exception as e:
|
||||
logger.exception("stem split error task=%s", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
|
||||
|
||||
def run_upload_cover(task_id: str, params: dict) -> None:
|
||||
try:
|
||||
if not SUNO_API_KEY:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="SUNO_API_KEY 미설정"); return
|
||||
webhook_update_task(task_id, "processing", 5, "AI Cover 요청 중...")
|
||||
payload = {
|
||||
"uploadUrl": params["upload_url"],
|
||||
"customMode": params.get("custom_mode", True),
|
||||
"instrumental": params.get("instrumental", False),
|
||||
"model": params.get("model", "V4"),
|
||||
"callBackUrl": "https://example.com/noop",
|
||||
}
|
||||
for k, ak in [("prompt", "prompt"), ("style", "style"), ("title", "title"),
|
||||
("vocal_gender", "vocalGender"), ("negative_tags", "negativeTags"),
|
||||
("style_weight", "styleWeight"), ("audio_weight", "audioWeight")]:
|
||||
if params.get(k):
|
||||
payload[ak] = params[k]
|
||||
resp = requests.post(f"{SUNO_BASE_URL}/generate/upload-cover", headers=_headers(), json=payload, timeout=30)
|
||||
if resp.status_code != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Upload Cover 오류: {resp.text[:300]}"); return
|
||||
body = resp.json()
|
||||
if body.get("code") != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Upload Cover 거부: {body.get('msg', '?')}"); return
|
||||
suno_task_id = body.get("data", {}).get("taskId", "")
|
||||
if not suno_task_id:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="응답에 taskId 없음"); return
|
||||
webhook_update_task(task_id, "processing", 15, "AI Cover 생성 중...")
|
||||
response = _poll_suno_record("/generate/record-info", suno_task_id, task_id)
|
||||
if not response:
|
||||
return
|
||||
completed = response.get("sunoData") or []
|
||||
if not completed:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Cover 완료했으나 트랙 없음"); return
|
||||
track = _download_and_register(task_id, completed[0], params)
|
||||
if track:
|
||||
webhook_add_track(task_id, "succeeded", 100, "AI Cover 완료",
|
||||
audio_url=track["audio_url"], track=track)
|
||||
except Exception as e:
|
||||
logger.exception("upload cover error task=%s", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
|
||||
|
||||
def run_upload_extend(task_id: str, params: dict) -> None:
|
||||
try:
|
||||
if not SUNO_API_KEY:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="SUNO_API_KEY 미설정"); return
|
||||
webhook_update_task(task_id, "processing", 5, "Upload Extend 요청 중...")
|
||||
payload = {
|
||||
"uploadUrl": params["upload_url"],
|
||||
"defaultParamFlag": params.get("default_param_flag", True),
|
||||
"model": params.get("model", "V4"),
|
||||
"callBackUrl": "https://example.com/noop",
|
||||
}
|
||||
for k, ak in [("prompt", "prompt"), ("style", "style"), ("title", "title"),
|
||||
("continue_at", "continueAt"), ("instrumental", "instrumental"),
|
||||
("vocal_gender", "vocalGender"), ("negative_tags", "negativeTags")]:
|
||||
if params.get(k) is not None:
|
||||
payload[ak] = params[k]
|
||||
resp = requests.post(f"{SUNO_BASE_URL}/generate/upload-extend", headers=_headers(), json=payload, timeout=30)
|
||||
if resp.status_code != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Upload Extend 오류: {resp.text[:300]}"); return
|
||||
body = resp.json()
|
||||
if body.get("code") != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Upload Extend 거부: {body.get('msg', '?')}"); return
|
||||
suno_task_id = body.get("data", {}).get("taskId", "")
|
||||
if not suno_task_id:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="응답에 taskId 없음"); return
|
||||
webhook_update_task(task_id, "processing", 15, "Upload Extend 생성 중...")
|
||||
response = _poll_suno_record("/generate/record-info", suno_task_id, task_id)
|
||||
if not response:
|
||||
return
|
||||
completed = response.get("sunoData") or []
|
||||
if not completed:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Upload Extend 완료했으나 트랙 없음"); return
|
||||
track = _download_and_register(task_id, completed[0], params)
|
||||
if track:
|
||||
webhook_add_track(task_id, "succeeded", 100, "Upload Extend 완료",
|
||||
audio_url=track["audio_url"], track=track)
|
||||
except Exception as e:
|
||||
logger.exception("upload extend error task=%s", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
|
||||
|
||||
def run_add_vocals(task_id: str, params: dict) -> None:
|
||||
try:
|
||||
if not SUNO_API_KEY:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="SUNO_API_KEY 미설정"); return
|
||||
webhook_update_task(task_id, "processing", 5, "보컬 추가 요청 중...")
|
||||
payload = {
|
||||
"uploadUrl": params["upload_url"],
|
||||
"prompt": params.get("prompt", ""),
|
||||
"title": params.get("title", ""),
|
||||
"style": params.get("style", ""),
|
||||
"negativeTags": params.get("negative_tags", ""),
|
||||
"callBackUrl": "https://example.com/noop",
|
||||
}
|
||||
for k, ak in [("vocal_gender", "vocalGender"), ("model", "model"),
|
||||
("style_weight", "styleWeight"), ("audio_weight", "audioWeight")]:
|
||||
if params.get(k) is not None:
|
||||
payload[ak] = params[k]
|
||||
resp = requests.post(f"{SUNO_BASE_URL}/generate/add-vocals", headers=_headers(), json=payload, timeout=30)
|
||||
if resp.status_code != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Add Vocals 오류: {resp.text[:300]}"); return
|
||||
body = resp.json()
|
||||
if body.get("code") != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Add Vocals 거부: {body.get('msg', '?')}"); return
|
||||
suno_task_id = body.get("data", {}).get("taskId", "")
|
||||
if not suno_task_id:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="응답에 taskId 없음"); return
|
||||
webhook_update_task(task_id, "processing", 15, "AI 보컬 생성 중...")
|
||||
response = _poll_suno_record("/generate/record-info", suno_task_id, task_id)
|
||||
if not response:
|
||||
return
|
||||
completed = response.get("sunoData") or []
|
||||
if not completed:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="보컬 추가 완료했으나 트랙 없음"); return
|
||||
track = _download_and_register(task_id, completed[0], params)
|
||||
if track:
|
||||
webhook_add_track(task_id, "succeeded", 100, "보컬 추가 완료",
|
||||
audio_url=track["audio_url"], track=track)
|
||||
except Exception as e:
|
||||
logger.exception("add vocals error task=%s", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
|
||||
|
||||
def run_add_instrumental(task_id: str, params: dict) -> None:
|
||||
try:
|
||||
if not SUNO_API_KEY:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="SUNO_API_KEY 미설정"); return
|
||||
webhook_update_task(task_id, "processing", 5, "인스트루멘탈 추가 요청 중...")
|
||||
payload = {
|
||||
"uploadUrl": params["upload_url"],
|
||||
"title": params.get("title", ""),
|
||||
"tags": params.get("tags", ""),
|
||||
"negativeTags": params.get("negative_tags", ""),
|
||||
"callBackUrl": "https://example.com/noop",
|
||||
}
|
||||
for k, ak in [("vocal_gender", "vocalGender"), ("model", "model"),
|
||||
("style_weight", "styleWeight"), ("audio_weight", "audioWeight")]:
|
||||
if params.get(k) is not None:
|
||||
payload[ak] = params[k]
|
||||
resp = requests.post(f"{SUNO_BASE_URL}/generate/add-instrumental", headers=_headers(), json=payload, timeout=30)
|
||||
if resp.status_code != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Add Inst 오류: {resp.text[:300]}"); return
|
||||
body = resp.json()
|
||||
if body.get("code") != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Add Inst 거부: {body.get('msg', '?')}"); return
|
||||
suno_task_id = body.get("data", {}).get("taskId", "")
|
||||
if not suno_task_id:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="응답에 taskId 없음"); return
|
||||
webhook_update_task(task_id, "processing", 15, "AI 반주 생성 중...")
|
||||
response = _poll_suno_record("/generate/record-info", suno_task_id, task_id)
|
||||
if not response:
|
||||
return
|
||||
completed = response.get("sunoData") or []
|
||||
if not completed:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Add Inst 완료했으나 트랙 없음"); return
|
||||
track = _download_and_register(task_id, completed[0], params)
|
||||
if track:
|
||||
webhook_add_track(task_id, "succeeded", 100, "Add Instrumental 완료",
|
||||
audio_url=track["audio_url"], track=track)
|
||||
except Exception as e:
|
||||
logger.exception("add instrumental error task=%s", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
|
||||
|
||||
def run_video_generate(task_id: str, params: dict) -> None:
|
||||
try:
|
||||
if not SUNO_API_KEY:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="SUNO_API_KEY 미설정"); return
|
||||
webhook_update_task(task_id, "processing", 5, "뮤직비디오 생성 요청 중...")
|
||||
payload = {
|
||||
"taskId": params["suno_task_id"],
|
||||
"audioId": params["suno_id"],
|
||||
"callBackUrl": "https://example.com/noop",
|
||||
}
|
||||
if params.get("author"):
|
||||
payload["author"] = params["author"][:50]
|
||||
if params.get("domain_name"):
|
||||
payload["domainName"] = params["domain_name"][:50]
|
||||
resp = requests.post(f"{SUNO_BASE_URL}/mp4/generate", headers=_headers(), json=payload, timeout=30)
|
||||
if resp.status_code != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Video 오류: {resp.text[:300]}"); return
|
||||
body = resp.json()
|
||||
if body.get("code") != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Video 거부: {body.get('msg', '?')}"); return
|
||||
video_task_id = body.get("data", {}).get("taskId", params.get("suno_task_id", ""))
|
||||
webhook_update_task(task_id, "processing", 15, "뮤직비디오 렌더링 중...")
|
||||
response = _poll_suno_record(
|
||||
"/mp4/record-info", video_task_id, task_id,
|
||||
max_attempts=60, interval=10,
|
||||
progress_msg_map={"PENDING": "비디오 대기 중...", "GENERATING": "비디오 렌더링 중..."},
|
||||
)
|
||||
if not response:
|
||||
return
|
||||
video_url = ""
|
||||
sd = response.get("sunoData") or []
|
||||
if sd and isinstance(sd, list) and isinstance(sd[0], dict):
|
||||
video_url = sd[0].get("videoUrl") or sd[0].get("video_url", "")
|
||||
if not video_url:
|
||||
video_url = response.get("video_url") or response.get("videoUrl", "")
|
||||
webhook_update_task(task_id, "succeeded", 100, "뮤직비디오 완료", audio_url=video_url)
|
||||
except Exception as e:
|
||||
logger.exception("video generate error task=%s", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
131
services/music-render/providers/sync_ops.py
Normal file
131
services/music-render/providers/sync_ops.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Sync Suno API helpers — main.py FastAPI sync endpoints에서 호출.
|
||||
|
||||
NAS music-lab/app/suno_provider.py의 sync 함수들 이식.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SUNO_BASE_URL = "https://api.sunoapi.org/api/v1"
|
||||
SUNO_API_KEY = os.getenv("SUNO_API_KEY", "")
|
||||
|
||||
|
||||
def _headers() -> dict:
|
||||
return {
|
||||
"Authorization": f"Bearer {SUNO_API_KEY}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
def generate_lyrics(prompt: str) -> Optional[dict]:
|
||||
"""Suno 가사 생성 API — 폴링 결과 반환."""
|
||||
if not SUNO_API_KEY:
|
||||
return None
|
||||
try:
|
||||
resp = requests.post(
|
||||
f"{SUNO_BASE_URL}/lyrics",
|
||||
headers=_headers(),
|
||||
json={"prompt": prompt[:200]},
|
||||
timeout=30,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return None
|
||||
body = resp.json()
|
||||
if body.get("code") != 200:
|
||||
return body
|
||||
task_id = body.get("data", {}).get("taskId", "")
|
||||
if not task_id:
|
||||
return body
|
||||
return _poll_lyrics(task_id)
|
||||
except Exception as e:
|
||||
logger.warning("Suno lyrics API error: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def _poll_lyrics(lyrics_task_id: str) -> Optional[dict]:
|
||||
for _ in range(15):
|
||||
time.sleep(3)
|
||||
try:
|
||||
resp = requests.get(
|
||||
f"{SUNO_BASE_URL}/lyrics/record-info",
|
||||
headers=_headers(),
|
||||
params={"taskId": lyrics_task_id},
|
||||
timeout=15,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
continue
|
||||
body = resp.json()
|
||||
data = body.get("data", {})
|
||||
if data.get("status") == "complete":
|
||||
items = data.get("data") or data.get("sunoData") or []
|
||||
if items and isinstance(items, list):
|
||||
return {
|
||||
"id": lyrics_task_id,
|
||||
"status": "complete",
|
||||
"text": items[0].get("text", ""),
|
||||
"title": items[0].get("title", ""),
|
||||
}
|
||||
return {"id": lyrics_task_id, "status": "complete", "text": ""}
|
||||
except Exception:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def get_credits() -> Optional[dict]:
|
||||
if not SUNO_API_KEY:
|
||||
return None
|
||||
for path in ["/generate/credit", "/get-credits"]:
|
||||
try:
|
||||
resp = requests.get(f"{SUNO_BASE_URL}{path}", headers=_headers(), timeout=15)
|
||||
if resp.status_code == 200:
|
||||
body = resp.json()
|
||||
data = body.get("data", body)
|
||||
if isinstance(data, (int, float)):
|
||||
return {"credits_left": int(data)}
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.warning("Suno credits API error (%s): %s", path, e)
|
||||
return None
|
||||
|
||||
|
||||
def get_timestamped_lyrics(suno_task_id: str, suno_id: str) -> Optional[dict]:
|
||||
if not SUNO_API_KEY:
|
||||
return None
|
||||
try:
|
||||
resp = requests.post(
|
||||
f"{SUNO_BASE_URL}/generate/get-timestamped-lyrics",
|
||||
headers=_headers(),
|
||||
json={"taskId": suno_task_id, "audioId": suno_id},
|
||||
timeout=30,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
body = resp.json()
|
||||
return body.get("data", body)
|
||||
except Exception as e:
|
||||
logger.warning("Timestamped lyrics error: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def generate_style_boost(content: str) -> Optional[dict]:
|
||||
if not SUNO_API_KEY:
|
||||
return None
|
||||
try:
|
||||
resp = requests.post(
|
||||
f"{SUNO_BASE_URL}/style/generate",
|
||||
headers=_headers(),
|
||||
json={"content": content},
|
||||
timeout=30,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
body = resp.json()
|
||||
return body.get("data", body)
|
||||
except Exception as e:
|
||||
logger.warning("Style boost error: %s", e)
|
||||
return None
|
||||
9
services/music-render/requirements.txt
Normal file
9
services/music-render/requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
fastapi==0.115.6
|
||||
uvicorn[standard]==0.34.0
|
||||
requests==2.32.3
|
||||
redis>=5.0
|
||||
httpx>=0.27
|
||||
mutagen==1.47.0
|
||||
pytest>=8.0
|
||||
pytest-asyncio>=0.24
|
||||
respx>=0.21
|
||||
80
services/music-render/tests/test_nas_client.py
Normal file
80
services/music-render/tests/test_nas_client.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""nas_client — webhook adapter tests."""
|
||||
import os
|
||||
import pytest
|
||||
import respx
|
||||
import httpx
|
||||
|
||||
from nas_client import webhook_update_task, webhook_add_track
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _env(monkeypatch):
|
||||
monkeypatch.setenv("NAS_BASE_URL", "http://nas-test:18600")
|
||||
monkeypatch.setenv("INTERNAL_API_KEY", "test-key")
|
||||
|
||||
|
||||
@respx.mock
|
||||
def test_webhook_update_task_sends_x_internal_key():
|
||||
route = respx.post("http://nas-test:18600/api/internal/music/update").mock(
|
||||
return_value=httpx.Response(200, json={"ok": True})
|
||||
)
|
||||
webhook_update_task("task-1", "processing", 30, message="downloading")
|
||||
assert route.called
|
||||
req = route.calls[0].request
|
||||
assert req.headers["X-Internal-Key"] == "test-key"
|
||||
import json
|
||||
body = json.loads(req.content)
|
||||
assert body["task_id"] == "task-1"
|
||||
assert body["status"] == "processing"
|
||||
assert body["progress"] == 30
|
||||
assert body["message"] == "downloading"
|
||||
|
||||
|
||||
@respx.mock
|
||||
def test_webhook_update_task_with_audio_url():
|
||||
route = respx.post("http://nas-test:18600/api/internal/music/update").mock(
|
||||
return_value=httpx.Response(200, json={"ok": True})
|
||||
)
|
||||
webhook_update_task("task-2", "succeeded", 100, message="완료",
|
||||
audio_url="/media/music/task-2.mp3")
|
||||
import json
|
||||
payload = json.loads(route.calls[0].request.content)
|
||||
assert payload["audio_url"] == "/media/music/task-2.mp3"
|
||||
assert payload["status"] == "succeeded"
|
||||
|
||||
|
||||
@respx.mock
|
||||
def test_webhook_update_task_with_error():
|
||||
route = respx.post("http://nas-test:18600/api/internal/music/update").mock(
|
||||
return_value=httpx.Response(200, json={"ok": True})
|
||||
)
|
||||
webhook_update_task("task-3", "failed", 0, error="API rate limit")
|
||||
import json
|
||||
payload = json.loads(route.calls[0].request.content)
|
||||
assert payload["error"] == "API rate limit"
|
||||
|
||||
|
||||
@respx.mock
|
||||
def test_webhook_add_track_uses_track_field():
|
||||
"""add_track은 update와 동시에 (succeeded 시)."""
|
||||
route = respx.post("http://nas-test:18600/api/internal/music/update").mock(
|
||||
return_value=httpx.Response(200, json={"ok": True})
|
||||
)
|
||||
track = {"title": "x", "audio_url": "/media/music/t.mp3", "provider": "suno"}
|
||||
webhook_add_track("task-4", "succeeded", 100, message="ok",
|
||||
audio_url="/media/music/t.mp3", track=track)
|
||||
import json
|
||||
payload = json.loads(route.calls[0].request.content)
|
||||
assert payload["track"]["title"] == "x"
|
||||
assert payload["status"] == "succeeded"
|
||||
|
||||
|
||||
@respx.mock
|
||||
def test_webhook_swallows_network_error(caplog):
|
||||
"""webhook 실패해도 raise 안 함 (logger.error)."""
|
||||
respx.post("http://nas-test:18600/api/internal/music/update").mock(
|
||||
side_effect=httpx.ConnectError("no host")
|
||||
)
|
||||
# raise 안 하면 통과
|
||||
webhook_update_task("task-5", "processing", 10)
|
||||
assert "task-5" in caplog.text
|
||||
32
services/music-render/tests/test_suno_provider.py
Normal file
32
services/music-render/tests/test_suno_provider.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""providers/suno.py — _build_suno_payload 단위 테스트 + 1개 함수 mock 검증."""
|
||||
import pytest
|
||||
from providers.suno import _build_suno_payload
|
||||
|
||||
|
||||
def test_payload_custom_mode_with_lyrics():
|
||||
params = {"lyrics": "[Verse]\nhello", "genre": "lofi", "moods": ["chill"], "model": "V4"}
|
||||
p = _build_suno_payload(params)
|
||||
assert p["customMode"] is True
|
||||
assert p["prompt"] == "[Verse]\nhello"
|
||||
assert "lofi" in p["style"]
|
||||
assert "chill" in p["style"]
|
||||
|
||||
|
||||
def test_payload_simple_mode_no_lyrics_no_genre():
|
||||
params = {"prompt": "happy summer", "model": "V4"}
|
||||
p = _build_suno_payload(params)
|
||||
assert p["customMode"] is False
|
||||
assert "happy summer" in p["prompt"]
|
||||
|
||||
|
||||
def test_payload_instrumental_clears_prompt():
|
||||
params = {"genre": "ambient", "instrumental": True, "model": "V5"}
|
||||
p = _build_suno_payload(params)
|
||||
assert p["instrumental"] is True
|
||||
assert p["prompt"] == ""
|
||||
|
||||
|
||||
def test_payload_includes_optional_vocal_gender():
|
||||
params = {"genre": "pop", "vocal_gender": "f", "model": "V4"}
|
||||
p = _build_suno_payload(params)
|
||||
assert p["vocalGender"] == "f"
|
||||
109
services/music-render/tests/test_worker.py
Normal file
109
services/music-render/tests/test_worker.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""worker.py — job_type 디스패처 + paused 체크."""
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import worker
|
||||
|
||||
|
||||
def test_dispatch_suno_generation_calls_run_suno_generation():
|
||||
payload = {
|
||||
"task_id": "t1",
|
||||
"job_type": "suno_generation",
|
||||
"params": {"genre": "lofi", "title": "x"},
|
||||
}
|
||||
with patch("worker.run_suno_generation") as m:
|
||||
worker._dispatch(payload)
|
||||
m.assert_called_once_with("t1", {"genre": "lofi", "title": "x"})
|
||||
|
||||
|
||||
def test_dispatch_local_generation_calls_run_local_generation():
|
||||
payload = {
|
||||
"task_id": "t2",
|
||||
"job_type": "local_generation",
|
||||
"params": {"genre": "ambient"},
|
||||
}
|
||||
with patch("worker.run_local_generation") as m:
|
||||
worker._dispatch(payload)
|
||||
m.assert_called_once_with("t2", {"genre": "ambient"})
|
||||
|
||||
|
||||
def test_dispatch_unknown_job_type_logs_error():
|
||||
payload = {"task_id": "t3", "job_type": "weird_type", "params": {}}
|
||||
with patch("worker.webhook_update_task") as m:
|
||||
worker._dispatch(payload)
|
||||
# 알 수 없는 job_type은 failed로 보고
|
||||
m.assert_called_once()
|
||||
args = m.call_args[0]
|
||||
assert args[0] == "t3"
|
||||
assert args[1] == "failed"
|
||||
|
||||
|
||||
def test_dispatch_suno_extend_calls_run_suno_extend():
|
||||
payload = {"task_id": "t4", "job_type": "suno_extend", "params": {"suno_id": "abc"}}
|
||||
with patch("worker.run_suno_extend") as m:
|
||||
worker._dispatch(payload)
|
||||
m.assert_called_once_with("t4", {"suno_id": "abc"})
|
||||
|
||||
|
||||
def test_dispatch_vocal_removal_calls_run_vocal_removal():
|
||||
payload = {"task_id": "t5", "job_type": "vocal_removal", "params": {"suno_id": "abc"}}
|
||||
with patch("worker.run_vocal_removal") as m:
|
||||
worker._dispatch(payload)
|
||||
m.assert_called_once_with("t5", {"suno_id": "abc"})
|
||||
|
||||
|
||||
def test_dispatch_cover_image_calls_run_cover_image():
|
||||
payload = {"task_id": "t6", "job_type": "cover_image", "params": {"suno_task_id": "x"}}
|
||||
with patch("worker.run_cover_image") as m:
|
||||
worker._dispatch(payload)
|
||||
m.assert_called_once_with("t6", {"suno_task_id": "x"})
|
||||
|
||||
|
||||
def test_dispatch_wav_convert_calls_run_wav_convert():
|
||||
payload = {"task_id": "t7", "job_type": "wav_convert", "params": {"suno_task_id": "x", "suno_id": "y"}}
|
||||
with patch("worker.run_wav_convert") as m:
|
||||
worker._dispatch(payload)
|
||||
m.assert_called_once_with("t7", {"suno_task_id": "x", "suno_id": "y"})
|
||||
|
||||
|
||||
def test_dispatch_stem_split_calls_run_stem_split():
|
||||
payload = {"task_id": "t8", "job_type": "stem_split", "params": {"suno_task_id": "x", "suno_id": "y"}}
|
||||
with patch("worker.run_stem_split") as m:
|
||||
worker._dispatch(payload)
|
||||
m.assert_called_once_with("t8", {"suno_task_id": "x", "suno_id": "y"})
|
||||
|
||||
|
||||
def test_dispatch_video_generate_calls_run_video_generate():
|
||||
payload = {"task_id": "t9", "job_type": "video_generate", "params": {"suno_task_id": "x", "suno_id": "y"}}
|
||||
with patch("worker.run_video_generate") as m:
|
||||
worker._dispatch(payload)
|
||||
m.assert_called_once_with("t9", {"suno_task_id": "x", "suno_id": "y"})
|
||||
|
||||
|
||||
def test_dispatch_upload_cover_calls_run_upload_cover():
|
||||
payload = {"task_id": "t10", "job_type": "upload_cover", "params": {"upload_url": "u"}}
|
||||
with patch("worker.run_upload_cover") as m:
|
||||
worker._dispatch(payload)
|
||||
m.assert_called_once_with("t10", {"upload_url": "u"})
|
||||
|
||||
|
||||
def test_dispatch_upload_extend_calls_run_upload_extend():
|
||||
payload = {"task_id": "t11", "job_type": "upload_extend", "params": {"upload_url": "u"}}
|
||||
with patch("worker.run_upload_extend") as m:
|
||||
worker._dispatch(payload)
|
||||
m.assert_called_once_with("t11", {"upload_url": "u"})
|
||||
|
||||
|
||||
def test_dispatch_add_vocals_calls_run_add_vocals():
|
||||
payload = {"task_id": "t12", "job_type": "add_vocals", "params": {"upload_url": "u"}}
|
||||
with patch("worker.run_add_vocals") as m:
|
||||
worker._dispatch(payload)
|
||||
m.assert_called_once_with("t12", {"upload_url": "u"})
|
||||
|
||||
|
||||
def test_dispatch_add_instrumental_calls_run_add_instrumental():
|
||||
payload = {"task_id": "t13", "job_type": "add_instrumental", "params": {"upload_url": "u"}}
|
||||
with patch("worker.run_add_instrumental") as m:
|
||||
worker._dispatch(payload)
|
||||
m.assert_called_once_with("t13", {"upload_url": "u"})
|
||||
95
services/music-render/worker.py
Normal file
95
services/music-render/worker.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Redis BLPOP worker — queue:music-render → job_type 디스패치 → NAS webhook.
|
||||
|
||||
queue:paused 가 set이면 대기 (task-watcher가 박재오 활동 감지 시 set).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from nas_client import webhook_update_task
|
||||
from providers.suno import (
|
||||
run_suno_generation, run_suno_extend, run_vocal_removal,
|
||||
run_cover_image, run_wav_convert, run_stem_split,
|
||||
run_upload_cover, run_upload_extend, run_add_vocals,
|
||||
run_add_instrumental, run_video_generate,
|
||||
)
|
||||
from providers.local import run_local_generation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379")
|
||||
QUEUE_KEY = "queue:music-render"
|
||||
PAUSED_KEY = "queue:paused"
|
||||
|
||||
# Maps job_type → module-level function name (string).
|
||||
# _dispatch resolves the name via globals() at call time so unittest.mock.patch
|
||||
# on "worker.<name>" is correctly intercepted.
|
||||
_DISPATCH_TABLE: dict[str, str] = {
|
||||
"suno_generation": "run_suno_generation",
|
||||
"local_generation": "run_local_generation",
|
||||
"suno_extend": "run_suno_extend",
|
||||
"vocal_removal": "run_vocal_removal",
|
||||
"cover_image": "run_cover_image",
|
||||
"wav_convert": "run_wav_convert",
|
||||
"stem_split": "run_stem_split",
|
||||
"upload_cover": "run_upload_cover",
|
||||
"upload_extend": "run_upload_extend",
|
||||
"add_vocals": "run_add_vocals",
|
||||
"add_instrumental": "run_add_instrumental",
|
||||
"video_generate": "run_video_generate",
|
||||
}
|
||||
|
||||
|
||||
def _dispatch(payload: dict) -> None:
|
||||
"""payload[job_type] → provider 함수 호출 (sync, asyncio.to_thread로 래핑)."""
|
||||
import sys
|
||||
_self = sys.modules[__name__]
|
||||
job_type = payload.get("job_type", "")
|
||||
task_id = payload.get("task_id", "")
|
||||
params = payload.get("params", {})
|
||||
fn_name = _DISPATCH_TABLE.get(job_type)
|
||||
if fn_name is None:
|
||||
logger.error("unknown job_type=%s task=%s", job_type, task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"unknown job_type: {job_type}")
|
||||
return
|
||||
try:
|
||||
fn = getattr(_self, fn_name)
|
||||
except AttributeError:
|
||||
logger.error("dispatch table typo for job_type=%s name=%s task=%s", job_type, fn_name, task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"internal dispatch error: {fn_name}")
|
||||
return
|
||||
fn(task_id, params)
|
||||
|
||||
|
||||
async def worker_loop():
|
||||
redis = aioredis.from_url(REDIS_URL, decode_responses=False)
|
||||
logger.info("music-render worker started (queue=%s)", QUEUE_KEY)
|
||||
while True:
|
||||
try:
|
||||
paused = await redis.get(PAUSED_KEY)
|
||||
if paused == b"1":
|
||||
await asyncio.sleep(10)
|
||||
continue
|
||||
item = await redis.blpop(QUEUE_KEY, timeout=1)
|
||||
if item is None:
|
||||
continue
|
||||
_, raw = item
|
||||
try:
|
||||
payload = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("invalid queue payload: %r", raw[:200])
|
||||
continue
|
||||
# sync provider 함수 — thread로 실행해서 이벤트 루프 블로킹 방지
|
||||
await asyncio.to_thread(_dispatch, payload)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("worker_loop cancelled")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("worker_loop iteration 실패, 5초 후 재시도")
|
||||
await asyncio.sleep(5)
|
||||
11
services/task-watcher/.env.example
Normal file
11
services/task-watcher/.env.example
Normal file
@@ -0,0 +1,11 @@
|
||||
# Plan-B-Infra — task-watcher
|
||||
|
||||
# NAS Redis
|
||||
REDIS_URL=redis://192.168.45.54:6379
|
||||
|
||||
# NAS stock holidays endpoint
|
||||
STOCK_BASE_URL=http://192.168.45.54:18500
|
||||
|
||||
# 트레이딩 윈도우 (KST, HH:MM) — 이 시간대에만 queue:paused
|
||||
TRADING_START=07:00
|
||||
TRADING_END=16:30
|
||||
16
services/task-watcher/Dockerfile
Normal file
16
services/task-watcher/Dockerfile
Normal file
@@ -0,0 +1,16 @@
|
||||
FROM python:3.12-slim-bookworm
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates tzdata \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir --timeout 600 --retries 5 -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
EXPOSE 8000
|
||||
CMD ["python", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "1"]
|
||||
83
services/task-watcher/NSSM_SETUP.md
Normal file
83
services/task-watcher/NSSM_SETUP.md
Normal file
@@ -0,0 +1,83 @@
|
||||
# NSSM 자동 시작 설정 (SP-9)
|
||||
|
||||
Windows AI 머신 부팅 시 ai_trade(트레이딩) + WSL2 Docker(render workers + task-watcher) 자동 시작.
|
||||
|
||||
## 1. NSSM 다운로드
|
||||
|
||||
https://nssm.cc/download → nssm-2.24.zip → `C:\nssm\nssm.exe` 배치 (또는 PATH 등록).
|
||||
|
||||
## 2. ai_trade (Native Python, HIGH priority)
|
||||
|
||||
⚠️ spec의 signal_v2는 ai_trade로 rename됨. 경로/포트 확인.
|
||||
|
||||
```powershell
|
||||
# 관리자 PowerShell
|
||||
C:\nssm\nssm.exe install ai_trade "C:\Python312\python.exe" "-m uvicorn main:app --host 0.0.0.0 --port 8001"
|
||||
C:\nssm\nssm.exe set ai_trade AppDirectory "C:\Users\jaeoh\Desktop\workspace\web-ai\ai_trade"
|
||||
C:\nssm\nssm.exe set ai_trade Priority HIGH_PRIORITY_CLASS
|
||||
C:\nssm\nssm.exe set ai_trade Start SERVICE_AUTO_START
|
||||
C:\nssm\nssm.exe set ai_trade AppStdout "C:\Users\jaeoh\nssm-logs\ai_trade.log"
|
||||
C:\nssm\nssm.exe set ai_trade AppStderr "C:\Users\jaeoh\nssm-logs\ai_trade.log"
|
||||
```
|
||||
|
||||
(ai_trade의 실제 진입점이 main:app + port 8001인지 확인. 다르면 조정.)
|
||||
|
||||
## 3. WSL2 Docker (NORMAL priority — render workers + task-watcher)
|
||||
|
||||
```powershell
|
||||
C:\nssm\nssm.exe install wsl_docker "C:\Windows\System32\wsl.exe" "-d Ubuntu-24.04 -- sh -c 'sudo service docker start && cd /workspace/web-ai/services && docker compose up -d'"
|
||||
C:\nssm\nssm.exe set wsl_docker Priority NORMAL_PRIORITY_CLASS
|
||||
C:\nssm\nssm.exe set wsl_docker Start SERVICE_AUTO_START
|
||||
C:\nssm\nssm.exe set wsl_docker AppStdout "C:\Users\jaeoh\nssm-logs\wsl_docker.log"
|
||||
```
|
||||
|
||||
⚠️ 변경점: Ubuntu-22.04 → **Ubuntu-24.04**, web-ai-services → **web-ai/services**. WSL 경로는 박재오 WSL 마운트 기준 (`/workspace`가 web-ai에 매핑되어 있으면 그대로, 아니면 `/mnt/c/Users/jaeoh/Desktop/workspace/web-ai/services`).
|
||||
|
||||
`sudo service docker start`가 비밀번호 요구하면 sudoers에 NOPASSWD 추가:
|
||||
```bash
|
||||
# WSL2 안
|
||||
echo "$USER ALL=(ALL) NOPASSWD: /usr/sbin/service docker start" | sudo tee /etc/sudoers.d/docker-start
|
||||
```
|
||||
|
||||
## 4. 서비스 시작 + 확인
|
||||
|
||||
```powershell
|
||||
C:\nssm\nssm.exe start ai_trade
|
||||
C:\nssm\nssm.exe start wsl_docker
|
||||
|
||||
# 상태 확인
|
||||
C:\nssm\nssm.exe status ai_trade
|
||||
C:\nssm\nssm.exe status wsl_docker
|
||||
sc query ai_trade
|
||||
```
|
||||
|
||||
## 5. 검증
|
||||
|
||||
```powershell
|
||||
# ai_trade
|
||||
curl http://localhost:8001/health # 또는 ai_trade의 실제 health endpoint
|
||||
|
||||
# WSL2 docker 컨테이너 (재부팅 후 자동 시작 확인)
|
||||
wsl -d Ubuntu-24.04 -- docker ps
|
||||
# insta-render, music-render, video-render, task-watcher 4개 Up 확인
|
||||
```
|
||||
|
||||
## 6. 재부팅 테스트
|
||||
|
||||
Windows 재부팅 → 로그인 → 수동 조작 없이:
|
||||
- ai_trade 서비스 자동 시작 (HIGH priority)
|
||||
- WSL2 + Docker + 4 컨테이너 자동 시작 (NORMAL priority)
|
||||
- task-watcher가 trading window에 queue:paused 토글 시작
|
||||
|
||||
## task-watcher 동작 확인
|
||||
|
||||
```bash
|
||||
# WSL2
|
||||
docker logs task-watcher --tail 20
|
||||
# 기대: "task-watcher started" + mode 전환 로그 (trading/free)
|
||||
|
||||
# Redis 큐 상태 (NAS 또는 LAN)
|
||||
docker exec redis redis-cli GET queue:paused
|
||||
# 트레이딩 시간대(평일 07:00-16:30): "1"
|
||||
# 그 외: (nil)
|
||||
```
|
||||
36
services/task-watcher/main.py
Normal file
36
services/task-watcher/main.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""task-watcher FastAPI entry — health + lifespan (watcher loop spawn)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
import watcher
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
watcher_task = asyncio.create_task(watcher.watcher_loop())
|
||||
logger.info("task-watcher lifespan 시작")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
watcher_task.cancel()
|
||||
try:
|
||||
await watcher_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("task-watcher lifespan 종료")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"ok": True, "service": "task-watcher"}
|
||||
57
services/task-watcher/mode.py
Normal file
57
services/task-watcher/mode.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""시간대 + 휴장일 기반 모드 판정 (idle 감지 생략 — 박재오 결정 2026-05-22).
|
||||
|
||||
trading: 비휴장 평일 07:00–16:30 (장중) → queue:paused SET
|
||||
free: 그 외 (장 전/후, 주말, 휴장) → queue:paused DEL
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as dt
|
||||
import logging
|
||||
import os
|
||||
from typing import Set
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KST = ZoneInfo("Asia/Seoul")
|
||||
STOCK_BASE_URL = os.getenv("STOCK_BASE_URL", "http://192.168.45.54:18500")
|
||||
|
||||
# 트레이딩 윈도우 (HH:MM, KST). .env로 조정 가능.
|
||||
TRADING_START = os.getenv("TRADING_START", "07:00")
|
||||
TRADING_END = os.getenv("TRADING_END", "16:30")
|
||||
|
||||
|
||||
def _parse_hhmm(s: str) -> dt.time:
|
||||
hh, mm = s.split(":")
|
||||
return dt.time(int(hh), int(mm))
|
||||
|
||||
|
||||
def current_mode(now: dt.datetime, holidays: Set[str]) -> str:
|
||||
"""now(KST aware) + holidays(ISO date set) → 'trading' | 'free'."""
|
||||
# 주말 (토=5, 일=6)
|
||||
if now.weekday() >= 5:
|
||||
return "free"
|
||||
# 휴장일
|
||||
if now.date().isoformat() in holidays:
|
||||
return "free"
|
||||
# 트레이딩 윈도우 [start, end)
|
||||
start = _parse_hhmm(TRADING_START)
|
||||
end = _parse_hhmm(TRADING_END)
|
||||
t = now.timetz().replace(tzinfo=None)
|
||||
if start <= t < end:
|
||||
return "trading"
|
||||
return "free"
|
||||
|
||||
|
||||
def fetch_holidays() -> Set[str]:
|
||||
"""NAS stock /api/stock/holidays 조회. 실패 시 빈 set (안전 — free로 판정)."""
|
||||
try:
|
||||
r = httpx.get(f"{STOCK_BASE_URL}/api/stock/holidays", timeout=10.0)
|
||||
if r.status_code == 200:
|
||||
return set(r.json().get("holidays", []))
|
||||
logger.warning("holidays fetch returned %d", r.status_code)
|
||||
except Exception:
|
||||
logger.exception("holidays fetch 실패")
|
||||
return set()
|
||||
5
services/task-watcher/requirements.txt
Normal file
5
services/task-watcher/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
fastapi==0.115.6
|
||||
uvicorn[standard]==0.34.0
|
||||
redis>=5.0
|
||||
httpx>=0.27
|
||||
pytest>=8.0
|
||||
0
services/task-watcher/tests/__init__.py
Normal file
0
services/task-watcher/tests/__init__.py
Normal file
44
services/task-watcher/tests/test_mode.py
Normal file
44
services/task-watcher/tests/test_mode.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""current_mode — 시간대 + 휴장일 판정 (순수 함수)."""
|
||||
import datetime as dt
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from mode import current_mode
|
||||
|
||||
KST = ZoneInfo("Asia/Seoul")
|
||||
HOLIDAYS = {"2026-05-25"} # 가상 휴장일 (월요일)
|
||||
|
||||
|
||||
def _kst(y, m, d, hh, mm):
|
||||
return dt.datetime(y, m, d, hh, mm, tzinfo=KST)
|
||||
|
||||
|
||||
def test_weekday_trading_hours_is_trading():
|
||||
# 2026-05-22 금요일 10:00 — 트레이딩 시간대
|
||||
assert current_mode(_kst(2026, 5, 22, 10, 0), HOLIDAYS) == "trading"
|
||||
|
||||
|
||||
def test_weekday_before_open_is_free():
|
||||
# 평일 06:00 — 장 전
|
||||
assert current_mode(_kst(2026, 5, 22, 6, 0), HOLIDAYS) == "free"
|
||||
|
||||
|
||||
def test_weekday_after_close_is_free():
|
||||
# 평일 17:00 — 장 마감 후
|
||||
assert current_mode(_kst(2026, 5, 22, 17, 0), HOLIDAYS) == "free"
|
||||
|
||||
|
||||
def test_weekend_is_free():
|
||||
# 2026-05-23 토요일 10:00
|
||||
assert current_mode(_kst(2026, 5, 23, 10, 0), HOLIDAYS) == "free"
|
||||
|
||||
|
||||
def test_holiday_weekday_is_free():
|
||||
# 2026-05-25 월요일이지만 휴장일 → 트레이딩 시간대라도 free
|
||||
assert current_mode(_kst(2026, 5, 25, 10, 0), HOLIDAYS) == "free"
|
||||
|
||||
|
||||
def test_trading_boundary_inclusive_start_exclusive_end():
|
||||
# 07:00 정각 = 트레이딩 시작, 16:30 정각 = 마감 (16:30은 free)
|
||||
assert current_mode(_kst(2026, 5, 22, 7, 0), HOLIDAYS) == "trading"
|
||||
assert current_mode(_kst(2026, 5, 22, 16, 29), HOLIDAYS) == "trading"
|
||||
assert current_mode(_kst(2026, 5, 22, 16, 30), HOLIDAYS) == "free"
|
||||
59
services/task-watcher/watcher.py
Normal file
59
services/task-watcher/watcher.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""30초마다 current_mode 판정 → queue:paused 토글.
|
||||
|
||||
trading → SET queue:paused 1 EX 600 (10분 TTL — watcher 죽어도 자동 해제)
|
||||
free → DEL queue:paused
|
||||
holidays는 1시간마다 refresh (매 loop fetch 부하 회피).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime as dt
|
||||
import logging
|
||||
import os
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from mode import current_mode, fetch_holidays, KST
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379")
|
||||
PAUSED_KEY = "queue:paused"
|
||||
LOOP_INTERVAL = 30 # 초
|
||||
HOLIDAYS_REFRESH = 3600 # 1시간
|
||||
PAUSED_TTL = 600 # 10분 (watcher 죽어도 자동 해제)
|
||||
|
||||
|
||||
async def watcher_loop():
|
||||
redis = aioredis.from_url(REDIS_URL, decode_responses=False)
|
||||
holidays = fetch_holidays()
|
||||
last_holiday_refresh = dt.datetime.now(KST)
|
||||
last_mode = None
|
||||
logger.info("task-watcher started (trading window 토글)")
|
||||
|
||||
while True:
|
||||
try:
|
||||
now = dt.datetime.now(KST)
|
||||
# holidays 주기적 refresh
|
||||
if (now - last_holiday_refresh).total_seconds() >= HOLIDAYS_REFRESH:
|
||||
holidays = fetch_holidays()
|
||||
last_holiday_refresh = now
|
||||
|
||||
mode = current_mode(now, holidays)
|
||||
if mode == "trading":
|
||||
await redis.set(PAUSED_KEY, b"1", ex=PAUSED_TTL)
|
||||
else:
|
||||
await redis.delete(PAUSED_KEY)
|
||||
|
||||
if mode != last_mode:
|
||||
logger.info("mode 전환: %s → %s (paused=%s)", last_mode, mode, mode == "trading")
|
||||
last_mode = mode
|
||||
|
||||
await asyncio.sleep(LOOP_INTERVAL)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("watcher_loop cancelled")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("watcher_loop iteration 실패, 30초 후 재시도")
|
||||
await asyncio.sleep(LOOP_INTERVAL)
|
||||
27
services/video-render/.env.example
Normal file
27
services/video-render/.env.example
Normal file
@@ -0,0 +1,27 @@
|
||||
# Plan-B-Video — Windows video-render worker
|
||||
|
||||
# NAS Redis 큐
|
||||
REDIS_URL=redis://192.168.45.54:6379
|
||||
|
||||
# NAS internal webhook (video-lab port 18801)
|
||||
NAS_BASE_URL=http://192.168.45.54:18801
|
||||
INTERNAL_API_KEY=__copy_from_nas_dotenv__
|
||||
|
||||
# Sora 2 (OpenAI)
|
||||
OPENAI_API_KEY=__paste_openai_key__
|
||||
|
||||
# Veo (Google Gemini API — ai.google.dev. Vertex AI 경로 아님, GCS bucket 불필요)
|
||||
GEMINI_API_KEY=__paste_gemini_key__
|
||||
|
||||
# Kling (Native KlingAI — JWT auth with Access Key + Secret Key)
|
||||
KLING_ACCESS_KEY=__paste_kling_access_key__
|
||||
KLING_SECRET_KEY=__paste_kling_secret_key__
|
||||
|
||||
# Seedance 2.0 (BytePlus)
|
||||
SEEDANCE_API_KEY=__paste_seedance_key__
|
||||
|
||||
# NAS SMB mount 안의 video 디렉토리
|
||||
VIDEO_MEDIA_ROOT=/mnt/nas/webpage/data/video
|
||||
|
||||
# nginx 서빙 prefix (NAS webhook payload용)
|
||||
VIDEO_MEDIA_URL_PREFIX=/media/video
|
||||
16
services/video-render/Dockerfile
Normal file
16
services/video-render/Dockerfile
Normal file
@@ -0,0 +1,16 @@
|
||||
FROM python:3.12-slim-bookworm
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir --timeout 600 --retries 5 -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
EXPOSE 8000
|
||||
CMD ["python", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "1"]
|
||||
36
services/video-render/main.py
Normal file
36
services/video-render/main.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""video-render FastAPI entry — health + lifespan (worker loop spawn)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
import worker
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
worker_task = asyncio.create_task(worker.worker_loop())
|
||||
logger.info("video-render lifespan 시작")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
worker_task.cancel()
|
||||
try:
|
||||
await worker_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("video-render lifespan 종료")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"ok": True, "service": "video-render"}
|
||||
54
services/video-render/nas_client.py
Normal file
54
services/video-render/nas_client.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""NAS webhook 어댑터 — Windows worker가 NAS DB 직접 접근 못하므로 HTTP로 위임.
|
||||
|
||||
Plan-B-Music nas_client와 동일 패턴 (call-time os.getenv으로 테스트 격리).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TIMEOUT = 10.0
|
||||
|
||||
|
||||
def _post(payload: Dict[str, Any]) -> None:
|
||||
nas_base_url = os.getenv("NAS_BASE_URL", "http://192.168.45.54:18801")
|
||||
internal_api_key = os.getenv("INTERNAL_API_KEY", "")
|
||||
url = f"{nas_base_url}/api/internal/video/update"
|
||||
try:
|
||||
r = httpx.post(
|
||||
url,
|
||||
headers={"X-Internal-Key": internal_api_key},
|
||||
json=payload,
|
||||
timeout=_TIMEOUT,
|
||||
)
|
||||
if r.status_code != 200:
|
||||
logger.error("webhook %s returned %d: %s",
|
||||
payload.get("task_id"), r.status_code, r.text[:200])
|
||||
except Exception:
|
||||
logger.exception("webhook %s 호출 실패", payload.get("task_id"))
|
||||
|
||||
|
||||
def webhook_update_task(
|
||||
task_id: str,
|
||||
status: str,
|
||||
progress: int,
|
||||
message: str = "",
|
||||
video_url: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> None:
|
||||
payload: Dict[str, Any] = {
|
||||
"task_id": task_id,
|
||||
"status": status,
|
||||
"progress": progress,
|
||||
"message": message,
|
||||
}
|
||||
if video_url is not None:
|
||||
payload["video_url"] = video_url
|
||||
if error is not None:
|
||||
payload["error"] = error
|
||||
_post(payload)
|
||||
0
services/video-render/providers/__init__.py
Normal file
0
services/video-render/providers/__init__.py
Normal file
153
services/video-render/providers/kling.py
Normal file
153
services/video-render/providers/kling.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Kling AI video generation — Native KlingAI API (api.klingai.com).
|
||||
|
||||
JWT auth: HS256, payload {iss: ACCESS_KEY, exp: now+1800, nbf: now-5}.
|
||||
POST /v1/videos/text2video → GET /v1/videos/text2video/{task_id} → task_result.videos[0].url 다운로드.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import jwt as pyjwt
|
||||
import requests
|
||||
|
||||
from nas_client import webhook_update_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KLING_BASE_URL = "https://api.klingai.com"
|
||||
VIDEO_MEDIA_ROOT = os.getenv("VIDEO_MEDIA_ROOT", "/mnt/nas/webpage/data/video")
|
||||
VIDEO_MEDIA_URL_PREFIX = os.getenv("VIDEO_MEDIA_URL_PREFIX", "/media/video")
|
||||
|
||||
POLL_INTERVAL = 10
|
||||
POLL_MAX_ATTEMPTS = 60 # 최대 ~10분
|
||||
|
||||
DEFAULT_MODEL = "kling-v1-6"
|
||||
|
||||
JWT_EXP_SECONDS = 1800 # 30분
|
||||
JWT_NBF_OFFSET = -5 # 5초 뒤로
|
||||
|
||||
|
||||
def _generate_jwt() -> Optional[str]:
|
||||
access_key = os.getenv("KLING_ACCESS_KEY", "")
|
||||
secret_key = os.getenv("KLING_SECRET_KEY", "")
|
||||
if not access_key or not secret_key:
|
||||
return None
|
||||
now = int(time.time())
|
||||
headers = {"alg": "HS256", "typ": "JWT"}
|
||||
payload = {"iss": access_key, "exp": now + JWT_EXP_SECONDS, "nbf": now + JWT_NBF_OFFSET}
|
||||
return pyjwt.encode(payload, secret_key, algorithm="HS256", headers=headers)
|
||||
|
||||
|
||||
def _headers() -> dict:
|
||||
token = _generate_jwt()
|
||||
return {
|
||||
"Authorization": f"Bearer {token}" if token else "",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
def run_kling_generation(task_id: str, params: dict) -> None:
|
||||
"""Kling으로 영상 생성 → mp4 → NAS SMB → webhook."""
|
||||
try:
|
||||
if not os.getenv("KLING_ACCESS_KEY") or not os.getenv("KLING_SECRET_KEY"):
|
||||
webhook_update_task(task_id, "failed", 0, "",
|
||||
error="KLING_ACCESS_KEY 또는 KLING_SECRET_KEY 미설정")
|
||||
return
|
||||
|
||||
webhook_update_task(task_id, "processing", 5, "Kling API 호출 중...")
|
||||
|
||||
# image_url 있으면 image2video, 없으면 text2video
|
||||
is_image2video = bool(params.get("image_url"))
|
||||
endpoint_path = "/v1/videos/image2video" if is_image2video else "/v1/videos/text2video"
|
||||
|
||||
body = {
|
||||
"model_name": params.get("model") or DEFAULT_MODEL,
|
||||
"prompt": params["prompt"][:2500],
|
||||
"duration": str(params.get("duration", 5)),
|
||||
"aspect_ratio": params.get("aspect_ratio", "16:9"),
|
||||
"mode": params.get("mode", "std"),
|
||||
}
|
||||
if params.get("negative_prompt"):
|
||||
body["negative_prompt"] = params["negative_prompt"][:2500]
|
||||
if params.get("cfg_scale") is not None:
|
||||
body["cfg_scale"] = float(params["cfg_scale"])
|
||||
if is_image2video:
|
||||
body["image"] = params["image_url"]
|
||||
|
||||
resp = requests.post(f"{KLING_BASE_URL}{endpoint_path}",
|
||||
headers=_headers(), json=body, timeout=30)
|
||||
if resp.status_code != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "",
|
||||
error=f"Kling API 오류: {resp.status_code} {resp.text[:300]}")
|
||||
return
|
||||
|
||||
body_json = resp.json()
|
||||
if body_json.get("code") != 0:
|
||||
webhook_update_task(task_id, "failed", 0, "",
|
||||
error=f"Kling API 거부: {body_json.get('message', '?')}")
|
||||
return
|
||||
|
||||
kling_task_id = (body_json.get("data") or {}).get("task_id", "")
|
||||
if not kling_task_id:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Kling 응답에 task_id 없음")
|
||||
return
|
||||
|
||||
webhook_update_task(task_id, "processing", 15, "Kling 작업 등록됨")
|
||||
|
||||
# 폴링 — GET /v1/videos/{text2video|image2video}/{task_id}
|
||||
video_url = None
|
||||
for attempt in range(POLL_MAX_ATTEMPTS):
|
||||
time.sleep(POLL_INTERVAL)
|
||||
fetch = requests.get(f"{KLING_BASE_URL}{endpoint_path}/{kling_task_id}",
|
||||
headers=_headers(), timeout=30)
|
||||
if fetch.status_code != 200:
|
||||
continue
|
||||
fd = fetch.json()
|
||||
if fd.get("code") != 0:
|
||||
continue
|
||||
data = fd.get("data") or {}
|
||||
status = data.get("task_status", "")
|
||||
scaled = min(15 + int((attempt / POLL_MAX_ATTEMPTS) * 65), 79)
|
||||
webhook_update_task(task_id, "processing", scaled, f"Kling 생성 중... ({status})")
|
||||
|
||||
if status == "succeed":
|
||||
videos = ((data.get("task_result") or {}).get("videos") or [])
|
||||
if videos:
|
||||
video_url = videos[0].get("url", "")
|
||||
break
|
||||
elif status == "failed":
|
||||
err = data.get("task_status_msg") or "Kling 작업 실패"
|
||||
webhook_update_task(task_id, "failed", 0, "", error=err)
|
||||
return
|
||||
# submitted/processing → 계속 폴링
|
||||
else:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Kling 폴링 timeout (10분)")
|
||||
return
|
||||
|
||||
if not video_url:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Kling 완료했으나 video url 없음")
|
||||
return
|
||||
|
||||
webhook_update_task(task_id, "processing", 85, "Kling 결과 다운로드 중...")
|
||||
filename = f"{task_id}.mp4"
|
||||
os.makedirs(VIDEO_MEDIA_ROOT, exist_ok=True)
|
||||
file_path = os.path.join(VIDEO_MEDIA_ROOT, filename)
|
||||
|
||||
# Kling 결과 url은 일반적으로 인증 불필요 (signed URL)
|
||||
dl = requests.get(video_url, stream=True, timeout=300)
|
||||
dl.raise_for_status()
|
||||
with open(file_path, "wb") as f:
|
||||
for chunk in dl.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
local_url = f"{VIDEO_MEDIA_URL_PREFIX}/{filename}"
|
||||
webhook_update_task(task_id, "succeeded", 100, "Kling 생성 완료", video_url=local_url)
|
||||
|
||||
except requests.Timeout:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Kling API 타임아웃")
|
||||
except Exception as e:
|
||||
logger.exception("Kling generation error task=%s", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
121
services/video-render/providers/seedance.py
Normal file
121
services/video-render/providers/seedance.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Seedance 2.0 video generation — ByteDance Volcano Engine (BytePlus 국제 endpoint).
|
||||
|
||||
POST https://api.byteplus.com/seedance/v1/videos → GET /videos/{id} 폴링 → output.video_url 다운로드.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
from nas_client import webhook_update_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SEEDANCE_BASE_URL = "https://api.byteplus.com/seedance/v1"
|
||||
VIDEO_MEDIA_ROOT = os.getenv("VIDEO_MEDIA_ROOT", "/mnt/nas/webpage/data/video")
|
||||
VIDEO_MEDIA_URL_PREFIX = os.getenv("VIDEO_MEDIA_URL_PREFIX", "/media/video")
|
||||
|
||||
POLL_INTERVAL = 8 # Seedance는 30~120초
|
||||
POLL_MAX_ATTEMPTS = 60
|
||||
|
||||
DEFAULT_MODEL = "seedance-2.0"
|
||||
|
||||
|
||||
def _headers() -> dict:
|
||||
api_key = os.getenv("SEEDANCE_API_KEY", "")
|
||||
return {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
def run_seedance_generation(task_id: str, params: dict) -> None:
|
||||
"""Seedance로 영상 생성 → mp4 → NAS SMB → webhook."""
|
||||
try:
|
||||
if not os.getenv("SEEDANCE_API_KEY"):
|
||||
webhook_update_task(task_id, "failed", 0, "", error="SEEDANCE_API_KEY 미설정")
|
||||
return
|
||||
|
||||
webhook_update_task(task_id, "processing", 5, "Seedance API 호출 중...")
|
||||
|
||||
body = {
|
||||
"model": params.get("model") or DEFAULT_MODEL,
|
||||
"prompt": params["prompt"][:2000],
|
||||
"resolution": params.get("resolution", "1080p"),
|
||||
"duration": params.get("duration", 5),
|
||||
"aspect_ratio": params.get("aspect_ratio", "16:9"),
|
||||
}
|
||||
if params.get("negative_prompt"):
|
||||
body["negative_prompt"] = params["negative_prompt"]
|
||||
if params.get("image_url"):
|
||||
body["references"] = [{"type": "image", "data": params["image_url"], "role": "subject"}]
|
||||
if params.get("audio") is not None:
|
||||
body["audio"] = bool(params["audio"])
|
||||
if params.get("seed") is not None:
|
||||
body["seed"] = int(params["seed"])
|
||||
|
||||
resp = requests.post(f"{SEEDANCE_BASE_URL}/videos", headers=_headers(), json=body, timeout=30)
|
||||
if resp.status_code not in (200, 201):
|
||||
webhook_update_task(task_id, "failed", 0, "",
|
||||
error=f"Seedance API 오류: {resp.status_code} {resp.text[:300]}")
|
||||
return
|
||||
|
||||
body_json = resp.json()
|
||||
job_id = body_json.get("id", "")
|
||||
if not job_id:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Seedance 응답에 id 없음")
|
||||
return
|
||||
|
||||
webhook_update_task(task_id, "processing", 15, "Seedance 작업 등록됨")
|
||||
|
||||
# 폴링
|
||||
video_url = None
|
||||
for attempt in range(POLL_MAX_ATTEMPTS):
|
||||
time.sleep(POLL_INTERVAL)
|
||||
fetch = requests.get(f"{SEEDANCE_BASE_URL}/videos/{job_id}",
|
||||
headers=_headers(), timeout=30)
|
||||
if fetch.status_code != 200:
|
||||
continue
|
||||
fd = fetch.json()
|
||||
status = fd.get("status", "")
|
||||
scaled = min(15 + int((attempt / POLL_MAX_ATTEMPTS) * 65), 79)
|
||||
webhook_update_task(task_id, "processing", scaled, f"Seedance 생성 중... ({status})")
|
||||
|
||||
if status == "completed":
|
||||
video_url = (fd.get("output") or {}).get("video_url", "")
|
||||
break
|
||||
elif status == "failed":
|
||||
err = fd.get("error") or "Seedance 작업 실패"
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(err)[:300])
|
||||
return
|
||||
else:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Seedance 폴링 timeout (10분)")
|
||||
return
|
||||
|
||||
if not video_url:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Seedance 완료했으나 video_url 없음")
|
||||
return
|
||||
|
||||
webhook_update_task(task_id, "processing", 85, "Seedance 결과 다운로드 중...")
|
||||
filename = f"{task_id}.mp4"
|
||||
os.makedirs(VIDEO_MEDIA_ROOT, exist_ok=True)
|
||||
file_path = os.path.join(VIDEO_MEDIA_ROOT, filename)
|
||||
|
||||
dl = requests.get(video_url, stream=True, timeout=300)
|
||||
dl.raise_for_status()
|
||||
with open(file_path, "wb") as f:
|
||||
for chunk in dl.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
local_url = f"{VIDEO_MEDIA_URL_PREFIX}/{filename}"
|
||||
webhook_update_task(task_id, "succeeded", 100, "Seedance 생성 완료", video_url=local_url)
|
||||
|
||||
except requests.Timeout:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Seedance API 타임아웃")
|
||||
except Exception as e:
|
||||
logger.exception("Seedance generation error task=%s", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
119
services/video-render/providers/sora.py
Normal file
119
services/video-render/providers/sora.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Sora 2 video generation — OpenAI Videos API.
|
||||
|
||||
POST /v1/videos → poll GET /v1/videos/{id} → GET /v1/videos/{id}/content download.
|
||||
⚠️ Deprecated, shutdown 2026-09-24. Spec 진행은 박재오 결정 따름.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
from nas_client import webhook_update_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SORA_BASE_URL = "https://api.openai.com/v1"
|
||||
VIDEO_MEDIA_ROOT = os.getenv("VIDEO_MEDIA_ROOT", "/mnt/nas/webpage/data/video")
|
||||
VIDEO_MEDIA_URL_PREFIX = os.getenv("VIDEO_MEDIA_URL_PREFIX", "/media/video")
|
||||
|
||||
POLL_INTERVAL = 15 # OpenAI 권장: 10~20초
|
||||
POLL_MAX_ATTEMPTS = 40 # 최대 ~10분
|
||||
|
||||
DEFAULT_MODEL = "sora-2"
|
||||
|
||||
|
||||
def _headers() -> dict:
|
||||
api_key = os.getenv("OPENAI_API_KEY", "")
|
||||
return {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
def run_sora_generation(task_id: str, params: dict) -> None:
|
||||
"""Sora 2로 영상 생성 → mp4 → NAS SMB 저장 → webhook."""
|
||||
try:
|
||||
if not os.getenv("OPENAI_API_KEY"):
|
||||
webhook_update_task(task_id, "failed", 0, "", error="OPENAI_API_KEY 미설정 (Windows .env)")
|
||||
return
|
||||
|
||||
webhook_update_task(task_id, "processing", 5, "Sora API 호출 중...")
|
||||
|
||||
payload = {
|
||||
"model": params.get("model") or DEFAULT_MODEL,
|
||||
"prompt": params["prompt"][:5000],
|
||||
}
|
||||
if params.get("duration"):
|
||||
payload["seconds"] = params["duration"]
|
||||
if params.get("size"):
|
||||
payload["size"] = params["size"]
|
||||
elif params.get("aspect_ratio") == "9:16":
|
||||
payload["size"] = "1080x1920"
|
||||
elif params.get("aspect_ratio") == "16:9":
|
||||
payload["size"] = "1920x1080"
|
||||
|
||||
resp = requests.post(f"{SORA_BASE_URL}/videos", headers=_headers(), json=payload, timeout=30)
|
||||
if resp.status_code not in (200, 201):
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Sora API 오류: {resp.status_code} {resp.text[:300]}")
|
||||
return
|
||||
|
||||
body = resp.json()
|
||||
video_id = body.get("id", "")
|
||||
if not video_id:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Sora 응답에 video id 없음")
|
||||
return
|
||||
|
||||
webhook_update_task(task_id, "processing", 15, f"Sora 작업 생성됨 (id={video_id[:16]})")
|
||||
|
||||
# 폴링
|
||||
for attempt in range(POLL_MAX_ATTEMPTS):
|
||||
time.sleep(POLL_INTERVAL)
|
||||
sr = requests.get(f"{SORA_BASE_URL}/videos/{video_id}", headers=_headers(), timeout=30)
|
||||
if sr.status_code != 200:
|
||||
continue
|
||||
sd = sr.json()
|
||||
status = sd.get("status", "")
|
||||
progress = sd.get("progress", 0)
|
||||
scaled = min(15 + int(progress * 0.65), 79)
|
||||
webhook_update_task(task_id, "processing", scaled, f"Sora 생성 중... {progress}%")
|
||||
|
||||
if status == "completed":
|
||||
break
|
||||
elif status == "failed":
|
||||
err = sd.get("error", {}).get("message", "Sora 작업 실패")
|
||||
webhook_update_task(task_id, "failed", 0, "", error=err)
|
||||
return
|
||||
else:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Sora 폴링 timeout (10분)")
|
||||
return
|
||||
|
||||
# 다운로드
|
||||
webhook_update_task(task_id, "processing", 80, "Sora 결과 다운로드 중...")
|
||||
filename = f"{task_id}.mp4"
|
||||
os.makedirs(VIDEO_MEDIA_ROOT, exist_ok=True)
|
||||
file_path = os.path.join(VIDEO_MEDIA_ROOT, filename)
|
||||
|
||||
dl = requests.get(
|
||||
f"{SORA_BASE_URL}/videos/{video_id}/content",
|
||||
headers=_headers(),
|
||||
params={"variant": "video"},
|
||||
stream=True,
|
||||
timeout=300,
|
||||
)
|
||||
dl.raise_for_status()
|
||||
with open(file_path, "wb") as f:
|
||||
for chunk in dl.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
local_url = f"{VIDEO_MEDIA_URL_PREFIX}/{filename}"
|
||||
webhook_update_task(task_id, "succeeded", 100, "Sora 생성 완료", video_url=local_url)
|
||||
|
||||
except requests.Timeout:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Sora API 타임아웃")
|
||||
except Exception as e:
|
||||
logger.exception("Sora generation error task=%s", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
139
services/video-render/providers/veo.py
Normal file
139
services/video-render/providers/veo.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Veo 3.1 video generation — Gemini API (ai.google.dev).
|
||||
|
||||
POST https://generativelanguage.googleapis.com/v1beta/models/{MODEL}:predictLongRunning
|
||||
GET https://generativelanguage.googleapis.com/v1beta/{operation_name}
|
||||
→ done=true 시 response.generateVideoResponse.generatedSamples[0].video.uri 다운로드
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
from nas_client import webhook_update_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
|
||||
VIDEO_MEDIA_ROOT = os.getenv("VIDEO_MEDIA_ROOT", "/mnt/nas/webpage/data/video")
|
||||
VIDEO_MEDIA_URL_PREFIX = os.getenv("VIDEO_MEDIA_URL_PREFIX", "/media/video")
|
||||
|
||||
POLL_INTERVAL = 10 # Veo는 30~120초 소요
|
||||
POLL_MAX_ATTEMPTS = 60 # 최대 ~10분
|
||||
|
||||
DEFAULT_MODEL = "veo-3.1-fast-generate-preview"
|
||||
|
||||
|
||||
def _headers() -> dict:
|
||||
api_key = os.getenv("GEMINI_API_KEY", "")
|
||||
return {
|
||||
"x-goog-api-key": api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
def run_veo_generation(task_id: str, params: dict) -> None:
|
||||
"""Veo로 영상 생성 → mp4 → NAS SMB → webhook."""
|
||||
try:
|
||||
if not os.getenv("GEMINI_API_KEY"):
|
||||
webhook_update_task(task_id, "failed", 0, "", error="GEMINI_API_KEY 미설정 (Windows .env)")
|
||||
return
|
||||
|
||||
webhook_update_task(task_id, "processing", 5, "Veo (Gemini API) 호출 중...")
|
||||
|
||||
model_id = params.get("model") or DEFAULT_MODEL
|
||||
body = {
|
||||
"instances": [{"prompt": params["prompt"]}],
|
||||
"parameters": {
|
||||
"aspectRatio": params.get("aspect_ratio") or "16:9",
|
||||
},
|
||||
}
|
||||
# numberOfVideos는 일부 모델(veo-3.0-fast 등) 미지원 — 호출자 명시 시에만 추가
|
||||
if params.get("number_of_videos"):
|
||||
body["parameters"]["numberOfVideos"] = int(params["number_of_videos"])
|
||||
if params.get("duration"):
|
||||
body["parameters"]["durationSeconds"] = int(params["duration"])
|
||||
if params.get("resolution"):
|
||||
body["parameters"]["resolution"] = params["resolution"]
|
||||
if params.get("negative_prompt"):
|
||||
body["parameters"]["negativePrompt"] = params["negative_prompt"]
|
||||
if params.get("person_generation"):
|
||||
body["parameters"]["personGeneration"] = params["person_generation"]
|
||||
|
||||
resp = requests.post(
|
||||
f"{GEMINI_BASE_URL}/models/{model_id}:predictLongRunning",
|
||||
headers=_headers(), json=body, timeout=30,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "",
|
||||
error=f"Veo Gemini API 오류: {resp.status_code} {resp.text[:300]}")
|
||||
return
|
||||
|
||||
op_name = resp.json().get("name", "")
|
||||
if not op_name:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Veo 응답에 operation name 없음")
|
||||
return
|
||||
|
||||
webhook_update_task(task_id, "processing", 15, "Veo 작업 시작됨")
|
||||
|
||||
# 폴링 — GET /v1beta/{operation_name}
|
||||
video_uri = None
|
||||
for attempt in range(POLL_MAX_ATTEMPTS):
|
||||
time.sleep(POLL_INTERVAL)
|
||||
fetch = requests.get(
|
||||
f"{GEMINI_BASE_URL}/{op_name}",
|
||||
headers=_headers(),
|
||||
timeout=30,
|
||||
)
|
||||
if fetch.status_code != 200:
|
||||
continue
|
||||
fd = fetch.json()
|
||||
done = fd.get("done", False)
|
||||
scaled = min(15 + int((attempt / POLL_MAX_ATTEMPTS) * 65), 79)
|
||||
webhook_update_task(task_id, "processing", scaled, "Veo 생성 중...")
|
||||
|
||||
if done:
|
||||
if "error" in fd:
|
||||
webhook_update_task(task_id, "failed", 0, "",
|
||||
error=f"Veo 작업 실패: {fd['error'].get('message','?')}")
|
||||
return
|
||||
# response.generateVideoResponse.generatedSamples[0].video.uri
|
||||
response = fd.get("response") or {}
|
||||
gen = response.get("generateVideoResponse") or {}
|
||||
samples = gen.get("generatedSamples") or []
|
||||
if not samples:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Veo 완료했으나 generatedSamples 비어 있음")
|
||||
return
|
||||
video_uri = (samples[0].get("video") or {}).get("uri", "")
|
||||
break
|
||||
else:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Veo 폴링 timeout (10분)")
|
||||
return
|
||||
|
||||
if not video_uri:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Veo 응답에 video.uri 없음")
|
||||
return
|
||||
|
||||
webhook_update_task(task_id, "processing", 85, "Veo 결과 다운로드 중...")
|
||||
filename = f"{task_id}.mp4"
|
||||
os.makedirs(VIDEO_MEDIA_ROOT, exist_ok=True)
|
||||
file_path = os.path.join(VIDEO_MEDIA_ROOT, filename)
|
||||
|
||||
# 다운로드 — x-goog-api-key 헤더 그대로 사용 (Gemini API가 인증 처리)
|
||||
dl = requests.get(video_uri, headers=_headers(), stream=True, timeout=300)
|
||||
dl.raise_for_status()
|
||||
with open(file_path, "wb") as f:
|
||||
for chunk in dl.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
local_url = f"{VIDEO_MEDIA_URL_PREFIX}/{filename}"
|
||||
webhook_update_task(task_id, "succeeded", 100, "Veo 생성 완료", video_url=local_url)
|
||||
|
||||
except requests.Timeout:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Veo API 타임아웃")
|
||||
except Exception as e:
|
||||
logger.exception("Veo generation error task=%s", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
10
services/video-render/requirements.txt
Normal file
10
services/video-render/requirements.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
fastapi==0.115.6
|
||||
uvicorn[standard]==0.34.0
|
||||
requests==2.32.3
|
||||
redis>=5.0
|
||||
httpx>=0.27
|
||||
openai>=1.50.0
|
||||
PyJWT>=2.8.0
|
||||
pytest>=8.0
|
||||
pytest-asyncio>=0.24
|
||||
respx>=0.21
|
||||
0
services/video-render/tests/__init__.py
Normal file
0
services/video-render/tests/__init__.py
Normal file
70
services/video-render/tests/test_nas_client.py
Normal file
70
services/video-render/tests/test_nas_client.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""nas_client — webhook adapter for video-render."""
|
||||
import pytest
|
||||
import respx
|
||||
import httpx
|
||||
|
||||
from nas_client import webhook_update_task
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _env(monkeypatch):
|
||||
monkeypatch.setenv("NAS_BASE_URL", "http://nas-test:18801")
|
||||
monkeypatch.setenv("INTERNAL_API_KEY", "test-key")
|
||||
|
||||
|
||||
@respx.mock
|
||||
def test_webhook_update_task_sends_x_internal_key():
|
||||
route = respx.post("http://nas-test:18801/api/internal/video/update").mock(
|
||||
return_value=httpx.Response(200, json={"ok": True})
|
||||
)
|
||||
webhook_update_task("task-1", "processing", 30, message="downloading")
|
||||
assert route.called
|
||||
req = route.calls[0].request
|
||||
assert req.headers["X-Internal-Key"] == "test-key"
|
||||
import json
|
||||
body = json.loads(req.content)
|
||||
assert body["task_id"] == "task-1"
|
||||
assert body["status"] == "processing"
|
||||
assert body["progress"] == 30
|
||||
|
||||
|
||||
@respx.mock
|
||||
def test_webhook_update_task_with_video_url():
|
||||
route = respx.post("http://nas-test:18801/api/internal/video/update").mock(
|
||||
return_value=httpx.Response(200, json={"ok": True})
|
||||
)
|
||||
webhook_update_task("task-2", "succeeded", 100, message="완료",
|
||||
video_url="/media/video/task-2.mp4")
|
||||
import json
|
||||
payload = json.loads(route.calls[0].request.content)
|
||||
assert payload["video_url"] == "/media/video/task-2.mp4"
|
||||
assert payload["status"] == "succeeded"
|
||||
|
||||
|
||||
@respx.mock
|
||||
def test_webhook_update_task_with_error():
|
||||
route = respx.post("http://nas-test:18801/api/internal/video/update").mock(
|
||||
return_value=httpx.Response(200, json={"ok": True})
|
||||
)
|
||||
webhook_update_task("task-3", "failed", 0, error="Sora API rate limit")
|
||||
import json
|
||||
payload = json.loads(route.calls[0].request.content)
|
||||
assert payload["error"] == "Sora API rate limit"
|
||||
|
||||
|
||||
@respx.mock
|
||||
def test_webhook_swallows_network_error(caplog):
|
||||
respx.post("http://nas-test:18801/api/internal/video/update").mock(
|
||||
side_effect=httpx.ConnectError("no host")
|
||||
)
|
||||
webhook_update_task("task-5", "processing", 10)
|
||||
assert "task-5" in caplog.text
|
||||
|
||||
|
||||
@respx.mock
|
||||
def test_webhook_swallows_non_200(caplog):
|
||||
respx.post("http://nas-test:18801/api/internal/video/update").mock(
|
||||
return_value=httpx.Response(500, text="server error")
|
||||
)
|
||||
webhook_update_task("task-6", "processing", 50)
|
||||
assert "task-6" in caplog.text
|
||||
43
services/video-render/tests/test_worker.py
Normal file
43
services/video-render/tests/test_worker.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""worker.py — job_type 디스패처 (4 provider)."""
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
import worker
|
||||
|
||||
|
||||
def test_dispatch_sora_calls_run_sora_generation():
|
||||
payload = {"task_id": "t1", "job_type": "sora_generation", "params": {"prompt": "x"}}
|
||||
with patch("worker.run_sora_generation") as m:
|
||||
worker._dispatch(payload)
|
||||
m.assert_called_once_with("t1", {"prompt": "x"})
|
||||
|
||||
|
||||
def test_dispatch_veo_calls_run_veo_generation():
|
||||
payload = {"task_id": "t2", "job_type": "veo_generation", "params": {"prompt": "x"}}
|
||||
with patch("worker.run_veo_generation") as m:
|
||||
worker._dispatch(payload)
|
||||
m.assert_called_once_with("t2", {"prompt": "x"})
|
||||
|
||||
|
||||
def test_dispatch_kling_calls_run_kling_generation():
|
||||
payload = {"task_id": "t3", "job_type": "kling_generation", "params": {"prompt": "x"}}
|
||||
with patch("worker.run_kling_generation") as m:
|
||||
worker._dispatch(payload)
|
||||
m.assert_called_once_with("t3", {"prompt": "x"})
|
||||
|
||||
|
||||
def test_dispatch_seedance_calls_run_seedance_generation():
|
||||
payload = {"task_id": "t4", "job_type": "seedance_generation", "params": {"prompt": "x"}}
|
||||
with patch("worker.run_seedance_generation") as m:
|
||||
worker._dispatch(payload)
|
||||
m.assert_called_once_with("t4", {"prompt": "x"})
|
||||
|
||||
|
||||
def test_dispatch_unknown_job_type_logs_error():
|
||||
payload = {"task_id": "t5", "job_type": "weird_type", "params": {}}
|
||||
with patch("worker.webhook_update_task") as m:
|
||||
worker._dispatch(payload)
|
||||
m.assert_called_once()
|
||||
args = m.call_args[0]
|
||||
assert args[0] == "t5"
|
||||
assert args[1] == "failed"
|
||||
80
services/video-render/worker.py
Normal file
80
services/video-render/worker.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Redis BLPOP worker — queue:video-render → job_type 디스패치 → NAS webhook.
|
||||
|
||||
queue:paused 가 set이면 대기 (task-watcher가 박재오 활동 감지 시 set).
|
||||
Plan-B-Music worker.py 패턴 — string-based dispatch + getattr (테스트 patch 호환).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from nas_client import webhook_update_task
|
||||
from providers.sora import run_sora_generation
|
||||
from providers.veo import run_veo_generation
|
||||
from providers.kling import run_kling_generation
|
||||
from providers.seedance import run_seedance_generation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379")
|
||||
QUEUE_KEY = "queue:video-render"
|
||||
PAUSED_KEY = "queue:paused"
|
||||
|
||||
# string names so `unittest.mock.patch` on `worker.<name>` is correctly intercepted
|
||||
_DISPATCH_TABLE = {
|
||||
"sora_generation": "run_sora_generation",
|
||||
"veo_generation": "run_veo_generation",
|
||||
"kling_generation": "run_kling_generation",
|
||||
"seedance_generation": "run_seedance_generation",
|
||||
}
|
||||
|
||||
|
||||
def _dispatch(payload: dict) -> None:
|
||||
"""payload[job_type] → provider 함수 호출 (sync, worker_loop에서 asyncio.to_thread로 wrap)."""
|
||||
job_type = payload.get("job_type", "")
|
||||
task_id = payload.get("task_id", "")
|
||||
params = payload.get("params", {})
|
||||
fn_name = _DISPATCH_TABLE.get(job_type)
|
||||
if fn_name is None:
|
||||
logger.error("unknown job_type=%s task=%s", job_type, task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"unknown job_type: {job_type}")
|
||||
return
|
||||
try:
|
||||
fn = getattr(sys.modules[__name__], fn_name)
|
||||
except AttributeError:
|
||||
logger.error("dispatch table typo for job_type=%s name=%s task=%s", job_type, fn_name, task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"internal dispatch error: {fn_name}")
|
||||
return
|
||||
fn(task_id, params)
|
||||
|
||||
|
||||
async def worker_loop():
|
||||
redis = aioredis.from_url(REDIS_URL, decode_responses=False)
|
||||
logger.info("video-render worker started (queue=%s)", QUEUE_KEY)
|
||||
while True:
|
||||
try:
|
||||
paused = await redis.get(PAUSED_KEY)
|
||||
if paused == b"1":
|
||||
await asyncio.sleep(10)
|
||||
continue
|
||||
item = await redis.blpop(QUEUE_KEY, timeout=1)
|
||||
if item is None:
|
||||
continue
|
||||
_, raw = item
|
||||
try:
|
||||
payload = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("invalid queue payload: %r", raw[:200])
|
||||
continue
|
||||
await asyncio.to_thread(_dispatch, payload)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("worker_loop cancelled")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("worker_loop iteration 실패, 5초 후 재시도")
|
||||
await asyncio.sleep(5)
|
||||
Reference in New Issue
Block a user