feat(image-lab): generate/tasks/providers 엔드포인트 (video-lab 복제)
This commit is contained in:
113
image-lab/app/main.py
Normal file
113
image-lab/app/main.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""FastAPI entrypoint for image-lab.
|
||||||
|
|
||||||
|
POST /api/image/generate — provider + prompt → Redis push → task_id
|
||||||
|
GET /api/image/tasks/{id} — DB 조회
|
||||||
|
GET /api/image/providers — 3 provider 메타
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from . import db
|
||||||
|
from .internal_router import router as internal_router
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CORS_ALLOW_ORIGINS = os.getenv("CORS_ALLOW_ORIGINS", "http://localhost:3007,http://localhost:8080")
|
||||||
|
REDIS_URL = os.getenv("REDIS_URL", "redis://redis:6379")
|
||||||
|
redis_client = aioredis.from_url(REDIS_URL, decode_responses=False)
|
||||||
|
|
||||||
|
SUPPORTED_PROVIDERS = {"gpt_image", "nano_banana", "flux"}
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(internal_router)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=[o.strip() for o in CORS_ALLOW_ORIGINS.split(",")],
|
||||||
|
allow_credentials=False,
|
||||||
|
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
|
||||||
|
allow_headers=["Content-Type"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
def on_startup():
|
||||||
|
db.init_db()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
def health():
|
||||||
|
return {"ok": True, "service": "image-lab"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/image/providers")
|
||||||
|
def list_providers():
|
||||||
|
"""3 provider 항상 노출 (key 누락은 worker가 failed 보고)."""
|
||||||
|
return {"providers": [
|
||||||
|
{"id": "gpt_image", "name": "GPT Image 2.0", "models": ["gpt-image-1"],
|
||||||
|
"sizes": ["1024x1024", "1024x1536", "1536x1024"]},
|
||||||
|
{"id": "nano_banana", "name": "Nano Banana (Gemini)", "models": ["gemini-2.5-flash-image"],
|
||||||
|
"sizes": ["1024x1024"]},
|
||||||
|
{"id": "flux", "name": "FLUX (local)", "models": ["flux-schnell", "flux-dev"],
|
||||||
|
"sizes": ["1024x1024", "832x1216", "1216x832"]},
|
||||||
|
]}
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateRequest(BaseModel):
|
||||||
|
provider: str = Field(..., description="gpt_image|nano_banana|flux")
|
||||||
|
model: Optional[str] = None
|
||||||
|
prompt: str
|
||||||
|
size: Optional[str] = None
|
||||||
|
negative_prompt: Optional[str] = None
|
||||||
|
# Provider 별 추가 키는 extra 허용
|
||||||
|
extra: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "allow"
|
||||||
|
|
||||||
|
|
||||||
|
async def _push_render_job(task_id: str, job_type: str, params: dict) -> None:
|
||||||
|
"""Redis queue:image-render에 push."""
|
||||||
|
kst = timezone(timedelta(hours=9))
|
||||||
|
payload = {
|
||||||
|
"task_id": task_id,
|
||||||
|
"kind": "image",
|
||||||
|
"job_type": job_type,
|
||||||
|
"params": params,
|
||||||
|
"submitted_at": datetime.now(kst).isoformat(),
|
||||||
|
}
|
||||||
|
await redis_client.rpush("queue:image-render", json.dumps(payload))
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/image/generate")
|
||||||
|
async def generate_image(req: GenerateRequest):
|
||||||
|
"""이미지 생성 — Redis 큐로 Windows image-render에 위임."""
|
||||||
|
if req.provider not in SUPPORTED_PROVIDERS:
|
||||||
|
raise HTTPException(400, f"지원하지 않는 provider: {req.provider} (supported: {sorted(SUPPORTED_PROVIDERS)})")
|
||||||
|
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
params = req.model_dump(exclude_none=True)
|
||||||
|
db.create_task(task_id, req.provider, params)
|
||||||
|
|
||||||
|
job_type = f"{req.provider}_generation" # gpt_image_generation, nano_banana_generation, flux_generation
|
||||||
|
await _push_render_job(task_id, job_type, params)
|
||||||
|
return {"task_id": task_id, "provider": req.provider}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/image/tasks/{task_id}")
|
||||||
|
def get_task_status(task_id: str):
|
||||||
|
t = db.get_task(task_id)
|
||||||
|
if not t:
|
||||||
|
raise HTTPException(404, "task not found")
|
||||||
|
return t
|
||||||
43
image-lab/tests/test_main.py
Normal file
43
image-lab/tests/test_main.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import os, tempfile, importlib
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
|
def _client(monkeypatch, tmp):
|
||||||
|
monkeypatch.setenv("IMAGE_DATA_DIR", tmp)
|
||||||
|
import app.db as db
|
||||||
|
importlib.reload(db)
|
||||||
|
db.init_db()
|
||||||
|
import app.main as main
|
||||||
|
importlib.reload(main)
|
||||||
|
pushed = []
|
||||||
|
|
||||||
|
async def fake_push(task_id, job_type, params):
|
||||||
|
pushed.append((task_id, job_type, params))
|
||||||
|
|
||||||
|
monkeypatch.setattr(main, "_push_render_job", fake_push)
|
||||||
|
return TestClient(main.app), db, pushed
|
||||||
|
|
||||||
|
|
||||||
|
def test_providers_lists_three(monkeypatch):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
client, _, _ = _client(monkeypatch, tmp)
|
||||||
|
r = client.get("/api/image/providers")
|
||||||
|
ids = {p["id"] for p in r.json()["providers"]}
|
||||||
|
assert ids == {"gpt_image", "nano_banana", "flux"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_rejects_unknown_provider(monkeypatch):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
client, _, _ = _client(monkeypatch, tmp)
|
||||||
|
r = client.post("/api/image/generate", json={"provider": "midjourney", "prompt": "x"})
|
||||||
|
assert r.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_creates_task_and_pushes(monkeypatch):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
client, db, pushed = _client(monkeypatch, tmp)
|
||||||
|
r = client.post("/api/image/generate", json={"provider": "gpt_image", "prompt": "a cat"})
|
||||||
|
assert r.status_code == 200
|
||||||
|
task_id = r.json()["task_id"]
|
||||||
|
assert db.get_task(task_id)["status"] == "queued"
|
||||||
|
assert pushed[0][1] == "gpt_image_generation"
|
||||||
Reference in New Issue
Block a user