diff --git a/insta-lab/app/main.py b/insta-lab/app/main.py new file mode 100644 index 0000000..fe0ef80 --- /dev/null +++ b/insta-lab/app/main.py @@ -0,0 +1,245 @@ +"""FastAPI entrypoint for insta-lab.""" + +import asyncio +import json +import logging +import os +from typing import Optional + +from fastapi import FastAPI, HTTPException, BackgroundTasks, Body, Query +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse +from pydantic import BaseModel + +from .config import ( + CORS_ALLOW_ORIGINS, NAVER_CLIENT_ID, ANTHROPIC_API_KEY, + INSTA_DATA_PATH, DB_PATH, DEFAULT_CATEGORY_SEEDS, KEYWORDS_PER_CATEGORY, +) +from . import db, news_collector, keyword_extractor, card_writer, card_renderer + +logger = logging.getLogger(__name__) +app = FastAPI() + +app.add_middleware( + CORSMiddleware, + allow_origins=[o.strip() for o in CORS_ALLOW_ORIGINS.split(",")], + allow_credentials=False, + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"], + allow_headers=["Content-Type"], +) + + +@app.on_event("startup") +def on_startup(): + os.makedirs(INSTA_DATA_PATH, exist_ok=True) + db.init_db() + + +@app.get("/health") +def health(): + return {"ok": True} + + +@app.get("/api/insta/status") +def status(): + return { + "ok": True, + "naver_api": bool(NAVER_CLIENT_ID), + "anthropic_api": bool(ANTHROPIC_API_KEY), + } + + +# ── News ───────────────────────────────────────────────────────── +class CollectRequest(BaseModel): + categories: Optional[list[str]] = None + + +def _seeds_for(category: str) -> list[str]: + pt = db.get_prompt_template("category_seeds") + if pt and pt.get("template"): + try: + data = json.loads(pt["template"]) + if category in data: + return list(data[category]) + except Exception: + pass + return list(DEFAULT_CATEGORY_SEEDS.get(category, [])) + + +async def _bg_collect(task_id: str, categories: list[str]): + try: + db.update_task(task_id, "processing", 10, "수집 중") + total = 0 + for cat in categories: + seeds = _seeds_for(cat) + if not seeds: + continue + total += news_collector.collect_for_category(cat, seeds) + db.update_task(task_id, "succeeded", 100, f"{total}건 수집", result_id=total) + except Exception as e: + logger.exception("collect failed") + db.update_task(task_id, "failed", 0, "", error=str(e)) + + +@app.post("/api/insta/news/collect") +def collect_news(req: CollectRequest, bg: BackgroundTasks): + cats = req.categories or list(DEFAULT_CATEGORY_SEEDS.keys()) + tid = db.create_task("news_collect", {"categories": cats}) + bg.add_task(_bg_collect, tid, cats) + return {"task_id": tid, "categories": cats} + + +@app.get("/api/insta/news/articles") +def list_articles(category: Optional[str] = None, days: int = Query(7, ge=1, le=90)): + return {"items": db.list_news_articles(category=category, days=days)} + + +# ── Keywords ───────────────────────────────────────────────────── +class ExtractRequest(BaseModel): + categories: Optional[list[str]] = None + + +async def _bg_extract(task_id: str, categories: list[str]): + try: + db.update_task(task_id, "processing", 10, "추출 중") + for cat in categories: + keyword_extractor.extract_for_category(cat, limit=KEYWORDS_PER_CATEGORY) + db.update_task(task_id, "succeeded", 100, "완료", result_id=0) + except Exception as e: + logger.exception("extract failed") + db.update_task(task_id, "failed", 0, "", error=str(e)) + + +@app.post("/api/insta/keywords/extract") +def extract_keywords(req: ExtractRequest, bg: BackgroundTasks): + cats = req.categories or list(DEFAULT_CATEGORY_SEEDS.keys()) + tid = db.create_task("keyword_extract", {"categories": cats}) + bg.add_task(_bg_extract, tid, cats) + return {"task_id": tid, "categories": cats} + + +@app.get("/api/insta/keywords") +def list_keywords(category: Optional[str] = None, used: Optional[bool] = None): + return {"items": db.list_trending_keywords(category=category, used=used)} + + +# ── Slates ─────────────────────────────────────────────────────── +class SlateRequest(BaseModel): + keyword: str + category: str + keyword_id: Optional[int] = None + + +async def _bg_create_slate(task_id: str, keyword: str, category: str, keyword_id: Optional[int]): + try: + db.update_task(task_id, "processing", 30, "카피 생성 중") + sid = card_writer.write_slate(keyword=keyword, category=category) + db.update_task(task_id, "processing", 70, "카드 렌더 중") + await card_renderer.render_slate(sid) + db.update_slate_status(sid, "rendered") + if keyword_id: + db.mark_keyword_used(keyword_id) + db.update_task(task_id, "succeeded", 100, "완료", result_id=sid) + except Exception as e: + logger.exception("create slate failed") + db.update_task(task_id, "failed", 0, "", error=str(e)) + + +@app.post("/api/insta/slates") +def create_slate(req: SlateRequest, bg: BackgroundTasks): + tid = db.create_task("slate_create", req.dict()) + bg.add_task(_bg_create_slate, tid, req.keyword, req.category, req.keyword_id) + return {"task_id": tid} + + +@app.get("/api/insta/slates") +def list_slates(limit: int = Query(50, ge=1, le=500)): + return {"items": db.list_card_slates(limit=limit)} + + +@app.get("/api/insta/slates/{slate_id}") +def get_slate(slate_id: int): + s = db.get_card_slate(slate_id) + if not s: + raise HTTPException(404, "slate not found") + s["assets"] = db.list_card_assets(slate_id) + for k in ("cover_copy", "body_copies", "cta_copy", "hashtags"): + if isinstance(s.get(k), str): + try: + s[k] = json.loads(s[k]) + except Exception: + pass + return s + + +async def _bg_render(task_id: str, slate_id: int): + try: + db.update_task(task_id, "processing", 30, "재렌더 중") + await card_renderer.render_slate(slate_id) + db.update_slate_status(slate_id, "rendered") + db.update_task(task_id, "succeeded", 100, "완료", result_id=slate_id) + except Exception as e: + logger.exception("render failed") + db.update_task(task_id, "failed", 0, "", error=str(e)) + + +@app.post("/api/insta/slates/{slate_id}/render") +def render_slate_endpoint(slate_id: int, bg: BackgroundTasks): + if not db.get_card_slate(slate_id): + raise HTTPException(404, "slate not found") + tid = db.create_task("slate_render", {"slate_id": slate_id}) + bg.add_task(_bg_render, tid, slate_id) + return {"task_id": tid} + + +@app.get("/api/insta/slates/{slate_id}/assets/{page}") +def get_asset(slate_id: int, page: int): + if not (1 <= page <= 10): + raise HTTPException(400, "page must be 1..10") + assets = db.list_card_assets(slate_id) + match = next((a for a in assets if a["page_index"] == page), None) + if not match: + raise HTTPException(404, "asset not found") + return FileResponse(match["file_path"], media_type="image/png") + + +@app.delete("/api/insta/slates/{slate_id}") +def delete_slate(slate_id: int): + if not db.get_card_slate(slate_id): + raise HTTPException(404) + for a in db.list_card_assets(slate_id): + try: + os.unlink(a["file_path"]) + except OSError: + pass + db.delete_card_slate(slate_id) + return {"ok": True} + + +# ── Tasks ──────────────────────────────────────────────────────── +@app.get("/api/insta/tasks/{task_id}") +def get_task_status(task_id: str): + t = db.get_task(task_id) + if not t: + raise HTTPException(404) + return t + + +# ── Prompt Templates ───────────────────────────────────────────── +class TemplateBody(BaseModel): + template: str + description: str = "" + + +@app.get("/api/insta/templates/prompts/{name}") +def get_prompt(name: str): + pt = db.get_prompt_template(name) + if not pt: + raise HTTPException(404) + return pt + + +@app.put("/api/insta/templates/prompts/{name}") +def upsert_prompt(name: str, body: TemplateBody): + db.upsert_prompt_template(name, body.template, body.description) + return db.get_prompt_template(name) diff --git a/insta-lab/tests/test_main.py b/insta-lab/tests/test_main.py new file mode 100644 index 0000000..7ae31ce --- /dev/null +++ b/insta-lab/tests/test_main.py @@ -0,0 +1,91 @@ +import os +import tempfile + +import pytest +from fastapi.testclient import TestClient + +from app import db as db_module + + +@pytest.fixture +def client(monkeypatch): + fd, path = tempfile.mkstemp(suffix=".db") + os.close(fd) + monkeypatch.setattr(db_module, "DB_PATH", path) + db_module.init_db() + from app import main + monkeypatch.setattr(main, "DB_PATH", path) + with TestClient(main.app) as c: + yield c + import gc + gc.collect() + for ext in ("", "-wal", "-shm"): + try: + os.remove(path + ext) + except OSError: + pass + + +def test_health(client): + resp = client.get("/health") + assert resp.status_code == 200 + assert resp.json()["ok"] is True + + +def test_status_endpoint(client): + resp = client.get("/api/insta/status") + assert resp.status_code == 200 + j = resp.json() + assert "naver_api" in j and "anthropic_api" in j + + +def test_news_articles_listing(client): + db_module.add_news_article({ + "category": "economy", "title": "T1", "link": "https://x/1", "summary": "S", + }) + resp = client.get("/api/insta/news/articles?category=economy&days=7") + assert resp.status_code == 200 + assert len(resp.json()["items"]) == 1 + + +def test_keywords_listing(client): + db_module.add_trending_keyword({ + "keyword": "K", "category": "economy", "score": 0.5, "articles_count": 3, + }) + resp = client.get("/api/insta/keywords?category=economy") + assert resp.status_code == 200 + assert resp.json()["items"][0]["keyword"] == "K" + + +def test_create_slate_kicks_background_task(client, monkeypatch): + from app import main, card_writer, card_renderer + + def fake_write(keyword, category, articles=None): + return db_module.add_card_slate({ + "keyword": keyword, "category": category, "status": "draft", + "cover_copy": {"headline": "H", "body": "B", "accent_color": "#000"}, + "body_copies": [{"headline": f"h{i}", "body": f"b{i}"} for i in range(8)], + "cta_copy": {"headline": "C", "body": "B", "cta": "F"}, + }) + + async def fake_render(slate_id, template="default/card.html.j2"): + for i in range(1, 11): + db_module.add_card_asset(slate_id, i, f"/tmp/{slate_id}_{i}.png", "h") + return [f"/tmp/{slate_id}_{i}.png" for i in range(1, 11)] + + monkeypatch.setattr(card_writer, "write_slate", fake_write) + monkeypatch.setattr(card_renderer, "render_slate", fake_render) + + resp = client.post("/api/insta/slates", json={"keyword": "K", "category": "economy"}) + assert resp.status_code == 200 + task_id = resp.json()["task_id"] + # poll task + for _ in range(20): + st = client.get(f"/api/insta/tasks/{task_id}").json() + if st["status"] in ("succeeded", "failed"): + break + assert st["status"] == "succeeded" + slate_id = st["result_id"] + detail = client.get(f"/api/insta/slates/{slate_id}").json() + assert detail["status"] == "rendered" + assert len(detail["assets"]) == 10