122 lines
4.1 KiB
Python
122 lines
4.1 KiB
Python
import os
|
|
import gc
|
|
import tempfile
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
import pytest
|
|
|
|
from app import db as db_module
|
|
from app import trend_collector
|
|
|
|
|
|
@pytest.fixture
|
|
def tmp_db(monkeypatch):
|
|
fd, path = tempfile.mkstemp(suffix=".db")
|
|
os.close(fd)
|
|
monkeypatch.setattr(db_module, "DB_PATH", path)
|
|
db_module.init_db()
|
|
yield path
|
|
gc.collect()
|
|
for ext in ("", "-wal", "-shm"):
|
|
try:
|
|
os.remove(path + ext)
|
|
except OSError:
|
|
pass
|
|
|
|
|
|
NAVER_RESPONSE = {
|
|
"items": [
|
|
{"title": "<b>기준금리</b> 인상", "link": "https://n.news.naver.com/a/1", "description": "한국은행 발표"},
|
|
{"title": "환율 급등", "link": "https://n.news.naver.com/a/2", "description": "달러 강세"},
|
|
{"title": "기준금리 추가 인상", "link": "https://n.news.naver.com/a/3", "description": "추가 발표"},
|
|
],
|
|
}
|
|
|
|
|
|
def test_fetch_naver_popular_extracts_top_terms(tmp_db, monkeypatch):
|
|
fake_resp = MagicMock()
|
|
fake_resp.json.return_value = NAVER_RESPONSE
|
|
fake_resp.raise_for_status.return_value = None
|
|
|
|
with patch.object(trend_collector.requests, "get", return_value=fake_resp):
|
|
trends = trend_collector.fetch_naver_popular("economy", per_seed=10, top_n=5)
|
|
|
|
keywords = [t["keyword"] for t in trends]
|
|
assert "기준금리" in keywords
|
|
for t in trends:
|
|
assert t["category"] == "economy"
|
|
assert t["source"] == "naver_popular"
|
|
assert 0.0 <= t["score"] <= 1.0
|
|
|
|
|
|
def test_collect_naver_writes_to_db(tmp_db, monkeypatch):
|
|
fake_resp = MagicMock()
|
|
fake_resp.json.return_value = NAVER_RESPONSE
|
|
fake_resp.raise_for_status.return_value = None
|
|
with patch.object(trend_collector.requests, "get", return_value=fake_resp):
|
|
n = trend_collector.collect_naver_popular_for(["economy"])
|
|
assert n > 0
|
|
rows = db_module.list_trends(source="naver_popular")
|
|
assert len(rows) > 0
|
|
assert all(r["source"] == "naver_popular" for r in rows)
|
|
|
|
|
|
def test_classify_keyword_with_cache(monkeypatch):
|
|
calls = {"n": 0}
|
|
|
|
def fake_claude(keyword: str) -> str:
|
|
calls["n"] += 1
|
|
return "economy"
|
|
|
|
monkeypatch.setattr(trend_collector, "_llm_classify_one", fake_claude)
|
|
trend_collector._category_cache.clear()
|
|
|
|
c1 = trend_collector.classify_keyword("기준금리")
|
|
c2 = trend_collector.classify_keyword("기준금리")
|
|
assert c1 == c2 == "economy"
|
|
assert calls["n"] == 1
|
|
|
|
|
|
def test_fetch_google_trends_parses_and_classifies(tmp_db, monkeypatch):
|
|
class FakePyTrends:
|
|
def __init__(self, *_a, **_kw):
|
|
pass
|
|
|
|
def trending_searches(self, pn="south_korea"):
|
|
import pandas as pd
|
|
return pd.DataFrame({"0": ["기준금리", "BTS 컴백", "스트레스 관리"]})
|
|
|
|
monkeypatch.setattr(trend_collector, "TrendReq", FakePyTrends)
|
|
monkeypatch.setattr(trend_collector, "classify_keyword",
|
|
lambda kw: {"기준금리": "economy", "BTS 컴백": "celebrity",
|
|
"스트레스 관리": "psychology"}.get(kw, "uncategorized"))
|
|
|
|
trends = trend_collector.fetch_google_trends()
|
|
by_kw = {t["keyword"]: t for t in trends}
|
|
assert by_kw["기준금리"]["category"] == "economy"
|
|
assert by_kw["BTS 컴백"]["category"] == "celebrity"
|
|
assert by_kw["스트레스 관리"]["category"] == "psychology"
|
|
assert all(t["source"] == "google_trends" for t in trends)
|
|
|
|
|
|
def test_collect_all_invokes_both_sources(tmp_db, monkeypatch):
|
|
monkeypatch.setattr(trend_collector, "collect_naver_popular_for",
|
|
lambda cats: 5)
|
|
monkeypatch.setattr(trend_collector, "collect_google_trends",
|
|
lambda: 3)
|
|
out = trend_collector.collect_all(["economy"])
|
|
assert out == {"naver_popular": 5, "google_trends": 3}
|
|
|
|
|
|
def test_fetch_google_trends_graceful_on_pytrends_failure(monkeypatch):
|
|
class FakePyTrends:
|
|
def __init__(self, *_a, **_kw):
|
|
pass
|
|
|
|
def trending_searches(self, pn="south_korea"):
|
|
raise RuntimeError("rate limited")
|
|
|
|
monkeypatch.setattr(trend_collector, "TrendReq", FakePyTrends)
|
|
out = trend_collector.fetch_google_trends()
|
|
assert out == []
|