feat/insta-trends #4

Merged
gahusb merged 8 commits from feat/insta-trends into main 2026-05-17 08:52:50 +09:00
3 changed files with 155 additions and 3 deletions
Showing only changes of commit b3982c8f72 - Show all commits

View File

@@ -101,6 +101,29 @@ def init_db() -> None:
) )
""") """)
# source column for trending_keywords (idempotent ALTER)
cols = [r[1] for r in conn.execute("PRAGMA table_info(trending_keywords)").fetchall()]
if "source" not in cols:
conn.execute("ALTER TABLE trending_keywords ADD COLUMN source TEXT NOT NULL DEFAULT 'manual'")
conn.execute("CREATE INDEX IF NOT EXISTS idx_tk_source ON trending_keywords(source, suggested_at DESC)")
# account_preferences — 카테고리 가중치
conn.execute("""
CREATE TABLE IF NOT EXISTS account_preferences (
category TEXT PRIMARY KEY,
weight REAL NOT NULL DEFAULT 1.0,
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now'))
)
""")
# seed defaults if table empty
existing = conn.execute("SELECT COUNT(*) FROM account_preferences").fetchone()[0]
if existing == 0:
for cat in ("economy", "psychology", "celebrity"):
conn.execute(
"INSERT INTO account_preferences(category, weight) VALUES(?,?)",
(cat, 1.0),
)
# ── news_articles ──────────────────────────────────────────────── # ── news_articles ────────────────────────────────────────────────
def add_news_article(row: Dict[str, Any]) -> int: def add_news_article(row: Dict[str, Any]) -> int:
@@ -132,8 +155,12 @@ def list_news_articles(category: Optional[str] = None, days: int = 1) -> List[Di
def add_trending_keyword(row: Dict[str, Any]) -> int: def add_trending_keyword(row: Dict[str, Any]) -> int:
with _conn() as conn: with _conn() as conn:
cur = conn.execute( cur = conn.execute(
"INSERT INTO trending_keywords(keyword, category, score, articles_count) VALUES(?,?,?,?)", "INSERT INTO trending_keywords(keyword, category, score, articles_count, source) VALUES(?,?,?,?,?)",
(row["keyword"], row["category"], float(row.get("score", 0.0)), int(row.get("articles_count", 0))), (
row["keyword"], row["category"],
float(row.get("score", 0.0)), int(row.get("articles_count", 0)),
row.get("source", "manual"),
),
) )
return cur.lastrowid return cur.lastrowid
@@ -276,3 +303,50 @@ def get_prompt_template(name: str) -> Optional[Dict[str, Any]]:
with _conn() as conn: with _conn() as conn:
row = conn.execute("SELECT * FROM prompt_templates WHERE name=?", (name,)).fetchone() row = conn.execute("SELECT * FROM prompt_templates WHERE name=?", (name,)).fetchone()
return dict(row) if row else None return dict(row) if row else None
# ── external trends ─────────────────────────────────────────────
def add_external_trend(row: Dict[str, Any]) -> int:
"""`source` 필수 — naver_popular | google_trends. trending_keywords에 인서트."""
if "source" not in row:
raise ValueError("add_external_trend requires 'source' field")
return add_trending_keyword(row)
def list_trends(source: Optional[str] = None, category: Optional[str] = None,
days: int = 1) -> List[Dict[str, Any]]:
sql = "SELECT * FROM trending_keywords WHERE suggested_at >= datetime('now', ?)"
params: List[Any] = [f"-{int(days)} days"]
if source and source != "all":
sql += " AND source=?"
params.append(source)
if category:
sql += " AND category=?"
params.append(category)
sql += " ORDER BY suggested_at DESC, score DESC"
with _conn() as conn:
rows = conn.execute(sql, params).fetchall()
return [dict(r) for r in rows]
# ── account_preferences ─────────────────────────────────────────
def get_preferences() -> List[Dict[str, Any]]:
with _conn() as conn:
rows = conn.execute(
"SELECT category, weight, updated_at FROM account_preferences ORDER BY category ASC"
).fetchall()
return [dict(r) for r in rows]
def upsert_preferences(weights: Dict[str, float]) -> None:
"""전체 upsert. 기존에 있던 카테고리는 weight 갱신, 신규는 INSERT.
명시되지 않은 기존 카테고리는 그대로 둔다 (삭제 X). 삭제 필요 시 별도 API로."""
with _conn() as conn:
for cat, w in weights.items():
conn.execute("""
INSERT INTO account_preferences(category, weight)
VALUES(?,?)
ON CONFLICT(category) DO UPDATE SET
weight=excluded.weight,
updated_at=strftime('%Y-%m-%dT%H:%M:%fZ','now')
""", (cat, float(w)))

View File

@@ -24,7 +24,7 @@ def tmp_db(monkeypatch):
pass pass
def test_init_db_creates_six_tables(tmp_db): def test_init_db_creates_seven_tables(tmp_db):
with db_module._conn() as conn: with db_module._conn() as conn:
rows = conn.execute( rows = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
@@ -33,6 +33,7 @@ def test_init_db_creates_six_tables(tmp_db):
assert names == sorted([ assert names == sorted([
"news_articles", "trending_keywords", "card_slates", "news_articles", "trending_keywords", "card_slates",
"card_assets", "generation_tasks", "prompt_templates", "card_assets", "generation_tasks", "prompt_templates",
"account_preferences",
]) ])

View File

@@ -0,0 +1,77 @@
import os
import gc
import tempfile
import pytest
from app import db as db_module
@pytest.fixture
def tmp_db(monkeypatch):
fd, path = tempfile.mkstemp(suffix=".db")
os.close(fd)
monkeypatch.setattr(db_module, "DB_PATH", path)
db_module.init_db()
yield path
gc.collect()
for ext in ("", "-wal", "-shm"):
try:
os.remove(path + ext)
except OSError:
pass
def test_init_db_creates_account_preferences(tmp_db):
with db_module._conn() as conn:
rows = conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
names = {r[0] for r in rows}
assert "account_preferences" in names
def test_init_db_seeds_default_weights(tmp_db):
prefs = db_module.get_preferences()
cats = {p["category"]: p["weight"] for p in prefs}
assert cats["economy"] == pytest.approx(1.0)
assert cats["psychology"] == pytest.approx(1.0)
assert cats["celebrity"] == pytest.approx(1.0)
def test_upsert_preferences_replaces_weights(tmp_db):
db_module.upsert_preferences({"economy": 0.6, "psychology": 0.3, "celebrity": 0.1, "tech": 0.5})
prefs = {p["category"]: p["weight"] for p in db_module.get_preferences()}
assert prefs["economy"] == pytest.approx(0.6)
assert prefs["tech"] == pytest.approx(0.5)
assert "celebrity" in prefs and prefs["celebrity"] == pytest.approx(0.1)
def test_trending_keywords_source_column_exists(tmp_db):
with db_module._conn() as conn:
cols = [r[1] for r in conn.execute("PRAGMA table_info(trending_keywords)").fetchall()]
assert "source" in cols
def test_add_trending_keyword_default_source(tmp_db):
kid = db_module.add_trending_keyword({
"keyword": "K", "category": "economy", "score": 0.5, "articles_count": 3,
})
with db_module._conn() as conn:
row = conn.execute("SELECT source FROM trending_keywords WHERE id=?", (kid,)).fetchone()
assert row[0] == "manual"
def test_add_external_trend_stores_source(tmp_db):
tid = db_module.add_external_trend({
"keyword": "급등주", "category": "economy", "source": "naver_popular", "score": 0.9,
})
rows = db_module.list_trends(source="naver_popular")
assert any(r["id"] == tid and r["keyword"] == "급등주" for r in rows)
def test_list_trends_filters_by_source_and_category(tmp_db):
db_module.add_external_trend({"keyword": "A", "category": "economy", "source": "naver_popular", "score": 1.0})
db_module.add_external_trend({"keyword": "B", "category": "celebrity", "source": "google_trends", "score": 1.0})
only_naver = db_module.list_trends(source="naver_popular")
assert {r["keyword"] for r in only_naver} == {"A"}
only_celeb_google = db_module.list_trends(source="google_trends", category="celebrity")
assert {r["keyword"] for r in only_celeb_google} == {"B"}