diff --git a/insta-lab/app/db.py b/insta-lab/app/db.py new file mode 100644 index 0000000..963218d --- /dev/null +++ b/insta-lab/app/db.py @@ -0,0 +1,278 @@ +import os +import sqlite3 +import json +import uuid +from typing import Any, Dict, List, Optional + +from .config import DB_PATH + + +def _conn() -> sqlite3.Connection: + os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) + conn = sqlite3.connect(DB_PATH, timeout=120.0) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA busy_timeout=120000") + conn.execute("PRAGMA foreign_keys=ON") + return conn + + +def init_db() -> None: + with _conn() as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS news_articles ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + category TEXT NOT NULL, + title TEXT NOT NULL, + link TEXT NOT NULL UNIQUE, + summary TEXT NOT NULL DEFAULT '', + pub_date TEXT, + fetched_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')) + ) + """) + conn.execute("CREATE INDEX IF NOT EXISTS idx_na_category_fetched ON news_articles(category, fetched_at DESC)") + + conn.execute(""" + CREATE TABLE IF NOT EXISTS trending_keywords ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + keyword TEXT NOT NULL, + category TEXT NOT NULL, + score REAL NOT NULL DEFAULT 0, + articles_count INTEGER NOT NULL DEFAULT 0, + suggested_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')), + used INTEGER NOT NULL DEFAULT 0 + ) + """) + conn.execute("CREATE INDEX IF NOT EXISTS idx_tk_score ON trending_keywords(category, score DESC)") + + conn.execute(""" + CREATE TABLE IF NOT EXISTS card_slates ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + keyword TEXT NOT NULL, + category TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'draft', + cover_copy TEXT NOT NULL DEFAULT '{}', + body_copies TEXT NOT NULL DEFAULT '[]', + cta_copy TEXT NOT NULL DEFAULT '{}', + suggested_caption TEXT NOT NULL DEFAULT '', + hashtags TEXT NOT NULL DEFAULT '[]', + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')), + updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')) + ) + """) + conn.execute("CREATE INDEX IF NOT EXISTS idx_cs_created ON card_slates(created_at DESC)") + + conn.execute(""" + CREATE TABLE IF NOT EXISTS card_assets ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + slate_id INTEGER NOT NULL REFERENCES card_slates(id) ON DELETE CASCADE, + page_index INTEGER NOT NULL, + file_path TEXT NOT NULL, + file_hash TEXT NOT NULL DEFAULT '', + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')), + UNIQUE (slate_id, page_index) + ) + """) + conn.execute("CREATE INDEX IF NOT EXISTS idx_ca_slate ON card_assets(slate_id, page_index)") + + conn.execute(""" + CREATE TABLE IF NOT EXISTS generation_tasks ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'queued', + progress INTEGER NOT NULL DEFAULT 0, + message TEXT NOT NULL DEFAULT '', + result_id INTEGER, + error TEXT, + params TEXT NOT NULL DEFAULT '{}', + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')), + updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')) + ) + """) + conn.execute("CREATE INDEX IF NOT EXISTS idx_gt_created ON generation_tasks(created_at DESC)") + + conn.execute(""" + CREATE TABLE IF NOT EXISTS prompt_templates ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + description TEXT NOT NULL DEFAULT '', + template TEXT NOT NULL DEFAULT '', + updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')) + ) + """) + + +# ── news_articles ──────────────────────────────────────────────── +def add_news_article(row: Dict[str, Any]) -> int: + with _conn() as conn: + try: + cur = conn.execute( + "INSERT INTO news_articles(category, title, link, summary, pub_date) VALUES(?,?,?,?,?)", + (row["category"], row["title"], row["link"], row.get("summary", ""), row.get("pub_date")), + ) + return cur.lastrowid + except sqlite3.IntegrityError: + existing = conn.execute("SELECT id FROM news_articles WHERE link=?", (row["link"],)).fetchone() + return existing["id"] if existing else 0 + + +def list_news_articles(category: Optional[str] = None, days: int = 1) -> List[Dict[str, Any]]: + sql = "SELECT * FROM news_articles WHERE fetched_at >= datetime('now', ?)" + params: List[Any] = [f"-{int(days)} days"] + if category: + sql += " AND category=?" + params.append(category) + sql += " ORDER BY fetched_at DESC" + with _conn() as conn: + rows = conn.execute(sql, params).fetchall() + return [dict(r) for r in rows] + + +# ── trending_keywords ─────────────────────────────────────────── +def add_trending_keyword(row: Dict[str, Any]) -> int: + with _conn() as conn: + cur = conn.execute( + "INSERT INTO trending_keywords(keyword, category, score, articles_count) VALUES(?,?,?,?)", + (row["keyword"], row["category"], float(row.get("score", 0.0)), int(row.get("articles_count", 0))), + ) + return cur.lastrowid + + +def list_trending_keywords(category: Optional[str] = None, used: Optional[bool] = None) -> List[Dict[str, Any]]: + sql = "SELECT * FROM trending_keywords WHERE 1=1" + params: List[Any] = [] + if category: + sql += " AND category=?" + params.append(category) + if used is not None: + sql += " AND used=?" + params.append(1 if used else 0) + sql += " ORDER BY score DESC, suggested_at DESC" + with _conn() as conn: + rows = conn.execute(sql, params).fetchall() + return [dict(r) for r in rows] + + +def mark_keyword_used(keyword_id: int) -> None: + with _conn() as conn: + conn.execute("UPDATE trending_keywords SET used=1 WHERE id=?", (keyword_id,)) + + +def get_trending_keyword(keyword_id: int) -> Optional[Dict[str, Any]]: + with _conn() as conn: + row = conn.execute("SELECT * FROM trending_keywords WHERE id=?", (keyword_id,)).fetchone() + return dict(row) if row else None + + +# ── card_slates ───────────────────────────────────────────────── +def add_card_slate(row: Dict[str, Any]) -> int: + with _conn() as conn: + cur = conn.execute(""" + INSERT INTO card_slates(keyword, category, status, cover_copy, body_copies, cta_copy, + suggested_caption, hashtags) + VALUES(?,?,?,?,?,?,?,?) + """, ( + row["keyword"], row["category"], row.get("status", "draft"), + json.dumps(row.get("cover_copy", {}), ensure_ascii=False), + json.dumps(row.get("body_copies", []), ensure_ascii=False), + json.dumps(row.get("cta_copy", {}), ensure_ascii=False), + row.get("suggested_caption", ""), + json.dumps(row.get("hashtags", []), ensure_ascii=False), + )) + return cur.lastrowid + + +def update_slate_status(slate_id: int, status: str) -> None: + with _conn() as conn: + conn.execute( + "UPDATE card_slates SET status=?, updated_at=strftime('%Y-%m-%dT%H:%M:%fZ','now') WHERE id=?", + (status, slate_id), + ) + + +def get_card_slate(slate_id: int) -> Optional[Dict[str, Any]]: + with _conn() as conn: + row = conn.execute("SELECT * FROM card_slates WHERE id=?", (slate_id,)).fetchone() + return dict(row) if row else None + + +def list_card_slates(limit: int = 50) -> List[Dict[str, Any]]: + with _conn() as conn: + rows = conn.execute( + "SELECT * FROM card_slates ORDER BY created_at DESC LIMIT ?", + (limit,), + ).fetchall() + return [dict(r) for r in rows] + + +def delete_card_slate(slate_id: int) -> None: + with _conn() as conn: + conn.execute("DELETE FROM card_slates WHERE id=?", (slate_id,)) + + +# ── card_assets ───────────────────────────────────────────────── +def add_card_asset(slate_id: int, page_index: int, file_path: str, file_hash: str = "") -> int: + with _conn() as conn: + cur = conn.execute(""" + INSERT INTO card_assets(slate_id, page_index, file_path, file_hash) + VALUES(?,?,?,?) + ON CONFLICT(slate_id, page_index) DO UPDATE SET + file_path=excluded.file_path, file_hash=excluded.file_hash + """, (slate_id, page_index, file_path, file_hash)) + return cur.lastrowid + + +def list_card_assets(slate_id: int) -> List[Dict[str, Any]]: + with _conn() as conn: + rows = conn.execute( + "SELECT * FROM card_assets WHERE slate_id=? ORDER BY page_index ASC", + (slate_id,), + ).fetchall() + return [dict(r) for r in rows] + + +# ── generation_tasks ──────────────────────────────────────────── +def create_task(task_type: str, params: Dict[str, Any]) -> str: + tid = uuid.uuid4().hex + with _conn() as conn: + conn.execute( + "INSERT INTO generation_tasks(id, type, params) VALUES(?,?,?)", + (tid, task_type, json.dumps(params, ensure_ascii=False)), + ) + return tid + + +def update_task(task_id: str, status: str, progress: int = 0, message: str = "", + result_id: Optional[int] = None, error: Optional[str] = None) -> None: + with _conn() as conn: + conn.execute(""" + UPDATE generation_tasks + SET status=?, progress=?, message=?, result_id=?, error=?, + updated_at=strftime('%Y-%m-%dT%H:%M:%fZ','now') + WHERE id=? + """, (status, progress, message, result_id, error, task_id)) + + +def get_task(task_id: str) -> Optional[Dict[str, Any]]: + with _conn() as conn: + row = conn.execute("SELECT * FROM generation_tasks WHERE id=?", (task_id,)).fetchone() + return dict(row) if row else None + + +# ── prompt_templates ──────────────────────────────────────────── +def upsert_prompt_template(name: str, template: str, description: str = "") -> None: + with _conn() as conn: + conn.execute(""" + INSERT INTO prompt_templates(name, description, template) + VALUES(?,?,?) + ON CONFLICT(name) DO UPDATE SET + template=excluded.template, + description=excluded.description, + updated_at=strftime('%Y-%m-%dT%H:%M:%fZ','now') + """, (name, description, template)) + + +def get_prompt_template(name: str) -> Optional[Dict[str, Any]]: + with _conn() as conn: + row = conn.execute("SELECT * FROM prompt_templates WHERE name=?", (name,)).fetchone() + return dict(row) if row else None diff --git a/insta-lab/tests/test_db.py b/insta-lab/tests/test_db.py new file mode 100644 index 0000000..9a853a9 --- /dev/null +++ b/insta-lab/tests/test_db.py @@ -0,0 +1,96 @@ +import os +import json +import tempfile + +import pytest + +from app import db as db_module + + +@pytest.fixture +def tmp_db(monkeypatch): + fd, path = tempfile.mkstemp(suffix=".db") + os.close(fd) + monkeypatch.setattr(db_module, "DB_PATH", path) + db_module.init_db() + yield path + # Close all SQLite WAL files before removal (needed on Windows) + import gc + gc.collect() + for ext in ("", "-wal", "-shm"): + try: + os.remove(path + ext) + except FileNotFoundError: + pass + + +def test_init_db_creates_six_tables(tmp_db): + with db_module._conn() as conn: + rows = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" + ).fetchall() + names = sorted(r[0] for r in rows if not r[0].startswith("sqlite_")) + assert names == sorted([ + "news_articles", "trending_keywords", "card_slates", + "card_assets", "generation_tasks", "prompt_templates", + ]) + + +def test_news_article_roundtrip(tmp_db): + aid = db_module.add_news_article({ + "category": "economy", + "title": "금리 인상 발표", + "link": "https://example.com/1", + "summary": "한국은행이 기준금리를 인상했다.", + "pub_date": "2026-05-15T08:00:00", + }) + assert isinstance(aid, int) + rows = db_module.list_news_articles(category="economy", days=7) + assert len(rows) == 1 + assert rows[0]["title"] == "금리 인상 발표" + + +def test_trending_keyword_roundtrip(tmp_db): + kid = db_module.add_trending_keyword({ + "keyword": "기준금리", + "category": "economy", + "score": 0.87, + "articles_count": 12, + }) + assert isinstance(kid, int) + items = db_module.list_trending_keywords(category="economy", used=False) + assert items[0]["score"] == pytest.approx(0.87) + + +def test_card_slate_with_assets(tmp_db): + sid = db_module.add_card_slate({ + "keyword": "기준금리", + "category": "economy", + "cover_copy": {"headline": "금리 인상", "body": "왜?", "accent_color": "#0F62FE"}, + "body_copies": [{"headline": f"H{i}", "body": f"B{i}"} for i in range(8)], + "cta_copy": {"headline": "정리", "body": "바로 확인", "cta": "팔로우"}, + "suggested_caption": "금리에 대해 알아보자", + "hashtags": ["#금리", "#경제"], + }) + db_module.add_card_asset(sid, page_index=1, file_path="/tmp/01.png", file_hash="abc") + slate = db_module.get_card_slate(sid) + assert slate["status"] == "draft" + assert json.loads(slate["body_copies"])[0]["headline"] == "H0" + assets = db_module.list_card_assets(sid) + assert assets[0]["page_index"] == 1 + + +def test_generation_task_lifecycle(tmp_db): + tid = db_module.create_task("collect", {"category": "economy"}) + db_module.update_task(tid, status="processing", progress=50, message="..") + db_module.update_task(tid, status="succeeded", progress=100, message="ok", result_id=123) + t = db_module.get_task(tid) + assert t["status"] == "succeeded" + assert t["result_id"] == 123 + + +def test_prompt_template_upsert(tmp_db): + db_module.upsert_prompt_template("slate_writer", "v1 template", "writer") + db_module.upsert_prompt_template("slate_writer", "v2 template", "writer") + pt = db_module.get_prompt_template("slate_writer") + assert pt["template"] == "v2 template"