diff --git a/image-lab/app/main.py b/image-lab/app/main.py new file mode 100644 index 0000000..c986f48 --- /dev/null +++ b/image-lab/app/main.py @@ -0,0 +1,113 @@ +"""FastAPI entrypoint for image-lab. + +POST /api/image/generate — provider + prompt → Redis push → task_id +GET /api/image/tasks/{id} — DB 조회 +GET /api/image/providers — 3 provider 메타 +""" +from __future__ import annotations + +import json +import logging +import os +import uuid +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, Optional + +import redis.asyncio as aioredis +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, Field + +from . import db +from .internal_router import router as internal_router + +logger = logging.getLogger(__name__) + +CORS_ALLOW_ORIGINS = os.getenv("CORS_ALLOW_ORIGINS", "http://localhost:3007,http://localhost:8080") +REDIS_URL = os.getenv("REDIS_URL", "redis://redis:6379") +redis_client = aioredis.from_url(REDIS_URL, decode_responses=False) + +SUPPORTED_PROVIDERS = {"gpt_image", "nano_banana", "flux"} + +app = FastAPI() +app.include_router(internal_router) + +app.add_middleware( + CORSMiddleware, + allow_origins=[o.strip() for o in CORS_ALLOW_ORIGINS.split(",")], + allow_credentials=False, + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"], + allow_headers=["Content-Type"], +) + + +@app.on_event("startup") +def on_startup(): + db.init_db() + + +@app.get("/health") +def health(): + return {"ok": True, "service": "image-lab"} + + +@app.get("/api/image/providers") +def list_providers(): + """3 provider 항상 노출 (key 누락은 worker가 failed 보고).""" + return {"providers": [ + {"id": "gpt_image", "name": "GPT Image 2.0", "models": ["gpt-image-1"], + "sizes": ["1024x1024", "1024x1536", "1536x1024"]}, + {"id": "nano_banana", "name": "Nano Banana (Gemini)", "models": ["gemini-2.5-flash-image"], + "sizes": ["1024x1024"]}, + {"id": "flux", "name": "FLUX (local)", "models": ["flux-schnell", "flux-dev"], + "sizes": ["1024x1024", "832x1216", "1216x832"]}, + ]} + + +class GenerateRequest(BaseModel): + provider: str = Field(..., description="gpt_image|nano_banana|flux") + model: Optional[str] = None + prompt: str + size: Optional[str] = None + negative_prompt: Optional[str] = None + # Provider 별 추가 키는 extra 허용 + extra: Optional[Dict[str, Any]] = None + + class Config: + extra = "allow" + + +async def _push_render_job(task_id: str, job_type: str, params: dict) -> None: + """Redis queue:image-render에 push.""" + kst = timezone(timedelta(hours=9)) + payload = { + "task_id": task_id, + "kind": "image", + "job_type": job_type, + "params": params, + "submitted_at": datetime.now(kst).isoformat(), + } + await redis_client.rpush("queue:image-render", json.dumps(payload)) + + +@app.post("/api/image/generate") +async def generate_image(req: GenerateRequest): + """이미지 생성 — Redis 큐로 Windows image-render에 위임.""" + if req.provider not in SUPPORTED_PROVIDERS: + raise HTTPException(400, f"지원하지 않는 provider: {req.provider} (supported: {sorted(SUPPORTED_PROVIDERS)})") + + task_id = str(uuid.uuid4()) + params = req.model_dump(exclude_none=True) + db.create_task(task_id, req.provider, params) + + job_type = f"{req.provider}_generation" # gpt_image_generation, nano_banana_generation, flux_generation + await _push_render_job(task_id, job_type, params) + return {"task_id": task_id, "provider": req.provider} + + +@app.get("/api/image/tasks/{task_id}") +def get_task_status(task_id: str): + t = db.get_task(task_id) + if not t: + raise HTTPException(404, "task not found") + return t diff --git a/image-lab/tests/test_main.py b/image-lab/tests/test_main.py new file mode 100644 index 0000000..ad2f69f --- /dev/null +++ b/image-lab/tests/test_main.py @@ -0,0 +1,43 @@ +import os, tempfile, importlib +from fastapi.testclient import TestClient + + +def _client(monkeypatch, tmp): + monkeypatch.setenv("IMAGE_DATA_DIR", tmp) + import app.db as db + importlib.reload(db) + db.init_db() + import app.main as main + importlib.reload(main) + pushed = [] + + async def fake_push(task_id, job_type, params): + pushed.append((task_id, job_type, params)) + + monkeypatch.setattr(main, "_push_render_job", fake_push) + return TestClient(main.app), db, pushed + + +def test_providers_lists_three(monkeypatch): + with tempfile.TemporaryDirectory() as tmp: + client, _, _ = _client(monkeypatch, tmp) + r = client.get("/api/image/providers") + ids = {p["id"] for p in r.json()["providers"]} + assert ids == {"gpt_image", "nano_banana", "flux"} + + +def test_generate_rejects_unknown_provider(monkeypatch): + with tempfile.TemporaryDirectory() as tmp: + client, _, _ = _client(monkeypatch, tmp) + r = client.post("/api/image/generate", json={"provider": "midjourney", "prompt": "x"}) + assert r.status_code == 400 + + +def test_generate_creates_task_and_pushes(monkeypatch): + with tempfile.TemporaryDirectory() as tmp: + client, db, pushed = _client(monkeypatch, tmp) + r = client.post("/api/image/generate", json={"provider": "gpt_image", "prompt": "a cat"}) + assert r.status_code == 200 + task_id = r.json()["task_id"] + assert db.get_task(task_id)["status"] == "queued" + assert pushed[0][1] == "gpt_image_generation"