Files
ai-trade/signal_v2/chronos_predictor.py
gahusb 8eefe9d79d fix(signal_v2-phase3b): ChronosBolt predict_quantiles API support
ChronosBoltPipeline.predict() does not accept `context` kwarg; it
uses positional-only and is deterministic (no num_samples). Switch
to predict_quantiles(context, prediction_length, quantile_levels)
which returns (quantiles_tensor, mean_tensor).

Implementation: if hasattr(pipeline, "predict_quantiles") → modern
quantile branch. Else fall back to legacy sample-based predict (T5).

Tests: switch to predict_quantiles mock returning (quantiles, None)
with shape [1, 1, 3] for q10/q50/q90 directly.

45/45 tests pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-17 09:07:11 +09:00

123 lines
4.7 KiB
Python

"""Chronos-2 zero-shot forecaster wrapper."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import datetime
from zoneinfo import ZoneInfo
import numpy as np
logger = logging.getLogger(__name__)
KST = ZoneInfo("Asia/Seoul")
@dataclass
class ChronosPrediction:
median: float
q10: float
q90: float
conf: float
as_of: str
class ChronosPredictor:
"""HuggingFace Chronos-2 zero-shot forecaster."""
def __init__(self, model_name: str = "amazon/chronos-2", device: str | None = None):
# BaseChronosPipeline auto-detects model variant (Chronos / ChronosBolt / Chronos-2)
# and returns the appropriate sub-pipeline. ChronosPipeline only supports legacy T5.
import torch
try:
from chronos import BaseChronosPipeline
pipeline_cls = BaseChronosPipeline
except ImportError:
from chronos import ChronosPipeline
pipeline_cls = ChronosPipeline
self._device = device or ("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16 if self._device == "cuda" else torch.float32
logger.info("Loading Chronos pipeline: %s on %s (cls=%s)",
model_name, self._device, pipeline_cls.__name__)
# Try `dtype` (newer API) first, fall back to `torch_dtype` (older)
try:
self._pipeline = pipeline_cls.from_pretrained(
model_name, device_map=self._device, dtype=dtype,
)
except TypeError:
self._pipeline = pipeline_cls.from_pretrained(
model_name, device_map=self._device, torch_dtype=dtype,
)
logger.info("Chronos pipeline loaded.")
def predict_batch(
self,
daily_ohlcv_dict: dict[str, list[dict]],
prediction_length: int = 1,
num_samples: int = 100,
) -> dict[str, ChronosPrediction]:
"""종목별 1-day return 분포 예측.
ChronosBolt / Chronos-2 등 신모델은 predict_quantiles 사용 (deterministic).
Legacy ChronosPipeline (T5) 는 sample-based predict.
"""
import torch
tickers = list(daily_ohlcv_dict.keys())
if not tickers:
return {}
contexts = [
torch.tensor([bar["close"] for bar in daily_ohlcv_dict[t]], dtype=torch.float32)
for t in tickers
]
now_iso = datetime.now(KST).isoformat()
results: dict[str, ChronosPrediction] = {}
# Modern API: predict_quantiles (ChronosBolt / Chronos-2)
if hasattr(self._pipeline, "predict_quantiles"):
quantile_levels = [0.1, 0.5, 0.9]
quantiles_tensor, _ = self._pipeline.predict_quantiles(
context=contexts,
prediction_length=prediction_length,
quantile_levels=quantile_levels,
)
quantiles_np = (
quantiles_tensor.cpu().numpy()
if hasattr(quantiles_tensor, "cpu")
else np.asarray(quantiles_tensor)
)
# shape: [num_series, prediction_length, 3]
for i, ticker in enumerate(tickers):
q10_price, q50_price, q90_price = quantiles_np[i, 0, :]
last_close = daily_ohlcv_dict[ticker][-1]["close"]
median = float((q50_price - last_close) / last_close)
q10 = float((q10_price - last_close) / last_close)
q90 = float((q90_price - last_close) / last_close)
spread = (q90 - q10) / max(abs(median), 0.001)
conf = float(max(0.0, min(1.0, 1.0 - spread / 2.0)))
results[ticker] = ChronosPrediction(
median=median, q10=q10, q90=q90, conf=conf, as_of=now_iso,
)
return results
# Legacy API: sample-based predict (ChronosPipeline T5)
forecasts = self._pipeline.predict(
context=contexts,
prediction_length=prediction_length,
num_samples=num_samples,
)
forecasts_np = forecasts.numpy() if hasattr(forecasts, "numpy") else np.asarray(forecasts)
for i, ticker in enumerate(tickers):
samples = forecasts_np[i, :, 0]
last_close = daily_ohlcv_dict[ticker][-1]["close"]
returns = (samples - last_close) / last_close
median = float(np.quantile(returns, 0.5))
q10 = float(np.quantile(returns, 0.1))
q90 = float(np.quantile(returns, 0.9))
spread = (q90 - q10) / max(abs(median), 0.001)
conf = float(max(0.0, min(1.0, 1.0 - spread / 2.0)))
results[ticker] = ChronosPrediction(
median=median, q10=q10, q90=q90, conf=conf, as_of=now_iso,
)
return results