Files
ai-trade/ai_trade/tests/test_chronos_predictor.py
gahusb c2e77a7310 fix(ai_trade): Chronos confidence를 absolute spread 기반으로 통일 (F4)
코드 리뷰 F4: signal_generator의 hard gate(L79)는 absolute spread(0.6 threshold)를
쓰지만 chronos_predictor:106의 confidence는 relative spread (q90-q10)/max(|median|, 0.001).
zero-shot median≈0 케이스에서 spread가 폭증하여 conf=0으로 눌리고 결국 모든
매수 신호가 confidence_threshold(0.7)를 못 넘김.

산식 통일: conf = max(0, min(1, 1 - spread/_SPREAD_THRESHOLD)). _SPREAD_THRESHOLD=0.6
은 signal_generator hard gate와 동일.

- spread≈0 → conf≈1 (확신)
- spread=0.3 → conf=0.5 (중간)
- spread≥0.6 → conf=0 (거부)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 19:39:15 +09:00

144 lines
6.2 KiB
Python

"""Tests for ChronosPredictor (model mock)."""
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
@pytest.fixture
def mock_pipeline():
"""Mock BaseChronosPipeline.from_pretrained returning a mock pipeline object."""
with patch("chronos.BaseChronosPipeline") as cls:
cls.__name__ = "BaseChronosPipeline"
instance = MagicMock()
# ChronosBolt API: predict_quantiles returns (quantiles_tensor, mean_tensor)
# Modern (predict_quantiles) branch will be used since hasattr(MagicMock, "predict_quantiles") is True.
cls.from_pretrained.return_value = instance
yield instance
@pytest.fixture
def mock_torch_cpu():
with patch("torch.cuda.is_available", return_value=False):
yield
def _daily_ohlcv(close_seq):
return [{"datetime": f"2026-05-{i+1:02d}", "open": c, "high": c, "low": c,
"close": c, "volume": 1000} for i, c in enumerate(close_seq)]
def _mk_quantiles_tensor(q10_price: float, q50_price: float, q90_price: float):
"""Helper: build predict_quantiles return tensor shape [1, 1, 3]."""
import torch
return torch.tensor([[[q10_price, q50_price, q90_price]]], dtype=torch.float32)
def test_predict_batch_returns_prediction_dict(mock_pipeline, mock_torch_cpu):
"""mock predict_quantiles → dict[ticker, ChronosPrediction]. last_close=100, q50=102 → median≈+2%."""
quantiles = _mk_quantiles_tensor(101.5, 102.0, 102.5) # narrow around 102
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
from ai_trade.chronos_predictor import ChronosPredictor, ChronosPrediction
predictor = ChronosPredictor(model_name="mock-model")
daily = {"005930": _daily_ohlcv([100] * 60)}
result = predictor.predict_batch(daily)
assert "005930" in result
pred = result["005930"]
assert isinstance(pred, ChronosPrediction)
assert abs(pred.median - 0.02) < 0.001
def test_conf_high_when_distribution_narrow(mock_pipeline, mock_torch_cpu):
"""좁은 distribution (q90-q10 작음, median 0 아님) → conf ≈ 1."""
# last_close=100, q10=101.99, q50=102.00, q90=102.01
# returns: q10=0.0199, q50=0.02, q90=0.0201
# spread = (0.0201 - 0.0199) / max(0.02, 0.001) = 0.0002/0.02 = 0.01 → conf = 1 - 0.005 = 0.995
quantiles = _mk_quantiles_tensor(101.99, 102.0, 102.01)
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
from ai_trade.chronos_predictor import ChronosPredictor
predictor = ChronosPredictor(model_name="mock-model")
daily = {"005930": _daily_ohlcv([100] * 60)}
result = predictor.predict_batch(daily)
assert result["005930"].conf > 0.8
def test_conf_low_when_distribution_wide(mock_pipeline, mock_torch_cpu):
"""넓은 distribution → conf ≈ 0."""
# last_close=100, q10=70, q50=100, q90=130
# returns: q10=-0.3, q50=0.0, q90=0.3
# spread = (0.3 - (-0.3)) / max(0.0, 0.001) = 0.6 / 0.001 = 600 → conf = max(0, 1 - 300) = 0
quantiles = _mk_quantiles_tensor(70.0, 100.0, 130.0)
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
from ai_trade.chronos_predictor import ChronosPredictor
predictor = ChronosPredictor(model_name="mock-model")
daily = {"005930": _daily_ohlcv([100] * 60)}
result = predictor.predict_batch(daily)
assert result["005930"].conf < 0.3
def test_return_computed_from_price_relative_to_last_close(mock_pipeline, mock_torch_cpu):
"""price 예측 → last_close 대비 return 변환. last_close=100, q50=110 → return ≈ +10%."""
quantiles = _mk_quantiles_tensor(109.0, 110.0, 111.0)
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
from ai_trade.chronos_predictor import ChronosPredictor
predictor = ChronosPredictor(model_name="mock-model")
# last close = 100
daily = {"005930": _daily_ohlcv(list(range(41, 101)))} # last = 100
result = predictor.predict_batch(daily)
assert abs(result["005930"].median - 0.10) < 0.001
# ----- F4: absolute spread 기반 confidence -----
def test_confidence_high_when_spread_near_zero(mock_pipeline, mock_torch_cpu):
"""F4 — median≈0 + spread≈0 일 때 conf≈1 (현 relative 산식의 회귀 케이스).
한국 주가 100000원, q10=q50=q90=100000 → median=0, spread=0.
Relative 산식 (spread/abs(median))은 0/0.001 보호선이라 spread=0이면 conf=1로
동작하지만, median≈0 + 미세 spread(예 1원) 케이스에서 폭증 → conf=0.
Absolute 산식은 그런 폭증 없음.
"""
quantiles = _mk_quantiles_tensor(100000.0, 100000.0, 100000.0)
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
from ai_trade.chronos_predictor import ChronosPredictor
predictor = ChronosPredictor(model_name="mock-model")
daily = {"005930": _daily_ohlcv([100000] * 60)}
result = predictor.predict_batch(daily)
assert result["005930"].conf > 0.95, (
f"median≈0 + spread≈0인데 conf={result['005930'].conf} (F4 회귀)"
)
def test_confidence_half_at_spread_03(mock_pipeline, mock_torch_cpu):
"""F4 — spread 0.30일 때 conf ≈ 0.5 (1 - 0.3/0.6)."""
# q10=85000 → -0.15, q90=115000 → 0.15, q50=100000 → 0.0
# spread = 0.30, conf = 1 - 0.30/0.60 = 0.50
quantiles = _mk_quantiles_tensor(85000.0, 100000.0, 115000.0)
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
from ai_trade.chronos_predictor import ChronosPredictor
predictor = ChronosPredictor(model_name="mock-model")
daily = {"005930": _daily_ohlcv([100000] * 60)}
result = predictor.predict_batch(daily)
conf = result["005930"].conf
assert 0.45 < conf < 0.55, f"spread=0.30에서 conf={conf} (expected ≈0.5)"
def test_confidence_zero_at_threshold_spread(mock_pipeline, mock_torch_cpu):
"""F4 — spread가 _SPREAD_THRESHOLD(0.6)이면 conf=0."""
quantiles = _mk_quantiles_tensor(70000.0, 100000.0, 130000.0)
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
from ai_trade.chronos_predictor import ChronosPredictor
predictor = ChronosPredictor(model_name="mock-model")
daily = {"005930": _daily_ohlcv([100000] * 60)}
result = predictor.predict_batch(daily)
assert result["005930"].conf < 0.05, (
f"spread=threshold에서 conf={result['005930'].conf} (expected ≈0)"
)