72 lines
2.1 KiB
Python
72 lines
2.1 KiB
Python
import os
|
|
import gc
|
|
import tempfile
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
from app import db as db_module
|
|
from app import keyword_extractor
|
|
|
|
|
|
@pytest.fixture
|
|
def tmp_db(monkeypatch):
|
|
fd, path = tempfile.mkstemp(suffix=".db")
|
|
os.close(fd)
|
|
monkeypatch.setattr(db_module, "DB_PATH", path)
|
|
db_module.init_db()
|
|
yield path
|
|
gc.collect()
|
|
for ext in ("", "-wal", "-shm"):
|
|
try:
|
|
os.remove(path + ext)
|
|
except OSError:
|
|
pass
|
|
|
|
|
|
def test_extract_with_weights_proportional(tmp_db, monkeypatch):
|
|
calls = []
|
|
|
|
def fake_extract(category, limit):
|
|
calls.append((category, limit))
|
|
return [{"id": i, "keyword": f"{category}{i}", "category": category, "score": 0.5}
|
|
for i in range(limit)]
|
|
|
|
monkeypatch.setattr(keyword_extractor, "extract_for_category", fake_extract)
|
|
out = keyword_extractor.extract_with_weights(
|
|
{"economy": 0.6, "psychology": 0.3, "celebrity": 0.1}, total_limit=10,
|
|
)
|
|
by_cat = {c: l for c, l in calls}
|
|
assert by_cat == {"economy": 6, "psychology": 3, "celebrity": 1}
|
|
assert len(out) == 10
|
|
|
|
|
|
def test_extract_with_weights_skips_zero(tmp_db, monkeypatch):
|
|
calls = []
|
|
|
|
def fake_extract(category, limit):
|
|
calls.append((category, limit))
|
|
return []
|
|
|
|
monkeypatch.setattr(keyword_extractor, "extract_for_category", fake_extract)
|
|
keyword_extractor.extract_with_weights(
|
|
{"economy": 1.0, "celebrity": 0.0}, total_limit=10,
|
|
)
|
|
cats_called = [c for c, _ in calls]
|
|
assert "celebrity" not in cats_called
|
|
assert "economy" in cats_called
|
|
|
|
|
|
def test_extract_with_weights_fallback_to_equal(tmp_db, monkeypatch):
|
|
calls = []
|
|
|
|
def fake_extract(category, limit):
|
|
calls.append((category, limit))
|
|
return []
|
|
|
|
monkeypatch.setattr(keyword_extractor, "extract_for_category", fake_extract)
|
|
keyword_extractor.extract_with_weights({}, total_limit=9)
|
|
by_cat = {c: l for c, l in calls}
|
|
assert set(by_cat.keys()) == {"economy", "psychology", "celebrity"}
|
|
assert all(l == 3 for l in by_cat.values())
|