feat(insta-lab): keyword_extractor.extract_with_weights for category proportions
This commit is contained in:
@@ -81,3 +81,22 @@ def extract_for_category(category: str, limit: int = KEYWORDS_PER_CATEGORY) -> L
|
|||||||
})
|
})
|
||||||
saved.append({"id": kid, **kw, "category": category})
|
saved.append({"id": kid, **kw, "category": category})
|
||||||
return saved
|
return saved
|
||||||
|
|
||||||
|
|
||||||
|
def extract_with_weights(weights: Dict[str, float], total_limit: int) -> List[Dict[str, Any]]:
|
||||||
|
"""카테고리 가중치 비율대로 키워드를 분배 추출."""
|
||||||
|
from .config import DEFAULT_CATEGORY_SEEDS
|
||||||
|
if not weights or sum(weights.values()) == 0:
|
||||||
|
cats = list(DEFAULT_CATEGORY_SEEDS.keys())
|
||||||
|
weights = {c: 1.0 for c in cats}
|
||||||
|
|
||||||
|
total_weight = sum(weights.values())
|
||||||
|
out: List[Dict[str, Any]] = []
|
||||||
|
for category, w in weights.items():
|
||||||
|
if w <= 0:
|
||||||
|
continue
|
||||||
|
per_cat = round(total_limit * w / total_weight)
|
||||||
|
if per_cat <= 0:
|
||||||
|
continue
|
||||||
|
out.extend(extract_for_category(category, limit=per_cat))
|
||||||
|
return out
|
||||||
|
|||||||
71
insta-lab/tests/test_extract_with_weights.py
Normal file
71
insta-lab/tests/test_extract_with_weights.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
import os
|
||||||
|
import gc
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app import db as db_module
|
||||||
|
from app import keyword_extractor
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmp_db(monkeypatch):
|
||||||
|
fd, path = tempfile.mkstemp(suffix=".db")
|
||||||
|
os.close(fd)
|
||||||
|
monkeypatch.setattr(db_module, "DB_PATH", path)
|
||||||
|
db_module.init_db()
|
||||||
|
yield path
|
||||||
|
gc.collect()
|
||||||
|
for ext in ("", "-wal", "-shm"):
|
||||||
|
try:
|
||||||
|
os.remove(path + ext)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_with_weights_proportional(tmp_db, monkeypatch):
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def fake_extract(category, limit):
|
||||||
|
calls.append((category, limit))
|
||||||
|
return [{"id": i, "keyword": f"{category}{i}", "category": category, "score": 0.5}
|
||||||
|
for i in range(limit)]
|
||||||
|
|
||||||
|
monkeypatch.setattr(keyword_extractor, "extract_for_category", fake_extract)
|
||||||
|
out = keyword_extractor.extract_with_weights(
|
||||||
|
{"economy": 0.6, "psychology": 0.3, "celebrity": 0.1}, total_limit=10,
|
||||||
|
)
|
||||||
|
by_cat = {c: l for c, l in calls}
|
||||||
|
assert by_cat == {"economy": 6, "psychology": 3, "celebrity": 1}
|
||||||
|
assert len(out) == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_with_weights_skips_zero(tmp_db, monkeypatch):
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def fake_extract(category, limit):
|
||||||
|
calls.append((category, limit))
|
||||||
|
return []
|
||||||
|
|
||||||
|
monkeypatch.setattr(keyword_extractor, "extract_for_category", fake_extract)
|
||||||
|
keyword_extractor.extract_with_weights(
|
||||||
|
{"economy": 1.0, "celebrity": 0.0}, total_limit=10,
|
||||||
|
)
|
||||||
|
cats_called = [c for c, _ in calls]
|
||||||
|
assert "celebrity" not in cats_called
|
||||||
|
assert "economy" in cats_called
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_with_weights_fallback_to_equal(tmp_db, monkeypatch):
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def fake_extract(category, limit):
|
||||||
|
calls.append((category, limit))
|
||||||
|
return []
|
||||||
|
|
||||||
|
monkeypatch.setattr(keyword_extractor, "extract_for_category", fake_extract)
|
||||||
|
keyword_extractor.extract_with_weights({}, total_limit=9)
|
||||||
|
by_cat = {c: l for c, l in calls}
|
||||||
|
assert set(by_cat.keys()) == {"economy", "psychology", "celebrity"}
|
||||||
|
assert all(l == 3 for l in by_cat.values())
|
||||||
Reference in New Issue
Block a user