From 6bb5c2fb40c436fe34e781470ba44c86d1291cd0 Mon Sep 17 00:00:00 2001 From: gahusb Date: Sat, 16 May 2026 17:51:16 +0900 Subject: [PATCH] feat(insta-lab): keyword_extractor.extract_with_weights for category proportions --- insta-lab/app/keyword_extractor.py | 19 ++++++ insta-lab/tests/test_extract_with_weights.py | 71 ++++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 insta-lab/tests/test_extract_with_weights.py diff --git a/insta-lab/app/keyword_extractor.py b/insta-lab/app/keyword_extractor.py index 2c307e1..31c10bd 100644 --- a/insta-lab/app/keyword_extractor.py +++ b/insta-lab/app/keyword_extractor.py @@ -81,3 +81,22 @@ def extract_for_category(category: str, limit: int = KEYWORDS_PER_CATEGORY) -> L }) saved.append({"id": kid, **kw, "category": category}) return saved + + +def extract_with_weights(weights: Dict[str, float], total_limit: int) -> List[Dict[str, Any]]: + """카테고리 가중치 비율대로 키워드를 분배 추출.""" + from .config import DEFAULT_CATEGORY_SEEDS + if not weights or sum(weights.values()) == 0: + cats = list(DEFAULT_CATEGORY_SEEDS.keys()) + weights = {c: 1.0 for c in cats} + + total_weight = sum(weights.values()) + out: List[Dict[str, Any]] = [] + for category, w in weights.items(): + if w <= 0: + continue + per_cat = round(total_limit * w / total_weight) + if per_cat <= 0: + continue + out.extend(extract_for_category(category, limit=per_cat)) + return out diff --git a/insta-lab/tests/test_extract_with_weights.py b/insta-lab/tests/test_extract_with_weights.py new file mode 100644 index 0000000..534d6a5 --- /dev/null +++ b/insta-lab/tests/test_extract_with_weights.py @@ -0,0 +1,71 @@ +import os +import gc +import tempfile +from unittest.mock import patch + +import pytest + +from app import db as db_module +from app import keyword_extractor + + +@pytest.fixture +def tmp_db(monkeypatch): + fd, path = tempfile.mkstemp(suffix=".db") + os.close(fd) + monkeypatch.setattr(db_module, "DB_PATH", path) + db_module.init_db() + yield path + gc.collect() + for ext in ("", "-wal", "-shm"): + try: + os.remove(path + ext) + except OSError: + pass + + +def test_extract_with_weights_proportional(tmp_db, monkeypatch): + calls = [] + + def fake_extract(category, limit): + calls.append((category, limit)) + return [{"id": i, "keyword": f"{category}{i}", "category": category, "score": 0.5} + for i in range(limit)] + + monkeypatch.setattr(keyword_extractor, "extract_for_category", fake_extract) + out = keyword_extractor.extract_with_weights( + {"economy": 0.6, "psychology": 0.3, "celebrity": 0.1}, total_limit=10, + ) + by_cat = {c: l for c, l in calls} + assert by_cat == {"economy": 6, "psychology": 3, "celebrity": 1} + assert len(out) == 10 + + +def test_extract_with_weights_skips_zero(tmp_db, monkeypatch): + calls = [] + + def fake_extract(category, limit): + calls.append((category, limit)) + return [] + + monkeypatch.setattr(keyword_extractor, "extract_for_category", fake_extract) + keyword_extractor.extract_with_weights( + {"economy": 1.0, "celebrity": 0.0}, total_limit=10, + ) + cats_called = [c for c, _ in calls] + assert "celebrity" not in cats_called + assert "economy" in cats_called + + +def test_extract_with_weights_fallback_to_equal(tmp_db, monkeypatch): + calls = [] + + def fake_extract(category, limit): + calls.append((category, limit)) + return [] + + monkeypatch.setattr(keyword_extractor, "extract_for_category", fake_extract) + keyword_extractor.extract_with_weights({}, total_limit=9) + by_cat = {c: l for c, l in calls} + assert set(by_cat.keys()) == {"economy", "psychology", "celebrity"} + assert all(l == 3 for l in by_cat.values())