feat(signal_v2-phase3b): chronos_predictor + 4 mock tests
ChronosPredictor wraps HuggingFace ChronosPipeline. Batch predict returns ChronosPrediction(median, q10, q90, conf, as_of) per ticker. Confidence = 1 - clamp(spread/2, 0, 1) where spread = (q90-q10) / |median|. Lazy import of chronos lib (heavy). GPU auto-detect with FP16. 44 tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
78
signal_v2/chronos_predictor.py
Normal file
78
signal_v2/chronos_predictor.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Chronos-2 zero-shot forecaster wrapper."""
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
KST = ZoneInfo("Asia/Seoul")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChronosPrediction:
|
||||
median: float
|
||||
q10: float
|
||||
q90: float
|
||||
conf: float
|
||||
as_of: str
|
||||
|
||||
|
||||
class ChronosPredictor:
|
||||
"""HuggingFace Chronos-2 zero-shot forecaster."""
|
||||
|
||||
def __init__(self, model_name: str = "amazon/chronos-2", device: str | None = None):
|
||||
from chronos import ChronosPipeline
|
||||
import torch
|
||||
|
||||
self._device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info("Loading Chronos pipeline: %s on %s", model_name, self._device)
|
||||
self._pipeline = ChronosPipeline.from_pretrained(
|
||||
model_name,
|
||||
device_map=self._device,
|
||||
torch_dtype=torch.float16 if self._device == "cuda" else torch.float32,
|
||||
)
|
||||
logger.info("Chronos pipeline loaded.")
|
||||
|
||||
def predict_batch(
|
||||
self,
|
||||
daily_ohlcv_dict: dict[str, list[dict]],
|
||||
prediction_length: int = 1,
|
||||
num_samples: int = 100,
|
||||
) -> dict[str, ChronosPrediction]:
|
||||
"""종목별 1-day return 분포 예측."""
|
||||
import torch
|
||||
|
||||
tickers = list(daily_ohlcv_dict.keys())
|
||||
if not tickers:
|
||||
return {}
|
||||
|
||||
contexts = [
|
||||
torch.tensor([bar["close"] for bar in daily_ohlcv_dict[t]], dtype=torch.float32)
|
||||
for t in tickers
|
||||
]
|
||||
forecasts = self._pipeline.predict(
|
||||
context=contexts,
|
||||
prediction_length=prediction_length,
|
||||
num_samples=num_samples,
|
||||
)
|
||||
# Convert to numpy if tensor
|
||||
forecasts_np = forecasts.numpy() if hasattr(forecasts, "numpy") else np.asarray(forecasts)
|
||||
|
||||
now_iso = datetime.now(KST).isoformat()
|
||||
results: dict[str, ChronosPrediction] = {}
|
||||
for i, ticker in enumerate(tickers):
|
||||
samples = forecasts_np[i, :, 0]
|
||||
last_close = daily_ohlcv_dict[ticker][-1]["close"]
|
||||
returns = (samples - last_close) / last_close
|
||||
median = float(np.quantile(returns, 0.5))
|
||||
q10 = float(np.quantile(returns, 0.1))
|
||||
q90 = float(np.quantile(returns, 0.9))
|
||||
spread = (q90 - q10) / max(abs(median), 0.001)
|
||||
conf = float(max(0.0, min(1.0, 1.0 - spread / 2.0)))
|
||||
results[ticker] = ChronosPrediction(
|
||||
median=median, q10=q10, q90=q90, conf=conf, as_of=now_iso,
|
||||
)
|
||||
return results
|
||||
83
signal_v2/tests/test_chronos_predictor.py
Normal file
83
signal_v2/tests/test_chronos_predictor.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Tests for ChronosPredictor (model mock)."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pipeline():
|
||||
"""Mock ChronosPipeline.from_pretrained returning a mock pipeline object."""
|
||||
with patch("chronos.ChronosPipeline") as cls:
|
||||
instance = MagicMock()
|
||||
cls.from_pretrained.return_value = instance
|
||||
yield instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_torch_cpu():
|
||||
with patch("torch.cuda.is_available", return_value=False):
|
||||
yield
|
||||
|
||||
|
||||
def _daily_ohlcv(close_seq):
|
||||
return [{"datetime": f"2026-05-{i+1:02d}", "open": c, "high": c, "low": c,
|
||||
"close": c, "volume": 1000} for i, c in enumerate(close_seq)]
|
||||
|
||||
|
||||
def test_predict_batch_returns_prediction_dict(mock_pipeline, mock_torch_cpu):
|
||||
"""mock pipeline → dict[ticker, ChronosPrediction]. last_close=100, samples=102 → ~+2% return."""
|
||||
import torch
|
||||
samples = np.full((100,), 102.0)
|
||||
mock_pipeline.predict.return_value = torch.tensor(samples.reshape(1, 100, 1))
|
||||
|
||||
from signal_v2.chronos_predictor import ChronosPredictor, ChronosPrediction
|
||||
predictor = ChronosPredictor(model_name="mock-model")
|
||||
daily = {"005930": _daily_ohlcv([100] * 60)}
|
||||
result = predictor.predict_batch(daily)
|
||||
assert "005930" in result
|
||||
pred = result["005930"]
|
||||
assert isinstance(pred, ChronosPrediction)
|
||||
assert abs(pred.median - 0.02) < 0.001
|
||||
|
||||
|
||||
def test_conf_high_when_distribution_narrow(mock_pipeline, mock_torch_cpu):
|
||||
"""좁은 distribution → conf ≈ 1."""
|
||||
import torch
|
||||
np.random.seed(42)
|
||||
samples = np.random.normal(102.0, 0.1, 100)
|
||||
mock_pipeline.predict.return_value = torch.tensor(samples.reshape(1, 100, 1))
|
||||
|
||||
from signal_v2.chronos_predictor import ChronosPredictor
|
||||
predictor = ChronosPredictor(model_name="mock-model")
|
||||
daily = {"005930": _daily_ohlcv([100] * 60)}
|
||||
result = predictor.predict_batch(daily)
|
||||
assert result["005930"].conf > 0.8
|
||||
|
||||
|
||||
def test_conf_low_when_distribution_wide(mock_pipeline, mock_torch_cpu):
|
||||
"""넓은 distribution → conf ≈ 0."""
|
||||
import torch
|
||||
np.random.seed(42)
|
||||
samples = np.random.normal(100.0, 30.0, 100)
|
||||
mock_pipeline.predict.return_value = torch.tensor(samples.reshape(1, 100, 1))
|
||||
|
||||
from signal_v2.chronos_predictor import ChronosPredictor
|
||||
predictor = ChronosPredictor(model_name="mock-model")
|
||||
daily = {"005930": _daily_ohlcv([100] * 60)}
|
||||
result = predictor.predict_batch(daily)
|
||||
assert result["005930"].conf < 0.3
|
||||
|
||||
|
||||
def test_return_computed_from_price_relative_to_last_close(mock_pipeline, mock_torch_cpu):
|
||||
"""price 예측 → last_close 대비 return 변환. last_close=100, samples=110 → return ~+10%."""
|
||||
import torch
|
||||
samples = np.full((100,), 110.0)
|
||||
mock_pipeline.predict.return_value = torch.tensor(samples.reshape(1, 100, 1))
|
||||
|
||||
from signal_v2.chronos_predictor import ChronosPredictor
|
||||
predictor = ChronosPredictor(model_name="mock-model")
|
||||
# last close in the seq = 100
|
||||
daily = {"005930": _daily_ohlcv(list(range(41, 101)))} # last = 100
|
||||
result = predictor.predict_batch(daily)
|
||||
assert abs(result["005930"].median - 0.10) < 0.001
|
||||
Reference in New Issue
Block a user