import os import gc import tempfile from unittest.mock import patch import pytest from app import db as db_module from app import keyword_extractor @pytest.fixture def tmp_db(monkeypatch): fd, path = tempfile.mkstemp(suffix=".db") os.close(fd) monkeypatch.setattr(db_module, "DB_PATH", path) db_module.init_db() yield path gc.collect() for ext in ("", "-wal", "-shm"): try: os.remove(path + ext) except OSError: pass def test_extract_with_weights_proportional(tmp_db, monkeypatch): calls = [] def fake_extract(category, limit): calls.append((category, limit)) return [{"id": i, "keyword": f"{category}{i}", "category": category, "score": 0.5} for i in range(limit)] monkeypatch.setattr(keyword_extractor, "extract_for_category", fake_extract) out = keyword_extractor.extract_with_weights( {"economy": 0.6, "psychology": 0.3, "celebrity": 0.1}, total_limit=10, ) by_cat = {c: l for c, l in calls} assert by_cat == {"economy": 6, "psychology": 3, "celebrity": 1} assert len(out) == 10 def test_extract_with_weights_skips_zero(tmp_db, monkeypatch): calls = [] def fake_extract(category, limit): calls.append((category, limit)) return [] monkeypatch.setattr(keyword_extractor, "extract_for_category", fake_extract) keyword_extractor.extract_with_weights( {"economy": 1.0, "celebrity": 0.0}, total_limit=10, ) cats_called = [c for c, _ in calls] assert "celebrity" not in cats_called assert "economy" in cats_called def test_extract_with_weights_fallback_to_equal(tmp_db, monkeypatch): calls = [] def fake_extract(category, limit): calls.append((category, limit)) return [] monkeypatch.setattr(keyword_extractor, "extract_for_category", fake_extract) keyword_extractor.extract_with_weights({}, total_limit=9) by_cat = {c: l for c, l in calls} assert set(by_cat.keys()) == {"economy", "psychology", "celebrity"} assert all(l == 3 for l in by_cat.values())