feat/insta-trends #4

Merged
gahusb merged 8 commits from feat/insta-trends into main 2026-05-17 08:52:50 +09:00
14 changed files with 821 additions and 8 deletions

View File

@@ -56,6 +56,8 @@ class InstaAgent(BaseAgent):
requires_approval=False)
await self.transition("working", "뉴스 수집·키워드 추출", task_id)
try:
prefs = await service_proxy.insta_get_preferences()
add_log(self.agent_id, f"insta preferences: {prefs}", "info", task_id)
await self._run_collect_and_extract()
kws = await service_proxy.insta_list_keywords(used=False)
if auto_select:
@@ -147,6 +149,12 @@ class InstaAgent(BaseAgent):
return {"ok": False, "message": "keyword_id 필수"}
await self._render_and_push(kid)
return {"ok": True}
if command == "collect_trends":
await messaging.send_raw("🌐 외부 트렌드 수집 시작")
created = await service_proxy.insta_collect_trends()
st = await self._wait_task(created["task_id"], step="trends_collect", timeout_sec=300)
await messaging.send_raw(f"✅ 트렌드 수집 완료: {st.get('message', '')}")
return {"ok": True, "result": st}
return {"ok": False, "message": f"Unknown command: {command}"}
async def on_callback(self, action: str, params: dict) -> dict:

View File

@@ -29,6 +29,12 @@ async def _run_insta_schedule():
if agent:
await agent.on_schedule()
async def _run_insta_trends_collect():
agent = AGENT_REGISTRY.get("insta")
if agent:
await agent.on_command("collect_trends", {})
async def _run_lotto_schedule():
agent = AGENT_REGISTRY.get("lotto")
if agent:
@@ -68,6 +74,7 @@ def init_scheduler():
id="stock_ai_news_sentiment",
)
scheduler.add_job(_run_insta_schedule, "cron", hour=9, minute=30, id="insta_pipeline")
scheduler.add_job(_run_insta_trends_collect, "cron", hour=9, minute=0, id="insta_trends_collect")
scheduler.add_job(_run_lotto_schedule, "cron", day_of_week="mon", hour=9, minute=0, id="lotto_curate")
scheduler.add_job(_run_youtube_research, "cron", hour=9, minute=0, id="youtube_research")
scheduler.add_job(_send_youtube_weekly_report, "cron", day_of_week="mon", hour=8, minute=0, id="youtube_weekly_report")

View File

@@ -167,6 +167,41 @@ async def insta_get_asset_bytes(slate_id: int, page: int) -> bytes:
return resp.content
async def insta_collect_trends(categories: Optional[list] = None) -> Dict[str, Any]:
payload = {"categories": categories} if categories else {}
resp = await _client.post(f"{INSTA_LAB_URL}/api/insta/trends/collect", json=payload)
resp.raise_for_status()
return resp.json()
async def insta_list_trends(source: Optional[str] = None,
category: Optional[str] = None,
days: int = 1) -> List[Dict[str, Any]]:
params: Dict[str, Any] = {"days": days}
if source:
params["source"] = source
if category:
params["category"] = category
resp = await _client.get(f"{INSTA_LAB_URL}/api/insta/trends", params=params)
resp.raise_for_status()
return resp.json().get("items", [])
async def insta_get_preferences() -> Dict[str, float]:
resp = await _client.get(f"{INSTA_LAB_URL}/api/insta/preferences")
resp.raise_for_status()
return {p["category"]: p["weight"] for p in resp.json().get("categories", [])}
async def insta_put_preferences(weights: Dict[str, float]) -> Dict[str, Any]:
resp = await _client.put(
f"{INSTA_LAB_URL}/api/insta/preferences",
json={"categories": weights},
)
resp.raise_for_status()
return resp.json()
# --- realestate-lab ---
async def realestate_collect() -> Dict[str, Any]:

View File

@@ -0,0 +1,73 @@
import os
import sys
import tempfile
_fd, _TMP = tempfile.mkstemp(suffix=".db")
os.close(_fd)
os.unlink(_TMP)
os.environ["AGENT_OFFICE_DB_PATH"] = _TMP
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from unittest.mock import AsyncMock
import pytest
from app.agents.insta import InstaAgent
@pytest.fixture(autouse=True)
def _init_db():
import gc
gc.collect()
if os.path.exists(_TMP):
os.remove(_TMP)
from app.db import init_db
init_db()
yield
gc.collect()
@pytest.mark.asyncio
async def test_on_command_collect_trends_dispatches(monkeypatch):
agent = InstaAgent()
fake_collect = AsyncMock(return_value={"task_id": "tcollect"})
fake_status = AsyncMock(return_value={"status": "succeeded", "result_id": 8,
"message": "naver:5, google:3"})
monkeypatch.setattr("app.agents.insta.service_proxy.insta_collect_trends", fake_collect)
monkeypatch.setattr("app.agents.insta.service_proxy.insta_task_status", fake_status)
monkeypatch.setattr("app.agents.insta.messaging.send_raw", AsyncMock(return_value={"ok": True}))
result = await agent.on_command("collect_trends", {})
assert result["ok"] is True
fake_collect.assert_awaited()
@pytest.mark.asyncio
async def test_on_schedule_loads_preferences(monkeypatch):
"""on_schedule이 preferences를 가져오는지 확인."""
agent = InstaAgent()
fake_collect = AsyncMock(return_value={"task_id": "t1"})
fake_extract = AsyncMock(return_value={"task_id": "t2"})
fake_status = AsyncMock(side_effect=[
{"status": "succeeded", "result_id": 0},
{"status": "succeeded", "result_id": 0},
])
fake_keywords = AsyncMock(return_value=[
{"id": 1, "keyword": "K", "category": "economy", "score": 0.9},
])
fake_prefs = AsyncMock(return_value={"economy": 0.6, "psychology": 0.4})
monkeypatch.setattr("app.agents.insta.service_proxy.insta_collect", fake_collect)
monkeypatch.setattr("app.agents.insta.service_proxy.insta_extract", fake_extract)
monkeypatch.setattr("app.agents.insta.service_proxy.insta_task_status", fake_status)
monkeypatch.setattr("app.agents.insta.service_proxy.insta_list_keywords", fake_keywords)
monkeypatch.setattr("app.agents.insta.service_proxy.insta_get_preferences", fake_prefs)
monkeypatch.setattr("app.agents.insta.messaging.send_raw", AsyncMock(return_value={"ok": True}))
agent.state = "idle"
await agent.on_schedule()
fake_prefs.assert_awaited()

View File

@@ -101,6 +101,29 @@ def init_db() -> None:
)
""")
# source column for trending_keywords (idempotent ALTER)
cols = [r[1] for r in conn.execute("PRAGMA table_info(trending_keywords)").fetchall()]
if "source" not in cols:
conn.execute("ALTER TABLE trending_keywords ADD COLUMN source TEXT NOT NULL DEFAULT 'manual'")
conn.execute("CREATE INDEX IF NOT EXISTS idx_tk_source ON trending_keywords(source, suggested_at DESC)")
# account_preferences — 카테고리 가중치
conn.execute("""
CREATE TABLE IF NOT EXISTS account_preferences (
category TEXT PRIMARY KEY,
weight REAL NOT NULL DEFAULT 1.0,
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now'))
)
""")
# seed defaults if table empty
existing = conn.execute("SELECT COUNT(*) FROM account_preferences").fetchone()[0]
if existing == 0:
for cat in ("economy", "psychology", "celebrity"):
conn.execute(
"INSERT INTO account_preferences(category, weight) VALUES(?,?)",
(cat, 1.0),
)
# ── news_articles ────────────────────────────────────────────────
def add_news_article(row: Dict[str, Any]) -> int:
@@ -132,8 +155,12 @@ def list_news_articles(category: Optional[str] = None, days: int = 1) -> List[Di
def add_trending_keyword(row: Dict[str, Any]) -> int:
with _conn() as conn:
cur = conn.execute(
"INSERT INTO trending_keywords(keyword, category, score, articles_count) VALUES(?,?,?,?)",
(row["keyword"], row["category"], float(row.get("score", 0.0)), int(row.get("articles_count", 0))),
"INSERT INTO trending_keywords(keyword, category, score, articles_count, source) VALUES(?,?,?,?,?)",
(
row["keyword"], row["category"],
float(row.get("score", 0.0)), int(row.get("articles_count", 0)),
row.get("source", "manual"),
),
)
return cur.lastrowid
@@ -276,3 +303,50 @@ def get_prompt_template(name: str) -> Optional[Dict[str, Any]]:
with _conn() as conn:
row = conn.execute("SELECT * FROM prompt_templates WHERE name=?", (name,)).fetchone()
return dict(row) if row else None
# ── external trends ─────────────────────────────────────────────
def add_external_trend(row: Dict[str, Any]) -> int:
"""`source` 필수 — naver_popular | google_trends. trending_keywords에 인서트."""
if "source" not in row:
raise ValueError("add_external_trend requires 'source' field")
return add_trending_keyword(row)
def list_trends(source: Optional[str] = None, category: Optional[str] = None,
days: int = 1) -> List[Dict[str, Any]]:
sql = "SELECT * FROM trending_keywords WHERE suggested_at >= datetime('now', ?)"
params: List[Any] = [f"-{int(days)} days"]
if source and source != "all":
sql += " AND source=?"
params.append(source)
if category:
sql += " AND category=?"
params.append(category)
sql += " ORDER BY suggested_at DESC, score DESC"
with _conn() as conn:
rows = conn.execute(sql, params).fetchall()
return [dict(r) for r in rows]
# ── account_preferences ─────────────────────────────────────────
def get_preferences() -> List[Dict[str, Any]]:
with _conn() as conn:
rows = conn.execute(
"SELECT category, weight, updated_at FROM account_preferences ORDER BY category ASC"
).fetchall()
return [dict(r) for r in rows]
def upsert_preferences(weights: Dict[str, float]) -> None:
"""전체 upsert. 기존에 있던 카테고리는 weight 갱신, 신규는 INSERT.
명시되지 않은 기존 카테고리는 그대로 둔다 (삭제 X). 삭제 필요 시 별도 API로."""
with _conn() as conn:
for cat, w in weights.items():
conn.execute("""
INSERT INTO account_preferences(category, weight)
VALUES(?,?)
ON CONFLICT(category) DO UPDATE SET
weight=excluded.weight,
updated_at=strftime('%Y-%m-%dT%H:%M:%fZ','now')
""", (cat, float(w)))

View File

@@ -81,3 +81,22 @@ def extract_for_category(category: str, limit: int = KEYWORDS_PER_CATEGORY) -> L
})
saved.append({"id": kid, **kw, "category": category})
return saved
def extract_with_weights(weights: Dict[str, float], total_limit: int) -> List[Dict[str, Any]]:
"""카테고리 가중치 비율대로 키워드를 분배 추출."""
from .config import DEFAULT_CATEGORY_SEEDS
if not weights or sum(weights.values()) == 0:
cats = list(DEFAULT_CATEGORY_SEEDS.keys())
weights = {c: 1.0 for c in cats}
total_weight = sum(weights.values())
out: List[Dict[str, Any]] = []
for category, w in weights.items():
if w <= 0:
continue
per_cat = round(total_limit * w / total_weight)
if per_cat <= 0:
continue
out.extend(extract_for_category(category, limit=per_cat))
return out

View File

@@ -15,7 +15,7 @@ from .config import (
CORS_ALLOW_ORIGINS, NAVER_CLIENT_ID, ANTHROPIC_API_KEY,
INSTA_DATA_PATH, DB_PATH, DEFAULT_CATEGORY_SEEDS, KEYWORDS_PER_CATEGORY,
)
from . import db, news_collector, keyword_extractor, card_writer, card_renderer
from . import db, news_collector, keyword_extractor, card_writer, card_renderer, trend_collector
logger = logging.getLogger(__name__)
app = FastAPI()
@@ -99,11 +99,16 @@ class ExtractRequest(BaseModel):
categories: Optional[list[str]] = None
async def _bg_extract(task_id: str, categories: list[str]):
async def _bg_extract(task_id: str, categories: Optional[list[str]] = None):
try:
db.update_task(task_id, "processing", 10, "추출 중")
for cat in categories:
keyword_extractor.extract_for_category(cat, limit=KEYWORDS_PER_CATEGORY)
prefs_rows = db.get_preferences()
weights = {p["category"]: p["weight"] for p in prefs_rows}
if categories:
# 사용자가 카테고리 명시한 경우만 그 서브셋으로 균등 가중치 (override)
weights = {c: 1.0 for c in categories}
total = KEYWORDS_PER_CATEGORY * max(1, len([w for w in weights.values() if w > 0]))
keyword_extractor.extract_with_weights(weights, total_limit=total)
db.update_task(task_id, "succeeded", 100, "완료", result_id=0)
except Exception as e:
logger.exception("extract failed")
@@ -119,7 +124,13 @@ def extract_keywords(req: ExtractRequest, bg: BackgroundTasks):
@app.get("/api/insta/keywords")
def list_keywords(category: Optional[str] = None, used: Optional[bool] = None):
def list_keywords(
category: Optional[str] = None,
used: Optional[bool] = None,
source: Optional[str] = None,
):
if source:
return {"items": db.list_trends(source=source, category=category, days=30)}
return {"items": db.list_trending_keywords(category=category, used=used)}
@@ -243,3 +254,52 @@ def get_prompt(name: str):
def upsert_prompt(name: str, body: TemplateBody):
db.upsert_prompt_template(name, body.template, body.description)
return db.get_prompt_template(name)
# ── Trends ───────────────────────────────────────────────────────
class TrendsCollectRequest(BaseModel):
categories: Optional[list[str]] = None
async def _bg_collect_trends(task_id: str, categories: list[str]):
try:
db.update_task(task_id, "processing", 10, "외부 트렌드 수집 중")
result = trend_collector.collect_all(categories)
msg = f"naver:{result['naver_popular']}, google:{result['google_trends']}"
db.update_task(task_id, "succeeded", 100, msg, result_id=sum(result.values()))
except Exception as e:
logger.exception("trends collect failed")
db.update_task(task_id, "failed", 0, "", error=str(e))
@app.post("/api/insta/trends/collect")
def collect_trends(req: TrendsCollectRequest, bg: BackgroundTasks):
cats = req.categories or list(DEFAULT_CATEGORY_SEEDS.keys())
tid = db.create_task("trends_collect", {"categories": cats})
bg.add_task(_bg_collect_trends, tid, cats)
return {"task_id": tid, "categories": cats}
@app.get("/api/insta/trends")
def list_trends_endpoint(
source: Optional[str] = None,
category: Optional[str] = None,
days: int = Query(1, ge=1, le=90),
):
return {"items": db.list_trends(source=source, category=category, days=days)}
# ── Preferences ──────────────────────────────────────────────────
class PreferencesBody(BaseModel):
categories: dict[str, float]
@app.get("/api/insta/preferences")
def get_preferences_endpoint():
return {"categories": db.get_preferences()}
@app.put("/api/insta/preferences")
def put_preferences_endpoint(body: PreferencesBody):
db.upsert_preferences(body.categories)
return {"categories": db.get_preferences()}

View File

@@ -0,0 +1,183 @@
"""외부 트렌드 수집 — NAVER 인기 + Google Trends + LLM 카테고리 분류.
Phase B Task 3: Google Trends integration via pytrends + Anthropic Haiku 분류 캐시 (24h TTL).
"""
import json
import logging
import re
import time
from typing import Any, Dict, List, Optional
import requests
from anthropic import Anthropic
from pytrends.request import TrendReq
from .config import (
NAVER_CLIENT_ID, NAVER_CLIENT_SECRET, DEFAULT_CATEGORY_SEEDS,
ANTHROPIC_API_KEY, ANTHROPIC_MODEL_HAIKU,
)
from . import db
from .news_collector import _clean
from .keyword_extractor import _count_nouns, _top_candidates
logger = logging.getLogger(__name__)
NEWS_URL = "https://openapi.naver.com/v1/search/news.json"
_NAVER_HEADERS = {
"X-Naver-Client-Id": NAVER_CLIENT_ID,
"X-Naver-Client-Secret": NAVER_CLIENT_SECRET,
}
def _seeds_for(category: str) -> List[str]:
pt = db.get_prompt_template("category_seeds")
if pt and pt.get("template"):
try:
data = json.loads(pt["template"])
if category in data:
return list(data[category])
except Exception:
pass
return list(DEFAULT_CATEGORY_SEEDS.get(category, []))
def fetch_naver_popular(category: str, per_seed: int = 30, top_n: int = 10) -> List[Dict[str, Any]]:
"""카테고리 시드 키워드들로 NAVER news.json `sort=sim` 호출,
응답 기사 묶음에서 빈도어 추출 후 상위 N개 반환."""
seeds = _seeds_for(category)
if not seeds:
return []
blob_parts: List[str] = []
for seed in seeds:
try:
resp = requests.get(
NEWS_URL,
headers=_NAVER_HEADERS,
params={"query": seed, "display": per_seed, "sort": "sim"},
timeout=10,
)
resp.raise_for_status()
for item in resp.json().get("items", []):
blob_parts.append(_clean(item.get("title", "")))
blob_parts.append(_clean(item.get("description", "")))
except Exception as e:
logger.warning("fetch_naver_popular seed=%s err=%s", seed, e)
continue
text = "\n".join(blob_parts)
counts = _count_nouns(text)
candidates = _top_candidates(counts, n=top_n)
if not candidates:
return []
max_count = candidates[0][1] or 1
return [
{
"keyword": k,
"category": category,
"source": "naver_popular",
"score": round(min(1.0, c / max_count), 4),
"articles_count": c,
}
for k, c in candidates
]
def collect_naver_popular_for(categories: List[str]) -> int:
total = 0
for cat in categories:
trends = fetch_naver_popular(cat)
for t in trends:
db.add_external_trend(t)
total += 1
return total
# ── LLM 분류 캐시 ────────────────────────────────────────────────────────────
_CACHE_TTL_SEC = 24 * 3600
_category_cache: Dict[str, tuple] = {} # keyword -> (category, expires_ts)
def _llm_classify_one(keyword: str) -> str:
"""Claude Haiku 1회 호출로 단일 키워드 분류."""
if not ANTHROPIC_API_KEY:
return "uncategorized"
seeds_template = db.get_prompt_template("category_seeds")
if seeds_template and seeds_template.get("template"):
try:
allowed = sorted(json.loads(seeds_template["template"]).keys())
except Exception:
allowed = sorted(DEFAULT_CATEGORY_SEEDS.keys())
else:
allowed = sorted(DEFAULT_CATEGORY_SEEDS.keys())
allowed.append("uncategorized")
client = Anthropic(api_key=ANTHROPIC_API_KEY)
msg = client.messages.create(
model=ANTHROPIC_MODEL_HAIKU,
max_tokens=20,
messages=[{
"role": "user",
"content": (
f"다음 한국어 트렌딩 키워드를 카테고리 중 하나로 분류해라. "
f"카테고리: {allowed}. 키워드: '{keyword}'. "
f"카테고리명 한 단어만 출력. 다른 텍스트 금지."
),
}],
)
raw = msg.content[0].text.strip().lower()
for cat in allowed:
if cat.lower() in raw:
return cat
return "uncategorized"
def classify_keyword(keyword: str) -> str:
now = time.time()
cached = _category_cache.get(keyword)
if cached and cached[1] > now:
return cached[0]
cat = _llm_classify_one(keyword)
_category_cache[keyword] = (cat, now + _CACHE_TTL_SEC)
return cat
# ── Google Trends ─────────────────────────────────────────────────────────────
def fetch_google_trends() -> List[Dict[str, Any]]:
"""pytrends 한국 daily trending searches. 실패 시 빈 리스트."""
try:
pytrends = TrendReq(hl="ko-KR", tz=540)
df = pytrends.trending_searches(pn="south_korea")
except Exception as e:
logger.warning("Google Trends fetch failed: %s", e)
return []
items: List[Dict[str, Any]] = []
for idx, row in df.iterrows():
kw = str(row.iloc[0]).strip()
if not kw:
continue
cat = classify_keyword(kw)
rank_score = round(max(0.0, 1.0 - (idx / max(1, len(df)))), 4)
items.append({
"keyword": kw,
"category": cat,
"source": "google_trends",
"score": rank_score,
"articles_count": 0,
})
return items
def collect_google_trends() -> int:
items = fetch_google_trends()
for it in items:
db.add_external_trend(it)
return len(items)
def collect_all(categories: List[str]) -> Dict[str, int]:
naver_n = collect_naver_popular_for(categories)
google_n = collect_google_trends()
return {"naver_popular": naver_n, "google_trends": google_n}

View File

@@ -7,3 +7,4 @@ jinja2>=3.1.4
playwright==1.48.0
pytest>=8.0
pytest-asyncio>=0.24
pytrends>=4.9

View File

@@ -24,7 +24,7 @@ def tmp_db(monkeypatch):
pass
def test_init_db_creates_six_tables(tmp_db):
def test_init_db_creates_seven_tables(tmp_db):
with db_module._conn() as conn:
rows = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
@@ -33,6 +33,7 @@ def test_init_db_creates_six_tables(tmp_db):
assert names == sorted([
"news_articles", "trending_keywords", "card_slates",
"card_assets", "generation_tasks", "prompt_templates",
"account_preferences",
])

View File

@@ -0,0 +1,71 @@
import os
import gc
import tempfile
from unittest.mock import patch
import pytest
from app import db as db_module
from app import keyword_extractor
@pytest.fixture
def tmp_db(monkeypatch):
fd, path = tempfile.mkstemp(suffix=".db")
os.close(fd)
monkeypatch.setattr(db_module, "DB_PATH", path)
db_module.init_db()
yield path
gc.collect()
for ext in ("", "-wal", "-shm"):
try:
os.remove(path + ext)
except OSError:
pass
def test_extract_with_weights_proportional(tmp_db, monkeypatch):
calls = []
def fake_extract(category, limit):
calls.append((category, limit))
return [{"id": i, "keyword": f"{category}{i}", "category": category, "score": 0.5}
for i in range(limit)]
monkeypatch.setattr(keyword_extractor, "extract_for_category", fake_extract)
out = keyword_extractor.extract_with_weights(
{"economy": 0.6, "psychology": 0.3, "celebrity": 0.1}, total_limit=10,
)
by_cat = {c: l for c, l in calls}
assert by_cat == {"economy": 6, "psychology": 3, "celebrity": 1}
assert len(out) == 10
def test_extract_with_weights_skips_zero(tmp_db, monkeypatch):
calls = []
def fake_extract(category, limit):
calls.append((category, limit))
return []
monkeypatch.setattr(keyword_extractor, "extract_for_category", fake_extract)
keyword_extractor.extract_with_weights(
{"economy": 1.0, "celebrity": 0.0}, total_limit=10,
)
cats_called = [c for c, _ in calls]
assert "celebrity" not in cats_called
assert "economy" in cats_called
def test_extract_with_weights_fallback_to_equal(tmp_db, monkeypatch):
calls = []
def fake_extract(category, limit):
calls.append((category, limit))
return []
monkeypatch.setattr(keyword_extractor, "extract_for_category", fake_extract)
keyword_extractor.extract_with_weights({}, total_limit=9)
by_cat = {c: l for c, l in calls}
assert set(by_cat.keys()) == {"economy", "psychology", "celebrity"}
assert all(l == 3 for l in by_cat.values())

View File

@@ -0,0 +1,83 @@
import os
import gc
import tempfile
import pytest
from fastapi.testclient import TestClient
from app import db as db_module
@pytest.fixture
def client(monkeypatch):
fd, path = tempfile.mkstemp(suffix=".db")
os.close(fd)
monkeypatch.setattr(db_module, "DB_PATH", path)
db_module.init_db()
from app import main
monkeypatch.setattr(main, "DB_PATH", path)
with TestClient(main.app) as c:
yield c
gc.collect()
for ext in ("", "-wal", "-shm"):
try:
os.remove(path + ext)
except OSError:
pass
def test_get_preferences_returns_defaults(client):
resp = client.get("/api/insta/preferences")
assert resp.status_code == 200
cats = {p["category"]: p["weight"] for p in resp.json()["categories"]}
assert cats == {"economy": 1.0, "psychology": 1.0, "celebrity": 1.0}
def test_put_preferences_upsert(client):
resp = client.put("/api/insta/preferences",
json={"categories": {"economy": 0.7, "psychology": 0.2, "tech": 0.5}})
assert resp.status_code == 200
cats = {p["category"]: p["weight"] for p in resp.json()["categories"]}
assert cats["economy"] == 0.7
assert cats["tech"] == 0.5
def test_list_trends_filter(client):
db_module.add_external_trend({"keyword": "A", "category": "economy",
"source": "naver_popular", "score": 1.0})
db_module.add_external_trend({"keyword": "B", "category": "celebrity",
"source": "google_trends", "score": 0.8})
resp = client.get("/api/insta/trends?source=naver_popular")
items = resp.json()["items"]
assert {it["keyword"] for it in items} == {"A"}
def test_collect_trends_kicks_background(client, monkeypatch):
from app import main, trend_collector
captured = {"called": False}
def fake_collect_all(cats):
captured["called"] = True
return {"naver_popular": 3, "google_trends": 2}
monkeypatch.setattr(trend_collector, "collect_all", fake_collect_all)
resp = client.post("/api/insta/trends/collect", json={})
assert resp.status_code == 200
task_id = resp.json()["task_id"]
for _ in range(20):
st = client.get(f"/api/insta/tasks/{task_id}").json()
if st["status"] in ("succeeded", "failed"):
break
assert st["status"] == "succeeded"
assert captured["called"] is True
def test_list_keywords_filters_by_source(client):
db_module.add_trending_keyword({"keyword": "M", "category": "economy",
"score": 0.4, "articles_count": 1, "source": "manual"})
db_module.add_external_trend({"keyword": "N", "category": "economy",
"source": "naver_popular", "score": 0.9})
resp = client.get("/api/insta/keywords?source=manual")
items = resp.json()["items"]
assert {it["keyword"] for it in items} == {"M"}

View File

@@ -0,0 +1,77 @@
import os
import gc
import tempfile
import pytest
from app import db as db_module
@pytest.fixture
def tmp_db(monkeypatch):
fd, path = tempfile.mkstemp(suffix=".db")
os.close(fd)
monkeypatch.setattr(db_module, "DB_PATH", path)
db_module.init_db()
yield path
gc.collect()
for ext in ("", "-wal", "-shm"):
try:
os.remove(path + ext)
except OSError:
pass
def test_init_db_creates_account_preferences(tmp_db):
with db_module._conn() as conn:
rows = conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
names = {r[0] for r in rows}
assert "account_preferences" in names
def test_init_db_seeds_default_weights(tmp_db):
prefs = db_module.get_preferences()
cats = {p["category"]: p["weight"] for p in prefs}
assert cats["economy"] == pytest.approx(1.0)
assert cats["psychology"] == pytest.approx(1.0)
assert cats["celebrity"] == pytest.approx(1.0)
def test_upsert_preferences_replaces_weights(tmp_db):
db_module.upsert_preferences({"economy": 0.6, "psychology": 0.3, "celebrity": 0.1, "tech": 0.5})
prefs = {p["category"]: p["weight"] for p in db_module.get_preferences()}
assert prefs["economy"] == pytest.approx(0.6)
assert prefs["tech"] == pytest.approx(0.5)
assert "celebrity" in prefs and prefs["celebrity"] == pytest.approx(0.1)
def test_trending_keywords_source_column_exists(tmp_db):
with db_module._conn() as conn:
cols = [r[1] for r in conn.execute("PRAGMA table_info(trending_keywords)").fetchall()]
assert "source" in cols
def test_add_trending_keyword_default_source(tmp_db):
kid = db_module.add_trending_keyword({
"keyword": "K", "category": "economy", "score": 0.5, "articles_count": 3,
})
with db_module._conn() as conn:
row = conn.execute("SELECT source FROM trending_keywords WHERE id=?", (kid,)).fetchone()
assert row[0] == "manual"
def test_add_external_trend_stores_source(tmp_db):
tid = db_module.add_external_trend({
"keyword": "급등주", "category": "economy", "source": "naver_popular", "score": 0.9,
})
rows = db_module.list_trends(source="naver_popular")
assert any(r["id"] == tid and r["keyword"] == "급등주" for r in rows)
def test_list_trends_filters_by_source_and_category(tmp_db):
db_module.add_external_trend({"keyword": "A", "category": "economy", "source": "naver_popular", "score": 1.0})
db_module.add_external_trend({"keyword": "B", "category": "celebrity", "source": "google_trends", "score": 1.0})
only_naver = db_module.list_trends(source="naver_popular")
assert {r["keyword"] for r in only_naver} == {"A"}
only_celeb_google = db_module.list_trends(source="google_trends", category="celebrity")
assert {r["keyword"] for r in only_celeb_google} == {"B"}

View File

@@ -0,0 +1,121 @@
import os
import gc
import tempfile
from unittest.mock import patch, MagicMock
import pytest
from app import db as db_module
from app import trend_collector
@pytest.fixture
def tmp_db(monkeypatch):
fd, path = tempfile.mkstemp(suffix=".db")
os.close(fd)
monkeypatch.setattr(db_module, "DB_PATH", path)
db_module.init_db()
yield path
gc.collect()
for ext in ("", "-wal", "-shm"):
try:
os.remove(path + ext)
except OSError:
pass
NAVER_RESPONSE = {
"items": [
{"title": "<b>기준금리</b> 인상", "link": "https://n.news.naver.com/a/1", "description": "한국은행 발표"},
{"title": "환율 급등", "link": "https://n.news.naver.com/a/2", "description": "달러 강세"},
{"title": "기준금리 추가 인상", "link": "https://n.news.naver.com/a/3", "description": "추가 발표"},
],
}
def test_fetch_naver_popular_extracts_top_terms(tmp_db, monkeypatch):
fake_resp = MagicMock()
fake_resp.json.return_value = NAVER_RESPONSE
fake_resp.raise_for_status.return_value = None
with patch.object(trend_collector.requests, "get", return_value=fake_resp):
trends = trend_collector.fetch_naver_popular("economy", per_seed=10, top_n=5)
keywords = [t["keyword"] for t in trends]
assert "기준금리" in keywords
for t in trends:
assert t["category"] == "economy"
assert t["source"] == "naver_popular"
assert 0.0 <= t["score"] <= 1.0
def test_collect_naver_writes_to_db(tmp_db, monkeypatch):
fake_resp = MagicMock()
fake_resp.json.return_value = NAVER_RESPONSE
fake_resp.raise_for_status.return_value = None
with patch.object(trend_collector.requests, "get", return_value=fake_resp):
n = trend_collector.collect_naver_popular_for(["economy"])
assert n > 0
rows = db_module.list_trends(source="naver_popular")
assert len(rows) > 0
assert all(r["source"] == "naver_popular" for r in rows)
def test_classify_keyword_with_cache(monkeypatch):
calls = {"n": 0}
def fake_claude(keyword: str) -> str:
calls["n"] += 1
return "economy"
monkeypatch.setattr(trend_collector, "_llm_classify_one", fake_claude)
trend_collector._category_cache.clear()
c1 = trend_collector.classify_keyword("기준금리")
c2 = trend_collector.classify_keyword("기준금리")
assert c1 == c2 == "economy"
assert calls["n"] == 1
def test_fetch_google_trends_parses_and_classifies(tmp_db, monkeypatch):
class FakePyTrends:
def __init__(self, *_a, **_kw):
pass
def trending_searches(self, pn="south_korea"):
import pandas as pd
return pd.DataFrame({"0": ["기준금리", "BTS 컴백", "스트레스 관리"]})
monkeypatch.setattr(trend_collector, "TrendReq", FakePyTrends)
monkeypatch.setattr(trend_collector, "classify_keyword",
lambda kw: {"기준금리": "economy", "BTS 컴백": "celebrity",
"스트레스 관리": "psychology"}.get(kw, "uncategorized"))
trends = trend_collector.fetch_google_trends()
by_kw = {t["keyword"]: t for t in trends}
assert by_kw["기준금리"]["category"] == "economy"
assert by_kw["BTS 컴백"]["category"] == "celebrity"
assert by_kw["스트레스 관리"]["category"] == "psychology"
assert all(t["source"] == "google_trends" for t in trends)
def test_collect_all_invokes_both_sources(tmp_db, monkeypatch):
monkeypatch.setattr(trend_collector, "collect_naver_popular_for",
lambda cats: 5)
monkeypatch.setattr(trend_collector, "collect_google_trends",
lambda: 3)
out = trend_collector.collect_all(["economy"])
assert out == {"naver_popular": 5, "google_trends": 3}
def test_fetch_google_trends_graceful_on_pytrends_failure(monkeypatch):
class FakePyTrends:
def __init__(self, *_a, **_kw):
pass
def trending_searches(self, pn="south_korea"):
raise RuntimeError("rate limited")
monkeypatch.setattr(trend_collector, "TrendReq", FakePyTrends)
out = trend_collector.fetch_google_trends()
assert out == []