feat(image-lab): generate/tasks/providers 엔드포인트 (video-lab 복제)
This commit is contained in:
113
image-lab/app/main.py
Normal file
113
image-lab/app/main.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""FastAPI entrypoint for image-lab.
|
||||
|
||||
POST /api/image/generate — provider + prompt → Redis push → task_id
|
||||
GET /api/image/tasks/{id} — DB 조회
|
||||
GET /api/image/providers — 3 provider 메타
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from . import db
|
||||
from .internal_router import router as internal_router
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CORS_ALLOW_ORIGINS = os.getenv("CORS_ALLOW_ORIGINS", "http://localhost:3007,http://localhost:8080")
|
||||
REDIS_URL = os.getenv("REDIS_URL", "redis://redis:6379")
|
||||
redis_client = aioredis.from_url(REDIS_URL, decode_responses=False)
|
||||
|
||||
SUPPORTED_PROVIDERS = {"gpt_image", "nano_banana", "flux"}
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(internal_router)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[o.strip() for o in CORS_ALLOW_ORIGINS.split(",")],
|
||||
allow_credentials=False,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
|
||||
allow_headers=["Content-Type"],
|
||||
)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def on_startup():
|
||||
db.init_db()
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"ok": True, "service": "image-lab"}
|
||||
|
||||
|
||||
@app.get("/api/image/providers")
|
||||
def list_providers():
|
||||
"""3 provider 항상 노출 (key 누락은 worker가 failed 보고)."""
|
||||
return {"providers": [
|
||||
{"id": "gpt_image", "name": "GPT Image 2.0", "models": ["gpt-image-1"],
|
||||
"sizes": ["1024x1024", "1024x1536", "1536x1024"]},
|
||||
{"id": "nano_banana", "name": "Nano Banana (Gemini)", "models": ["gemini-2.5-flash-image"],
|
||||
"sizes": ["1024x1024"]},
|
||||
{"id": "flux", "name": "FLUX (local)", "models": ["flux-schnell", "flux-dev"],
|
||||
"sizes": ["1024x1024", "832x1216", "1216x832"]},
|
||||
]}
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
provider: str = Field(..., description="gpt_image|nano_banana|flux")
|
||||
model: Optional[str] = None
|
||||
prompt: str
|
||||
size: Optional[str] = None
|
||||
negative_prompt: Optional[str] = None
|
||||
# Provider 별 추가 키는 extra 허용
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
async def _push_render_job(task_id: str, job_type: str, params: dict) -> None:
|
||||
"""Redis queue:image-render에 push."""
|
||||
kst = timezone(timedelta(hours=9))
|
||||
payload = {
|
||||
"task_id": task_id,
|
||||
"kind": "image",
|
||||
"job_type": job_type,
|
||||
"params": params,
|
||||
"submitted_at": datetime.now(kst).isoformat(),
|
||||
}
|
||||
await redis_client.rpush("queue:image-render", json.dumps(payload))
|
||||
|
||||
|
||||
@app.post("/api/image/generate")
|
||||
async def generate_image(req: GenerateRequest):
|
||||
"""이미지 생성 — Redis 큐로 Windows image-render에 위임."""
|
||||
if req.provider not in SUPPORTED_PROVIDERS:
|
||||
raise HTTPException(400, f"지원하지 않는 provider: {req.provider} (supported: {sorted(SUPPORTED_PROVIDERS)})")
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
params = req.model_dump(exclude_none=True)
|
||||
db.create_task(task_id, req.provider, params)
|
||||
|
||||
job_type = f"{req.provider}_generation" # gpt_image_generation, nano_banana_generation, flux_generation
|
||||
await _push_render_job(task_id, job_type, params)
|
||||
return {"task_id": task_id, "provider": req.provider}
|
||||
|
||||
|
||||
@app.get("/api/image/tasks/{task_id}")
|
||||
def get_task_status(task_id: str):
|
||||
t = db.get_task(task_id)
|
||||
if not t:
|
||||
raise HTTPException(404, "task not found")
|
||||
return t
|
||||
Reference in New Issue
Block a user