Files
web-page-backend/lotto/tests/test_weight_evolver.py

199 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# lotto/tests/test_weight_evolver.py
import json
import math
import pytest
from app import weight_evolver as we
def test_clamp_and_normalize_min_floor():
"""모든 값이 0.05 이상이 되도록 보장 + 합=1.0."""
W = we.clamp_and_normalize([0.01, 0.6, 0.2, 0.1, 0.09])
assert all(w >= 0.05 - 1e-9 for w in W)
assert abs(sum(W) - 1.0) < 1e-9
def test_clamp_and_normalize_negative_becomes_floor():
W = we.clamp_and_normalize([-0.1, 0.5, 0.3, 0.2, 0.1])
assert W[0] >= 0.05 - 1e-9
assert abs(sum(W) - 1.0) < 1e-9
def test_perturbation_changes_around_base():
"""σ=0.05 정규분포 perturbation 후 정규화 — 각 값이 합리적 범위 안."""
base = [0.2, 0.2, 0.2, 0.2, 0.2]
W = we.perturb_weights(base, sigma=0.05, seed=42)
assert abs(sum(W) - 1.0) < 1e-9
assert all(w >= 0.05 - 1e-9 for w in W)
def test_dirichlet_random_distribution():
"""Dirichlet α=2 — 5종 비음수 합=1."""
W = we.dirichlet_weights(alpha=2.0, seed=42)
assert abs(sum(W) - 1.0) < 1e-9
assert all(0.05 - 1e-9 <= w <= 1.0 for w in W)
def test_generate_weekly_candidates_count():
"""6개 후보 생성 — 4 perturb + 2 dirichlet."""
base = [0.2, 0.2, 0.2, 0.2, 0.2]
trials = we.generate_weekly_candidates(base, seed=42)
assert len(trials) == 6
sources = [t["source"] for t in trials]
assert sources.count("perturb") == 4
assert sources.count("dirichlet") == 2
days = sorted(t["day_of_week"] for t in trials)
assert days == [0, 1, 2, 3, 4, 5]
def test_calc_pick_score_six_match():
"""6개 모두 일치 → 1등 → base=1.0 + bonus 1.0 = 2.0."""
score = we.calc_pick_score([1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6])
assert score == pytest.approx(2.0)
def test_calc_pick_score_four_match():
"""4개 일치 → 4등 → base=4/6 + bonus 0.3."""
score = we.calc_pick_score([1, 2, 3, 4, 7, 8], [1, 2, 3, 4, 5, 6])
assert score == pytest.approx(4/6 + 0.3)
def test_calc_pick_score_three_match():
"""3개 일치 → 5등 → base=3/6 + bonus 0.1."""
score = we.calc_pick_score([1, 2, 3, 7, 8, 9], [1, 2, 3, 4, 5, 6])
assert score == pytest.approx(3/6 + 0.1)
def test_calc_pick_score_two_match_no_bonus():
"""2개 일치 → 미당첨 → base=2/6 + bonus 0."""
score = we.calc_pick_score([1, 2, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6])
assert score == pytest.approx(2/6)
def test_decide_base_update_winner_4plus_replaces():
"""winner_max_correct ≥ 4 → 교체."""
current = [0.2, 0.2, 0.2, 0.2, 0.2]
winner_W = [0.1, 0.3, 0.2, 0.3, 0.1]
new_base, reason = we.decide_base_update(
winner_max_correct=4,
winner_W=winner_W,
current_base=current,
)
assert new_base == winner_W
assert reason == "winner_4plus"
def test_decide_base_update_winner_3_ema_blend():
"""winner_max_correct = 3 → 0.3*winner + 0.7*current."""
current = [0.2, 0.2, 0.2, 0.2, 0.2]
winner_W = [0.1, 0.3, 0.2, 0.3, 0.1]
new_base, reason = we.decide_base_update(
winner_max_correct=3,
winner_W=winner_W,
current_base=current,
)
expected = [0.3 * w + 0.7 * c for w, c in zip(winner_W, current)]
assert all(abs(a - b) < 1e-9 for a, b in zip(new_base, expected))
assert reason == "ema_blend"
def test_decide_base_update_winner_lt3_unchanged():
"""winner_max_correct ≤ 2 → 직전 base 유지."""
current = [0.2, 0.2, 0.2, 0.2, 0.2]
winner_W = [0.1, 0.3, 0.2, 0.3, 0.1]
new_base, reason = we.decide_base_update(
winner_max_correct=2,
winner_W=winner_W,
current_base=current,
)
assert new_base == current
assert reason == "unchanged"
def test_decide_base_update_cold_start_returns_default():
"""current_base=None (첫 회) → 균등 default 반환."""
winner_W = [0.1, 0.3, 0.2, 0.3, 0.1]
new_base, reason = we.decide_base_update(
winner_max_correct=4,
winner_W=winner_W,
current_base=None,
)
assert new_base == winner_W
assert reason == "winner_4plus"
def test_select_winner_by_lift_gating():
# engine_w 3개 + random_null 기준. lift = engine 등수점수 random 등수점수
per_w = [
{"trial_id": 1, "day_of_week": 0, "weight": [0.2]*5, "prize_score": 5.0},
{"trial_id": 2, "day_of_week": 1, "weight": [0.3,0.2,0.2,0.2,0.1], "prize_score": 9.0},
{"trial_id": 3, "day_of_week": 2, "weight": [0.1,0.3,0.2,0.2,0.2], "prize_score": 4.0},
]
# random baseline이 8.0이면 lift는 -3, +1, -4 → 최대 lift(+1) < ε(2) → 게이팅
winner = we.select_winner_by_lift(per_w, random_score=8.0, epsilon=2.0)
assert winner["gated"] is True # 최대 lift(+1) < ε(2) → 게이팅
winner2 = we.select_winner_by_lift(per_w, random_score=3.0, epsilon=2.0)
assert winner2["gated"] is False
assert winner2["trial_id"] == 2 # prize 9 → lift +6
def test_prize_score_from_hist():
# 등수 가중치: 1등 매우 큼, 하위는 작게
s = we.prize_score_from_hist({"m3": 10, "m4": 2, "m5": 0, "m6": 0, "bonus_hits": 0})
s_big = we.prize_score_from_hist({"m3": 0, "m4": 0, "m5": 0, "m6": 1, "bonus_hits": 0})
assert s_big > s # 1등 1장이 5등 다수보다 큼
def test_select_winner_by_lift_preserves_all_keys():
"""select_winner_by_lift는 per_w 항목의 모든 키를 보존해야 한다.
best_match, weight_label 등 identity 필드가 누락되면 evaluate_weekly가 깨진다."""
per_w = [
{
"trial_id": 10,
"weight_label": "w0",
"weight": [0.2] * 5,
"prize_score": 3.0,
"best_match": 3,
},
{
"trial_id": 11,
"weight_label": "w1",
"weight": [0.3, 0.2, 0.2, 0.2, 0.1],
"prize_score": 20.0,
"best_match": 4,
},
]
result = we.select_winner_by_lift(per_w, random_score=5.0, epsilon=2.0)
assert result["gated"] is False
assert result["trial_id"] == 11
assert result["weight_label"] == "w1" # identity 키 보존
assert result["best_match"] == 4 # best_match 키 보존
assert "lift" in result # lift 추가됨
assert result["lift"] == pytest.approx(15.0)
def test_gated_path_keeps_base_via_select_winner():
"""gated=True일 때 select_winner_by_lift의 반환값 검증.
evaluate_weekly 내의 gated 분기가 올바른 값에 의존함을 확인한다."""
per_w = [
{"trial_id": 1, "weight_label": "w0", "weight": [0.2]*5,
"prize_score": 5.0, "best_match": 2},
{"trial_id": 2, "weight_label": "w1", "weight": [0.3,0.2,0.2,0.2,0.1],
"prize_score": 7.0, "best_match": 3},
]
# random_best=8.0 → 최대 engine lift=7-8=-1 → gated
result = we.select_winner_by_lift(per_w, random_score=8.0, epsilon=we.LIFT_EPSILON)
assert result["gated"] is True
assert result["lift"] < 0
# decide_base_update를 통해 gated가 unchanged를 유도하는지 확인
# (gated override가 없더라도, 현재 LIFT_EPSILON=10.0 하에서 lift<0이면 항상 gated)
current = [0.2, 0.2, 0.2, 0.2, 0.2]
# gated이면 evaluate_weekly가 current_base를 그대로 유지해야 함
# 여기서는 override 로직을 직접 재현해 검증한다
gated = result["gated"]
new_base_override = list(current) if gated else None
reason_override = "unchanged_gated" if gated else "should_not_reach"
assert new_base_override == current
assert reason_override == "unchanged_gated"