Compare commits
6 Commits
27a6df6cff
...
cb70226f42
| Author | SHA1 | Date | |
|---|---|---|---|
| cb70226f42 | |||
| de24bae984 | |||
| 0e6c893b4e | |||
| fb80973e38 | |||
| 31b0e7dbc4 | |||
| 6169f48eb8 |
@@ -95,3 +95,30 @@ services:
|
||||
interval: 60s
|
||||
timeout: 5s
|
||||
retries: 3
|
||||
|
||||
image-render:
|
||||
build:
|
||||
context: ./image-render
|
||||
container_name: image-render
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "18714:8000"
|
||||
environment:
|
||||
- TZ=Asia/Seoul
|
||||
- REDIS_URL=${REDIS_URL:-redis://192.168.45.54:6379}
|
||||
- NAS_BASE_URL=${NAS_BASE_URL:-http://192.168.45.54:18802}
|
||||
- INTERNAL_API_KEY=${INTERNAL_API_KEY:-}
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
|
||||
- GEMINI_API_KEY=${GEMINI_API_KEY:-}
|
||||
- COMFYUI_URL=${COMFYUI_URL:-http://host.docker.internal:8188}
|
||||
- FLUX_BLOCK_TRADING_HOURS=${FLUX_BLOCK_TRADING_HOURS:-1}
|
||||
- IMAGE_MEDIA_ROOT=${IMAGE_MEDIA_ROOT:-/mnt/nas/webpage/data/image}
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- /mnt/nas/webpage/data/image:/mnt/nas/webpage/data/image
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
|
||||
interval: 60s
|
||||
timeout: 5s
|
||||
retries: 3
|
||||
|
||||
16
services/image-render/Dockerfile
Normal file
16
services/image-render/Dockerfile
Normal file
@@ -0,0 +1,16 @@
|
||||
FROM python:3.12-slim-bookworm
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir --timeout 600 --retries 5 -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
EXPOSE 8000
|
||||
CMD ["python", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "1"]
|
||||
18
services/image-render/env.example
Normal file
18
services/image-render/env.example
Normal file
@@ -0,0 +1,18 @@
|
||||
# Redis (NAS)
|
||||
REDIS_URL=redis://192.168.45.54:6379
|
||||
|
||||
# NAS image-lab webhook
|
||||
NAS_BASE_URL=http://192.168.45.54:18802
|
||||
INTERNAL_API_KEY=replace-me
|
||||
|
||||
# API provider keys (worker reports failed if missing)
|
||||
OPENAI_API_KEY=
|
||||
GEMINI_API_KEY=
|
||||
# Seedance key not used by image-render
|
||||
|
||||
# FLUX local
|
||||
COMFYUI_URL=http://host.docker.internal:8188
|
||||
FLUX_BLOCK_TRADING_HOURS=1
|
||||
|
||||
# NAS SMB mount target (image-render writes to this, NAS reads via /media/image/)
|
||||
IMAGE_MEDIA_ROOT=/mnt/nas/webpage/data/image
|
||||
36
services/image-render/main.py
Normal file
36
services/image-render/main.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""image-render FastAPI entry — health + lifespan (worker loop spawn)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
import worker
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
worker_task = asyncio.create_task(worker.worker_loop())
|
||||
logger.info("image-render lifespan 시작")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
worker_task.cancel()
|
||||
try:
|
||||
await worker_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("image-render lifespan 종료")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"ok": True, "service": "image-render"}
|
||||
54
services/image-render/nas_client.py
Normal file
54
services/image-render/nas_client.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""NAS webhook 어댑터 — Windows worker → NAS image-lab HTTP 위임.
|
||||
|
||||
video-render nas_client 복제 (call-time os.getenv으로 테스트 격리).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TIMEOUT = 10.0
|
||||
|
||||
|
||||
def _post(payload: Dict[str, Any]) -> None:
|
||||
nas_base_url = os.getenv("NAS_BASE_URL", "http://192.168.45.54:18802")
|
||||
internal_api_key = os.getenv("INTERNAL_API_KEY", "")
|
||||
url = f"{nas_base_url}/api/internal/image/update"
|
||||
try:
|
||||
r = httpx.post(
|
||||
url,
|
||||
headers={"X-Internal-Key": internal_api_key},
|
||||
json=payload,
|
||||
timeout=_TIMEOUT,
|
||||
)
|
||||
if r.status_code != 200:
|
||||
logger.error("webhook %s returned %d: %s",
|
||||
payload.get("task_id"), r.status_code, r.text[:200])
|
||||
except Exception:
|
||||
logger.exception("webhook %s 호출 실패", payload.get("task_id"))
|
||||
|
||||
|
||||
def webhook_update_task(
|
||||
task_id: str,
|
||||
status: str,
|
||||
progress: int,
|
||||
message: str = "",
|
||||
image_url: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> None:
|
||||
payload: Dict[str, Any] = {
|
||||
"task_id": task_id,
|
||||
"status": status,
|
||||
"progress": progress,
|
||||
"message": message,
|
||||
}
|
||||
if image_url is not None:
|
||||
payload["image_url"] = image_url
|
||||
if error is not None:
|
||||
payload["error"] = error
|
||||
_post(payload)
|
||||
0
services/image-render/providers/__init__.py
Normal file
0
services/image-render/providers/__init__.py
Normal file
18
services/image-render/providers/_media.py
Normal file
18
services/image-render/providers/_media.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""b64 이미지 → NAS SMB 경로 저장 → /media/image URL 반환."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import os
|
||||
import uuid
|
||||
|
||||
IMAGE_MEDIA_ROOT = os.getenv("IMAGE_MEDIA_ROOT", "/mnt/nas/webpage/data/image")
|
||||
IMAGE_MEDIA_URL_PREFIX = os.getenv("IMAGE_MEDIA_URL_PREFIX", "/media/image")
|
||||
|
||||
|
||||
def save_b64_png(task_id: str, b64_data: str) -> str:
|
||||
os.makedirs(IMAGE_MEDIA_ROOT, exist_ok=True)
|
||||
fname = f"{task_id}-{uuid.uuid4().hex[:8]}.png"
|
||||
path = os.path.join(IMAGE_MEDIA_ROOT, fname)
|
||||
with open(path, "wb") as f:
|
||||
f.write(base64.b64decode(b64_data))
|
||||
return f"{IMAGE_MEDIA_URL_PREFIX}/{fname}"
|
||||
79
services/image-render/providers/flux.py
Normal file
79
services/image-render/providers/flux.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""FLUX 로컬 — ComfyUI HTTP API.
|
||||
|
||||
POST {COMFYUI_URL}/prompt (workflow JSON) → prompt_id
|
||||
GET {COMFYUI_URL}/history/{prompt_id} → outputs → image filename
|
||||
GET {COMFYUI_URL}/view?filename=... → PNG bytes → b64
|
||||
|
||||
워크플로우 JSON은 `flux_workflow.json` (ComfyUI UI에서 "Save (API Format)"로 export, CLIPTextEncode 노드 text를 "%PROMPT%"로 수동 치환). 박재오 산출물.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64, json, logging, os, time
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import requests
|
||||
|
||||
from nas_client import webhook_update_task
|
||||
from providers._media import save_b64_png
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
COMFYUI_URL = os.getenv("COMFYUI_URL", "http://127.0.0.1:8188")
|
||||
WORKFLOW_PATH = os.path.join(os.path.dirname(__file__), "flux_workflow.json")
|
||||
POLL_INTERVAL = 2
|
||||
POLL_MAX = 120
|
||||
|
||||
|
||||
def _is_trading_hours() -> bool:
|
||||
kst = timezone(timedelta(hours=9))
|
||||
now = datetime.now(kst)
|
||||
if now.weekday() >= 5:
|
||||
return False
|
||||
return (now.hour, now.minute) >= (9, 0) and (now.hour, now.minute) <= (15, 30)
|
||||
|
||||
|
||||
def _load_workflow(prompt: str, size: str) -> dict:
|
||||
with open(WORKFLOW_PATH, encoding="utf-8") as f:
|
||||
wf = json.load(f)
|
||||
# CLIPTextEncode 노드의 text를 prompt로 치환 (workflow에 "%PROMPT%" placeholder 사용)
|
||||
raw = json.dumps(wf).replace("%PROMPT%", prompt.replace('"', "'"))
|
||||
return json.loads(raw)
|
||||
|
||||
|
||||
def _submit_prompt(workflow: dict) -> str:
|
||||
r = requests.post(f"{COMFYUI_URL}/prompt", json={"prompt": workflow}, timeout=30)
|
||||
r.raise_for_status()
|
||||
return r.json()["prompt_id"]
|
||||
|
||||
|
||||
def _poll_image_b64(prompt_id: str):
|
||||
for _ in range(POLL_MAX):
|
||||
h = requests.get(f"{COMFYUI_URL}/history/{prompt_id}", timeout=10)
|
||||
data = h.json().get(prompt_id)
|
||||
if data and data.get("outputs"):
|
||||
for node_out in data["outputs"].values():
|
||||
for img in node_out.get("images", []):
|
||||
view = requests.get(f"{COMFYUI_URL}/view",
|
||||
params={"filename": img["filename"], "subfolder": img.get("subfolder", ""), "type": img.get("type", "output")},
|
||||
timeout=30)
|
||||
view.raise_for_status()
|
||||
return base64.b64encode(view.content).decode()
|
||||
time.sleep(POLL_INTERVAL)
|
||||
return None
|
||||
|
||||
|
||||
def run_flux_generation(task_id: str, params: dict) -> None:
|
||||
try:
|
||||
if os.getenv("FLUX_BLOCK_TRADING_HOURS") == "1" and _is_trading_hours():
|
||||
webhook_update_task(task_id, "failed", 0, "", error="장중 GPU 보호 — FLUX 거부 (API provider 사용 권장)")
|
||||
return
|
||||
webhook_update_task(task_id, "processing", 10, "FLUX (ComfyUI) 생성 중...")
|
||||
wf = _load_workflow(params["prompt"], params.get("size") or "1024x1024")
|
||||
pid = _submit_prompt(wf)
|
||||
b64 = _poll_image_b64(pid)
|
||||
if not b64:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="ComfyUI 타임아웃 또는 출력 없음")
|
||||
return
|
||||
url = save_b64_png(task_id, b64)
|
||||
webhook_update_task(task_id, "succeeded", 100, "완료", image_url=url)
|
||||
except Exception as e:
|
||||
logger.exception("flux task=%s 실패", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
47
services/image-render/providers/gpt_image.py
Normal file
47
services/image-render/providers/gpt_image.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""GPT Image 2.0 — OpenAI Images API.
|
||||
|
||||
POST https://api.openai.com/v1/images/generations
|
||||
body {model:"gpt-image-1", prompt, size, n:1} → data[0].b64_json
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
from nas_client import webhook_update_task
|
||||
from providers._media import save_b64_png
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
OPENAI_URL = "https://api.openai.com/v1/images/generations"
|
||||
DEFAULT_MODEL = "gpt-image-1"
|
||||
|
||||
|
||||
def run_gpt_image_generation(task_id: str, params: dict) -> None:
|
||||
try:
|
||||
if not os.getenv("OPENAI_API_KEY"):
|
||||
webhook_update_task(task_id, "failed", 0, "", error="OPENAI_API_KEY 미설정 (Windows .env)")
|
||||
return
|
||||
webhook_update_task(task_id, "processing", 10, "GPT Image 호출 중...")
|
||||
body = {
|
||||
"model": params.get("model") or DEFAULT_MODEL,
|
||||
"prompt": params["prompt"],
|
||||
"size": params.get("size") or "1024x1024",
|
||||
"n": 1,
|
||||
}
|
||||
resp = requests.post(
|
||||
OPENAI_URL,
|
||||
headers={"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}", "Content-Type": "application/json"},
|
||||
json=body,
|
||||
timeout=120,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"OpenAI {resp.status_code}: {resp.text[:200]}")
|
||||
return
|
||||
b64 = resp.json()["data"][0]["b64_json"]
|
||||
url = save_b64_png(task_id, b64)
|
||||
webhook_update_task(task_id, "succeeded", 100, "완료", image_url=url)
|
||||
except Exception as e:
|
||||
logger.exception("gpt_image task=%s 실패", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
52
services/image-render/providers/nano_banana.py
Normal file
52
services/image-render/providers/nano_banana.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Nano Banana — Gemini 2.5 Flash Image (generativelanguage API).
|
||||
|
||||
POST /v1beta/models/{MODEL}:generateContent
|
||||
→ candidates[0].content.parts[*].inlineData.data (b64 png)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging, os
|
||||
import requests
|
||||
|
||||
from nas_client import webhook_update_task
|
||||
from providers._media import save_b64_png
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
GEMINI_BASE = "https://generativelanguage.googleapis.com/v1beta"
|
||||
DEFAULT_MODEL = "gemini-2.5-flash-image"
|
||||
|
||||
|
||||
def _extract_b64(data: dict):
|
||||
for cand in data.get("candidates", []):
|
||||
for part in cand.get("content", {}).get("parts", []):
|
||||
inline = part.get("inlineData") or part.get("inline_data")
|
||||
if inline and inline.get("data"):
|
||||
return inline["data"]
|
||||
return None
|
||||
|
||||
|
||||
def run_nano_banana_generation(task_id: str, params: dict) -> None:
|
||||
try:
|
||||
if not os.getenv("GEMINI_API_KEY"):
|
||||
webhook_update_task(task_id, "failed", 0, "", error="GEMINI_API_KEY 미설정 (Windows .env)")
|
||||
return
|
||||
webhook_update_task(task_id, "processing", 10, "Nano Banana (Gemini) 호출 중...")
|
||||
model_id = params.get("model") or DEFAULT_MODEL
|
||||
body = {"contents": [{"parts": [{"text": params["prompt"]}]}]}
|
||||
resp = requests.post(
|
||||
f"{GEMINI_BASE}/models/{model_id}:generateContent",
|
||||
headers={"x-goog-api-key": os.getenv("GEMINI_API_KEY"), "Content-Type": "application/json"},
|
||||
json=body, timeout=120,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Gemini {resp.status_code}: {resp.text[:200]}")
|
||||
return
|
||||
b64 = _extract_b64(resp.json())
|
||||
if not b64:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Gemini 응답에 이미지 없음")
|
||||
return
|
||||
url = save_b64_png(task_id, b64)
|
||||
webhook_update_task(task_id, "succeeded", 100, "완료", image_url=url)
|
||||
except Exception as e:
|
||||
logger.exception("nano_banana task=%s 실패", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
9
services/image-render/requirements.txt
Normal file
9
services/image-render/requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
fastapi==0.115.6
|
||||
uvicorn[standard]==0.34.0
|
||||
requests==2.32.3
|
||||
redis>=5.0
|
||||
httpx>=0.27
|
||||
openai>=1.50.0
|
||||
pytest>=8.0
|
||||
pytest-asyncio>=0.24
|
||||
respx>=0.21
|
||||
0
services/image-render/tests/__init__.py
Normal file
0
services/image-render/tests/__init__.py
Normal file
21
services/image-render/tests/test_flux.py
Normal file
21
services/image-render/tests/test_flux.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import providers.flux as fx
|
||||
|
||||
def test_blocked_during_trading_hours(monkeypatch):
|
||||
monkeypatch.setenv("FLUX_BLOCK_TRADING_HOURS", "1")
|
||||
monkeypatch.setattr(fx, "_is_trading_hours", lambda: True)
|
||||
calls = []
|
||||
monkeypatch.setattr(fx, "webhook_update_task", lambda *a, **k: calls.append((a, k)))
|
||||
fx.run_flux_generation("t1", {"prompt": "a cat"})
|
||||
assert calls[-1][0][1] == "failed"
|
||||
assert "장중" in calls[-1][1]["error"]
|
||||
|
||||
def test_success_polls_history_and_saves(monkeypatch):
|
||||
monkeypatch.setattr(fx, "_is_trading_hours", lambda: False)
|
||||
calls = []
|
||||
monkeypatch.setattr(fx, "webhook_update_task", lambda *a, **k: calls.append((a, k)))
|
||||
monkeypatch.setattr(fx, "_load_workflow", lambda prompt, size: {"3": {}})
|
||||
monkeypatch.setattr(fx, "_submit_prompt", lambda wf: "pid-1")
|
||||
monkeypatch.setattr(fx, "_poll_image_b64", lambda pid: "ZmFrZQ==")
|
||||
monkeypatch.setattr(fx, "save_b64_png", lambda tid, b64: "/media/image/t1.png")
|
||||
fx.run_flux_generation("t1", {"prompt": "a cat"})
|
||||
assert [c for c in calls if c[0][1] == "succeeded"]
|
||||
32
services/image-render/tests/test_gpt_image.py
Normal file
32
services/image-render/tests/test_gpt_image.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import providers.gpt_image as gi
|
||||
|
||||
|
||||
def test_missing_key_reports_failed(monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
calls = []
|
||||
monkeypatch.setattr(gi, "webhook_update_task", lambda *a, **k: calls.append((a, k)))
|
||||
gi.run_gpt_image_generation("t1", {"prompt": "a cat"})
|
||||
# 마지막 호출이 failed
|
||||
assert calls[-1][0][1] == "failed"
|
||||
|
||||
|
||||
def test_success_saves_and_reports_url(monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-test")
|
||||
calls = []
|
||||
monkeypatch.setattr(gi, "webhook_update_task", lambda *a, **k: calls.append((a, k)))
|
||||
monkeypatch.setattr(gi, "save_b64_png", lambda tid, b64: "/media/image/t1.png")
|
||||
|
||||
class FakeResp:
|
||||
status_code = 200
|
||||
|
||||
def json(self):
|
||||
return {"data": [{"b64_json": "ZmFrZQ=="}]}
|
||||
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(gi.requests, "post", lambda *a, **k: FakeResp())
|
||||
|
||||
gi.run_gpt_image_generation("t1", {"prompt": "a cat"})
|
||||
succeeded = [c for c in calls if c[0][1] == "succeeded"]
|
||||
assert succeeded and succeeded[-1][1]["image_url"] == "/media/image/t1.png"
|
||||
25
services/image-render/tests/test_nano_banana.py
Normal file
25
services/image-render/tests/test_nano_banana.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import providers.nano_banana as nb
|
||||
|
||||
def test_missing_key_reports_failed(monkeypatch):
|
||||
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
||||
calls = []
|
||||
monkeypatch.setattr(nb, "webhook_update_task", lambda *a, **k: calls.append((a, k)))
|
||||
nb.run_nano_banana_generation("t1", {"prompt": "a cat"})
|
||||
assert calls[-1][0][1] == "failed"
|
||||
|
||||
def test_success_extracts_inline_data(monkeypatch):
|
||||
monkeypatch.setenv("GEMINI_API_KEY", "g-test")
|
||||
calls = []
|
||||
monkeypatch.setattr(nb, "webhook_update_task", lambda *a, **k: calls.append((a, k)))
|
||||
monkeypatch.setattr(nb, "save_b64_png", lambda tid, b64: "/media/image/t1.png")
|
||||
|
||||
class FakeResp:
|
||||
status_code = 200
|
||||
def json(self):
|
||||
return {"candidates": [{"content": {"parts": [
|
||||
{"inlineData": {"mimeType": "image/png", "data": "ZmFrZQ=="}}
|
||||
]}}]}
|
||||
monkeypatch.setattr(nb.requests, "post", lambda *a, **k: FakeResp())
|
||||
|
||||
nb.run_nano_banana_generation("t1", {"prompt": "a cat"})
|
||||
assert [c for c in calls if c[0][1] == "succeeded"]
|
||||
20
services/image-render/tests/test_nas_client.py
Normal file
20
services/image-render/tests/test_nas_client.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import nas_client
|
||||
|
||||
|
||||
def test_webhook_includes_image_url(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
def fake_post(payload):
|
||||
captured.update(payload)
|
||||
|
||||
monkeypatch.setattr(nas_client, "_post", fake_post)
|
||||
nas_client.webhook_update_task("t1", "succeeded", 100, "done", image_url="/media/image/t1.png")
|
||||
assert captured["task_id"] == "t1"
|
||||
assert captured["image_url"] == "/media/image/t1.png"
|
||||
|
||||
|
||||
def test_webhook_omits_none_fields(monkeypatch):
|
||||
captured = {}
|
||||
monkeypatch.setattr(nas_client, "_post", lambda p: captured.update(p))
|
||||
nas_client.webhook_update_task("t2", "processing", 10, "working")
|
||||
assert "image_url" not in captured and "error" not in captured
|
||||
15
services/image-render/tests/test_worker.py
Normal file
15
services/image-render/tests/test_worker.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import worker
|
||||
|
||||
|
||||
def test_dispatch_routes_to_provider(monkeypatch):
|
||||
called = {}
|
||||
monkeypatch.setattr(worker, "run_gpt_image_generation", lambda tid, p: called.setdefault("gpt", (tid, p)))
|
||||
worker._dispatch({"job_type": "gpt_image_generation", "task_id": "t1", "params": {"prompt": "x"}})
|
||||
assert called["gpt"][0] == "t1"
|
||||
|
||||
|
||||
def test_dispatch_unknown_job_type_reports_failed(monkeypatch):
|
||||
calls = []
|
||||
monkeypatch.setattr(worker, "webhook_update_task", lambda *a, **k: calls.append((a, k)))
|
||||
worker._dispatch({"job_type": "midjourney_generation", "task_id": "t9", "params": {}})
|
||||
assert calls[-1][0][1] == "failed"
|
||||
84
services/image-render/worker.py
Normal file
84
services/image-render/worker.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Redis BLPOP worker — queue:image-render → job_type dispatch → NAS webhook.
|
||||
|
||||
queue:paused 가 set이면 대기 (task-watcher가 박재오 활동 감지 시 set).
|
||||
video-render worker.py 패턴 — string-based dispatch + getattr (테스트 patch 호환).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from nas_client import webhook_update_task
|
||||
from providers.gpt_image import run_gpt_image_generation
|
||||
from providers.nano_banana import run_nano_banana_generation
|
||||
from providers.flux import run_flux_generation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REDIS_URL = os.getenv("REDIS_URL", "redis://192.168.45.54:6379")
|
||||
QUEUE_KEY = "queue:image-render"
|
||||
PAUSED_KEY = "queue:paused"
|
||||
|
||||
# string names so `unittest.mock.patch` / `monkeypatch.setattr` on `worker.<name>`
|
||||
# is correctly intercepted by getattr(sys.modules[__name__], ...)
|
||||
_DISPATCH_TABLE = {
|
||||
"gpt_image_generation": "run_gpt_image_generation",
|
||||
"nano_banana_generation": "run_nano_banana_generation",
|
||||
"flux_generation": "run_flux_generation",
|
||||
}
|
||||
|
||||
|
||||
def _dispatch(payload: dict) -> None:
|
||||
"""payload[job_type] → provider 함수 호출 (sync, worker_loop에서 asyncio.to_thread로 wrap)."""
|
||||
job_type = payload.get("job_type", "")
|
||||
task_id = payload.get("task_id", "")
|
||||
params = payload.get("params", {})
|
||||
fn_name = _DISPATCH_TABLE.get(job_type)
|
||||
if fn_name is None:
|
||||
logger.error("unknown job_type=%s task=%s", job_type, task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"unknown job_type: {job_type}")
|
||||
return
|
||||
try:
|
||||
fn = getattr(sys.modules[__name__], fn_name)
|
||||
except AttributeError:
|
||||
logger.error("dispatch table typo for job_type=%s name=%s task=%s", job_type, fn_name, task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"internal dispatch error: {fn_name}")
|
||||
return
|
||||
fn(task_id, params)
|
||||
|
||||
|
||||
async def worker_loop():
|
||||
redis = aioredis.from_url(REDIS_URL, decode_responses=False)
|
||||
logger.info("image-render worker started (queue=%s)", QUEUE_KEY)
|
||||
while True:
|
||||
try:
|
||||
paused = await redis.get(PAUSED_KEY)
|
||||
if paused == b"1":
|
||||
await asyncio.sleep(10)
|
||||
continue
|
||||
item = await redis.blpop(QUEUE_KEY, timeout=5)
|
||||
if item is None:
|
||||
continue
|
||||
_, raw = item
|
||||
try:
|
||||
payload = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("invalid queue payload: %r", raw[:200])
|
||||
continue
|
||||
await asyncio.to_thread(_dispatch, payload)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("worker_loop cancelled")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("worker_loop iteration 실패, 5초 후 재시도")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
asyncio.run(worker_loop())
|
||||
Reference in New Issue
Block a user