44 lines
1.5 KiB
Python
44 lines
1.5 KiB
Python
import os, tempfile, importlib
|
|
from fastapi.testclient import TestClient
|
|
|
|
|
|
def _client(monkeypatch, tmp):
|
|
monkeypatch.setenv("IMAGE_DATA_DIR", tmp)
|
|
import app.db as db
|
|
importlib.reload(db)
|
|
db.init_db()
|
|
import app.main as main
|
|
importlib.reload(main)
|
|
pushed = []
|
|
|
|
async def fake_push(task_id, job_type, params):
|
|
pushed.append((task_id, job_type, params))
|
|
|
|
monkeypatch.setattr(main, "_push_render_job", fake_push)
|
|
return TestClient(main.app), db, pushed
|
|
|
|
|
|
def test_providers_lists_three(monkeypatch):
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
client, _, _ = _client(monkeypatch, tmp)
|
|
r = client.get("/api/image/providers")
|
|
ids = {p["id"] for p in r.json()["providers"]}
|
|
assert ids == {"gpt_image", "nano_banana", "flux"}
|
|
|
|
|
|
def test_generate_rejects_unknown_provider(monkeypatch):
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
client, _, _ = _client(monkeypatch, tmp)
|
|
r = client.post("/api/image/generate", json={"provider": "midjourney", "prompt": "x"})
|
|
assert r.status_code == 400
|
|
|
|
|
|
def test_generate_creates_task_and_pushes(monkeypatch):
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
client, db, pushed = _client(monkeypatch, tmp)
|
|
r = client.post("/api/image/generate", json={"provider": "gpt_image", "prompt": "a cat"})
|
|
assert r.status_code == 200
|
|
task_id = r.json()["task_id"]
|
|
assert db.get_task(task_id)["status"] == "queued"
|
|
assert pushed[0][1] == "gpt_image_generation"
|