"""Chronos-2 zero-shot forecaster wrapper.""" from __future__ import annotations import logging from dataclasses import dataclass from datetime import datetime from zoneinfo import ZoneInfo import numpy as np logger = logging.getLogger(__name__) KST = ZoneInfo("Asia/Seoul") @dataclass class ChronosPrediction: median: float q10: float q90: float conf: float as_of: str class ChronosPredictor: """HuggingFace Chronos-2 zero-shot forecaster.""" def __init__(self, model_name: str = "amazon/chronos-2", device: str | None = None): # BaseChronosPipeline auto-detects model variant (Chronos / ChronosBolt / Chronos-2) # and returns the appropriate sub-pipeline. ChronosPipeline only supports legacy T5. import torch try: from chronos import BaseChronosPipeline pipeline_cls = BaseChronosPipeline except ImportError: from chronos import ChronosPipeline pipeline_cls = ChronosPipeline self._device = device or ("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float16 if self._device == "cuda" else torch.float32 logger.info("Loading Chronos pipeline: %s on %s (cls=%s)", model_name, self._device, pipeline_cls.__name__) # Try `dtype` (newer API) first, fall back to `torch_dtype` (older) try: self._pipeline = pipeline_cls.from_pretrained( model_name, device_map=self._device, dtype=dtype, ) except TypeError: self._pipeline = pipeline_cls.from_pretrained( model_name, device_map=self._device, torch_dtype=dtype, ) logger.info("Chronos pipeline loaded.") def predict_batch( self, daily_ohlcv_dict: dict[str, list[dict]], prediction_length: int = 1, num_samples: int = 100, ) -> dict[str, ChronosPrediction]: """종목별 1-day return 분포 예측.""" import torch tickers = list(daily_ohlcv_dict.keys()) if not tickers: return {} contexts = [ torch.tensor([bar["close"] for bar in daily_ohlcv_dict[t]], dtype=torch.float32) for t in tickers ] forecasts = self._pipeline.predict( context=contexts, prediction_length=prediction_length, num_samples=num_samples, ) # Convert to numpy if tensor forecasts_np = forecasts.numpy() if hasattr(forecasts, "numpy") else np.asarray(forecasts) now_iso = datetime.now(KST).isoformat() results: dict[str, ChronosPrediction] = {} for i, ticker in enumerate(tickers): samples = forecasts_np[i, :, 0] last_close = daily_ohlcv_dict[ticker][-1]["close"] returns = (samples - last_close) / last_close median = float(np.quantile(returns, 0.5)) q10 = float(np.quantile(returns, 0.1)) q90 = float(np.quantile(returns, 0.9)) spread = (q90 - q10) / max(abs(median), 0.001) conf = float(max(0.0, min(1.0, 1.0 - spread / 2.0))) results[ticker] = ChronosPrediction( median=median, q10=q10, q90=q90, conf=conf, as_of=now_iso, ) return results