From 28f9c8c3a6dba6c148fee0c07a862f8d66b3b135 Mon Sep 17 00:00:00 2001 From: gahusb Date: Sat, 16 May 2026 18:00:46 +0900 Subject: [PATCH] feat(signal_v2-phase3b): chronos_predictor + 4 mock tests ChronosPredictor wraps HuggingFace ChronosPipeline. Batch predict returns ChronosPrediction(median, q10, q90, conf, as_of) per ticker. Confidence = 1 - clamp(spread/2, 0, 1) where spread = (q90-q10) / |median|. Lazy import of chronos lib (heavy). GPU auto-detect with FP16. 44 tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- signal_v2/chronos_predictor.py | 78 +++++++++++++++++++++ signal_v2/tests/test_chronos_predictor.py | 83 +++++++++++++++++++++++ 2 files changed, 161 insertions(+) create mode 100644 signal_v2/chronos_predictor.py create mode 100644 signal_v2/tests/test_chronos_predictor.py diff --git a/signal_v2/chronos_predictor.py b/signal_v2/chronos_predictor.py new file mode 100644 index 0000000..5ec73ab --- /dev/null +++ b/signal_v2/chronos_predictor.py @@ -0,0 +1,78 @@ +"""Chronos-2 zero-shot forecaster wrapper.""" +from __future__ import annotations +import logging +from dataclasses import dataclass +from datetime import datetime +from zoneinfo import ZoneInfo + +import numpy as np + +logger = logging.getLogger(__name__) +KST = ZoneInfo("Asia/Seoul") + + +@dataclass +class ChronosPrediction: + median: float + q10: float + q90: float + conf: float + as_of: str + + +class ChronosPredictor: + """HuggingFace Chronos-2 zero-shot forecaster.""" + + def __init__(self, model_name: str = "amazon/chronos-2", device: str | None = None): + from chronos import ChronosPipeline + import torch + + self._device = device or ("cuda" if torch.cuda.is_available() else "cpu") + logger.info("Loading Chronos pipeline: %s on %s", model_name, self._device) + self._pipeline = ChronosPipeline.from_pretrained( + model_name, + device_map=self._device, + torch_dtype=torch.float16 if self._device == "cuda" else torch.float32, + ) + logger.info("Chronos pipeline loaded.") + + def predict_batch( + self, + daily_ohlcv_dict: dict[str, list[dict]], + prediction_length: int = 1, + num_samples: int = 100, + ) -> dict[str, ChronosPrediction]: + """종목별 1-day return 분포 예측.""" + import torch + + tickers = list(daily_ohlcv_dict.keys()) + if not tickers: + return {} + + contexts = [ + torch.tensor([bar["close"] for bar in daily_ohlcv_dict[t]], dtype=torch.float32) + for t in tickers + ] + forecasts = self._pipeline.predict( + context=contexts, + prediction_length=prediction_length, + num_samples=num_samples, + ) + # Convert to numpy if tensor + forecasts_np = forecasts.numpy() if hasattr(forecasts, "numpy") else np.asarray(forecasts) + + now_iso = datetime.now(KST).isoformat() + results: dict[str, ChronosPrediction] = {} + for i, ticker in enumerate(tickers): + samples = forecasts_np[i, :, 0] + last_close = daily_ohlcv_dict[ticker][-1]["close"] + returns = (samples - last_close) / last_close + median = float(np.quantile(returns, 0.5)) + q10 = float(np.quantile(returns, 0.1)) + q90 = float(np.quantile(returns, 0.9)) + spread = (q90 - q10) / max(abs(median), 0.001) + conf = float(max(0.0, min(1.0, 1.0 - spread / 2.0))) + results[ticker] = ChronosPrediction( + median=median, q10=q10, q90=q90, conf=conf, as_of=now_iso, + ) + return results diff --git a/signal_v2/tests/test_chronos_predictor.py b/signal_v2/tests/test_chronos_predictor.py new file mode 100644 index 0000000..42729a4 --- /dev/null +++ b/signal_v2/tests/test_chronos_predictor.py @@ -0,0 +1,83 @@ +"""Tests for ChronosPredictor (model mock).""" +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + + +@pytest.fixture +def mock_pipeline(): + """Mock ChronosPipeline.from_pretrained returning a mock pipeline object.""" + with patch("chronos.ChronosPipeline") as cls: + instance = MagicMock() + cls.from_pretrained.return_value = instance + yield instance + + +@pytest.fixture +def mock_torch_cpu(): + with patch("torch.cuda.is_available", return_value=False): + yield + + +def _daily_ohlcv(close_seq): + return [{"datetime": f"2026-05-{i+1:02d}", "open": c, "high": c, "low": c, + "close": c, "volume": 1000} for i, c in enumerate(close_seq)] + + +def test_predict_batch_returns_prediction_dict(mock_pipeline, mock_torch_cpu): + """mock pipeline → dict[ticker, ChronosPrediction]. last_close=100, samples=102 → ~+2% return.""" + import torch + samples = np.full((100,), 102.0) + mock_pipeline.predict.return_value = torch.tensor(samples.reshape(1, 100, 1)) + + from signal_v2.chronos_predictor import ChronosPredictor, ChronosPrediction + predictor = ChronosPredictor(model_name="mock-model") + daily = {"005930": _daily_ohlcv([100] * 60)} + result = predictor.predict_batch(daily) + assert "005930" in result + pred = result["005930"] + assert isinstance(pred, ChronosPrediction) + assert abs(pred.median - 0.02) < 0.001 + + +def test_conf_high_when_distribution_narrow(mock_pipeline, mock_torch_cpu): + """좁은 distribution → conf ≈ 1.""" + import torch + np.random.seed(42) + samples = np.random.normal(102.0, 0.1, 100) + mock_pipeline.predict.return_value = torch.tensor(samples.reshape(1, 100, 1)) + + from signal_v2.chronos_predictor import ChronosPredictor + predictor = ChronosPredictor(model_name="mock-model") + daily = {"005930": _daily_ohlcv([100] * 60)} + result = predictor.predict_batch(daily) + assert result["005930"].conf > 0.8 + + +def test_conf_low_when_distribution_wide(mock_pipeline, mock_torch_cpu): + """넓은 distribution → conf ≈ 0.""" + import torch + np.random.seed(42) + samples = np.random.normal(100.0, 30.0, 100) + mock_pipeline.predict.return_value = torch.tensor(samples.reshape(1, 100, 1)) + + from signal_v2.chronos_predictor import ChronosPredictor + predictor = ChronosPredictor(model_name="mock-model") + daily = {"005930": _daily_ohlcv([100] * 60)} + result = predictor.predict_batch(daily) + assert result["005930"].conf < 0.3 + + +def test_return_computed_from_price_relative_to_last_close(mock_pipeline, mock_torch_cpu): + """price 예측 → last_close 대비 return 변환. last_close=100, samples=110 → return ~+10%.""" + import torch + samples = np.full((100,), 110.0) + mock_pipeline.predict.return_value = torch.tensor(samples.reshape(1, 100, 1)) + + from signal_v2.chronos_predictor import ChronosPredictor + predictor = ChronosPredictor(model_name="mock-model") + # last close in the seq = 100 + daily = {"005930": _daily_ohlcv(list(range(41, 101)))} # last = 100 + result = predictor.predict_batch(daily) + assert abs(result["005930"].median - 0.10) < 0.001