diff --git a/services/image-render/providers/flux.py b/services/image-render/providers/flux.py new file mode 100644 index 0000000..b07b408 --- /dev/null +++ b/services/image-render/providers/flux.py @@ -0,0 +1,79 @@ +"""FLUX 로컬 — ComfyUI HTTP API. + +POST {COMFYUI_URL}/prompt (workflow JSON) → prompt_id +GET {COMFYUI_URL}/history/{prompt_id} → outputs → image filename +GET {COMFYUI_URL}/view?filename=... → PNG bytes → b64 + +워크플로우 JSON은 `flux_workflow.json` (ComfyUI UI에서 "Save (API Format)"로 export, CLIPTextEncode 노드 text를 "%PROMPT%"로 수동 치환). 박재오 산출물. +""" +from __future__ import annotations + +import base64, json, logging, os, time +from datetime import datetime, timezone, timedelta +import requests + +from nas_client import webhook_update_task +from providers._media import save_b64_png + +logger = logging.getLogger(__name__) +COMFYUI_URL = os.getenv("COMFYUI_URL", "http://127.0.0.1:8188") +WORKFLOW_PATH = os.path.join(os.path.dirname(__file__), "flux_workflow.json") +POLL_INTERVAL = 2 +POLL_MAX = 120 + + +def _is_trading_hours() -> bool: + kst = timezone(timedelta(hours=9)) + now = datetime.now(kst) + if now.weekday() >= 5: + return False + return (now.hour, now.minute) >= (9, 0) and (now.hour, now.minute) <= (15, 30) + + +def _load_workflow(prompt: str, size: str) -> dict: + with open(WORKFLOW_PATH, encoding="utf-8") as f: + wf = json.load(f) + # CLIPTextEncode 노드의 text를 prompt로 치환 (workflow에 "%PROMPT%" placeholder 사용) + raw = json.dumps(wf).replace("%PROMPT%", prompt.replace('"', "'")) + return json.loads(raw) + + +def _submit_prompt(workflow: dict) -> str: + r = requests.post(f"{COMFYUI_URL}/prompt", json={"prompt": workflow}, timeout=30) + r.raise_for_status() + return r.json()["prompt_id"] + + +def _poll_image_b64(prompt_id: str): + for _ in range(POLL_MAX): + h = requests.get(f"{COMFYUI_URL}/history/{prompt_id}", timeout=10) + data = h.json().get(prompt_id) + if data and data.get("outputs"): + for node_out in data["outputs"].values(): + for img in node_out.get("images", []): + view = requests.get(f"{COMFYUI_URL}/view", + params={"filename": img["filename"], "subfolder": img.get("subfolder", ""), "type": img.get("type", "output")}, + timeout=30) + view.raise_for_status() + return base64.b64encode(view.content).decode() + time.sleep(POLL_INTERVAL) + return None + + +def run_flux_generation(task_id: str, params: dict) -> None: + try: + if os.getenv("FLUX_BLOCK_TRADING_HOURS") == "1" and _is_trading_hours(): + webhook_update_task(task_id, "failed", 0, "", error="장중 GPU 보호 — FLUX 거부 (API provider 사용 권장)") + return + webhook_update_task(task_id, "processing", 10, "FLUX (ComfyUI) 생성 중...") + wf = _load_workflow(params["prompt"], params.get("size") or "1024x1024") + pid = _submit_prompt(wf) + b64 = _poll_image_b64(pid) + if not b64: + webhook_update_task(task_id, "failed", 0, "", error="ComfyUI 타임아웃 또는 출력 없음") + return + url = save_b64_png(task_id, b64) + webhook_update_task(task_id, "succeeded", 100, "완료", image_url=url) + except Exception as e: + logger.exception("flux task=%s 실패", task_id) + webhook_update_task(task_id, "failed", 0, "", error=str(e)) diff --git a/services/image-render/tests/test_flux.py b/services/image-render/tests/test_flux.py new file mode 100644 index 0000000..e15c591 --- /dev/null +++ b/services/image-render/tests/test_flux.py @@ -0,0 +1,21 @@ +import providers.flux as fx + +def test_blocked_during_trading_hours(monkeypatch): + monkeypatch.setenv("FLUX_BLOCK_TRADING_HOURS", "1") + monkeypatch.setattr(fx, "_is_trading_hours", lambda: True) + calls = [] + monkeypatch.setattr(fx, "webhook_update_task", lambda *a, **k: calls.append((a, k))) + fx.run_flux_generation("t1", {"prompt": "a cat"}) + assert calls[-1][0][1] == "failed" + assert "장중" in calls[-1][1]["error"] + +def test_success_polls_history_and_saves(monkeypatch): + monkeypatch.setattr(fx, "_is_trading_hours", lambda: False) + calls = [] + monkeypatch.setattr(fx, "webhook_update_task", lambda *a, **k: calls.append((a, k))) + monkeypatch.setattr(fx, "_load_workflow", lambda prompt, size: {"3": {}}) + monkeypatch.setattr(fx, "_submit_prompt", lambda wf: "pid-1") + monkeypatch.setattr(fx, "_poll_image_b64", lambda pid: "ZmFrZQ==") + monkeypatch.setattr(fx, "save_b64_png", lambda tid, b64: "/media/image/t1.png") + fx.run_flux_generation("t1", {"prompt": "a cat"}) + assert [c for c in calls if c[0][1] == "succeeded"]