Files
web-page-backend/lotto/app/weight_evolver.py

124 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# lotto/app/weight_evolver.py
"""5종 시뮬 점수 가중치 자율 학습 루프.
순수 함수 (clamp/perturb/Dirichlet/score/base-rule) + DB 진입점은 별도 섹션.
"""
from __future__ import annotations
import math
import random
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
MIN_WEIGHT = 0.05
N_METRICS = 5
DEFAULT_UNIFORM = [0.2] * N_METRICS # cold start
RANK_BY_CORRECT = {6: 1, 5: 3, 4: 4, 3: 5}
RANK_BONUS = {1: 1.0, 2: 0.8, 3: 0.6, 4: 0.3, 5: 0.1}
def clamp_and_normalize(W: List[float], min_w: float = MIN_WEIGHT) -> List[float]:
"""각 값 ≥ min_w + 합=1.0. 보장 안 되면 raise."""
if len(W) != N_METRICS:
raise ValueError(f"W must have {N_METRICS} elements")
# Iteratively clamp then normalize until all values satisfy min_w floor.
# (Normalizing after clamping can reduce some already-floored values below
# min_w when the denominator is large — iterate to convergence.)
vals = [float(w) for w in W]
for _ in range(100): # converges in a few iterations in practice
clamped = [max(min_w, v) for v in vals]
total = sum(clamped)
vals = [v / total for v in clamped]
if all(v >= min_w - 1e-12 for v in vals):
break
return vals
def perturb_weights(
base: List[float],
sigma: float = 0.05,
seed: Optional[int] = None,
) -> List[float]:
"""base에 정규분포 noise(σ) 추가 → clamp+normalize."""
if seed is not None:
np.random.seed(seed)
noise = np.random.normal(0, sigma, size=N_METRICS)
perturbed = [b + n for b, n in zip(base, noise)]
return clamp_and_normalize(perturbed)
def dirichlet_weights(
alpha: float = 2.0,
seed: Optional[int] = None,
) -> List[float]:
"""Dirichlet(α, α, α, α, α) 샘플 → clamp+normalize."""
if seed is not None:
np.random.seed(seed)
sample = np.random.dirichlet([alpha] * N_METRICS).tolist()
return clamp_and_normalize(sample)
def generate_weekly_candidates(
base: Optional[List[float]] = None,
seed: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""6개 후보 — 4 perturb + 2 dirichlet. day_of_week 0..5 매핑.
Returns:
[{"day_of_week": 0, "weight": [...], "source": "perturb"}, ...]
"""
if base is None:
base = DEFAULT_UNIFORM[:]
if seed is not None:
np.random.seed(seed)
trials = []
for i in range(4):
trials.append({
"day_of_week": i,
"weight": perturb_weights(base, sigma=0.05),
"source": "perturb",
})
for i in range(4, 6):
trials.append({
"day_of_week": i,
"weight": dirichlet_weights(alpha=2.0),
"source": "dirichlet",
})
return trials
def count_match(pick: List[int], winning: List[int]) -> int:
"""본번호 6개 일치 개수. 보너스 제외."""
return len(set(pick) & set(winning[:6]))
def calc_pick_score(pick_numbers: List[int], winning_numbers: List[int]) -> float:
"""correct/6 + RANK_BONUS. 보너스 번호 미고려."""
correct = count_match(pick_numbers, winning_numbers)
base = correct / 6.0
rank = RANK_BY_CORRECT.get(correct)
bonus = RANK_BONUS.get(rank, 0) if rank else 0
return base + bonus
def decide_base_update(
winner_max_correct: int,
winner_W: List[float],
current_base: Optional[List[float]],
) -> Tuple[List[float], str]:
"""Hybrid base update rule.
Returns:
(new_base, reason) — reason ∈ {'winner_4plus','ema_blend','unchanged','cold_start'}
"""
if winner_max_correct >= 4:
return list(winner_W), "winner_4plus"
if winner_max_correct == 3 and current_base is not None:
blended = [0.3 * w + 0.7 * c for w, c in zip(winner_W, current_base)]
return clamp_and_normalize(blended), "ema_blend"
if current_base is None:
return DEFAULT_UNIFORM[:], "cold_start"
return list(current_base), "unchanged"