"""Chronos-2 zero-shot forecaster wrapper.""" from __future__ import annotations import logging from dataclasses import dataclass from datetime import datetime from zoneinfo import ZoneInfo import numpy as np logger = logging.getLogger(__name__) KST = ZoneInfo("Asia/Seoul") @dataclass class ChronosPrediction: median: float q10: float q90: float conf: float as_of: str class ChronosPredictor: """HuggingFace Chronos-2 zero-shot forecaster.""" def __init__(self, model_name: str = "amazon/chronos-2", device: str | None = None): # BaseChronosPipeline auto-detects model variant (Chronos / ChronosBolt / Chronos-2) # and returns the appropriate sub-pipeline. ChronosPipeline only supports legacy T5. import torch try: from chronos import BaseChronosPipeline pipeline_cls = BaseChronosPipeline except ImportError: from chronos import ChronosPipeline pipeline_cls = ChronosPipeline self._device = device or ("cuda" if torch.cuda.is_available() else "cpu") # Always use float32 — Korean stock prices (e.g. 280,000원) exceed FP16 max (~65,504) # causing inf in quantile output. FP32 is safe for typical price magnitudes. dtype = torch.float32 logger.info("Loading Chronos pipeline: %s on %s (cls=%s)", model_name, self._device, pipeline_cls.__name__) # Try `dtype` (newer API) first, fall back to `torch_dtype` (older) try: self._pipeline = pipeline_cls.from_pretrained( model_name, device_map=self._device, dtype=dtype, ) except TypeError: self._pipeline = pipeline_cls.from_pretrained( model_name, device_map=self._device, torch_dtype=dtype, ) logger.info("Chronos pipeline loaded.") def predict_batch( self, daily_ohlcv_dict: dict[str, list[dict]], prediction_length: int = 1, num_samples: int = 100, ) -> dict[str, ChronosPrediction]: """종목별 1-day return 분포 예측. ChronosBolt / Chronos-2 등 신모델은 predict_quantiles 사용 (deterministic). Legacy ChronosPipeline (T5) 는 sample-based predict. """ import torch tickers = list(daily_ohlcv_dict.keys()) if not tickers: return {} contexts = [ torch.tensor([bar["close"] for bar in daily_ohlcv_dict[t]], dtype=torch.float32) for t in tickers ] now_iso = datetime.now(KST).isoformat() results: dict[str, ChronosPrediction] = {} # Modern API: predict_quantiles (ChronosBolt / Chronos-2) if hasattr(self._pipeline, "predict_quantiles"): quantile_levels = [0.1, 0.5, 0.9] # ChronosBolt API: positional `inputs` (first arg). Older variants use `context`. try: quantiles_tensor, _ = self._pipeline.predict_quantiles( contexts, prediction_length=prediction_length, quantile_levels=quantile_levels, ) except TypeError: quantiles_tensor, _ = self._pipeline.predict_quantiles( context=contexts, prediction_length=prediction_length, quantile_levels=quantile_levels, ) quantiles_np = ( quantiles_tensor.cpu().numpy() if hasattr(quantiles_tensor, "cpu") else np.asarray(quantiles_tensor) ) # shape: [num_series, prediction_length, 3] for i, ticker in enumerate(tickers): q10_price, q50_price, q90_price = quantiles_np[i, 0, :] last_close = daily_ohlcv_dict[ticker][-1]["close"] median = float((q50_price - last_close) / last_close) q10 = float((q10_price - last_close) / last_close) q90 = float((q90_price - last_close) / last_close) spread = (q90 - q10) / max(abs(median), 0.001) conf = float(max(0.0, min(1.0, 1.0 - spread / 2.0))) results[ticker] = ChronosPrediction( median=median, q10=q10, q90=q90, conf=conf, as_of=now_iso, ) return results # Legacy API: sample-based predict (ChronosPipeline T5) forecasts = self._pipeline.predict( context=contexts, prediction_length=prediction_length, num_samples=num_samples, ) forecasts_np = forecasts.numpy() if hasattr(forecasts, "numpy") else np.asarray(forecasts) for i, ticker in enumerate(tickers): samples = forecasts_np[i, :, 0] last_close = daily_ohlcv_dict[ticker][-1]["close"] returns = (samples - last_close) / last_close median = float(np.quantile(returns, 0.5)) q10 = float(np.quantile(returns, 0.1)) q90 = float(np.quantile(returns, 0.9)) spread = (q90 - q10) / max(abs(median), 0.001) conf = float(max(0.0, min(1.0, 1.0 - spread / 2.0))) results[ticker] = ChronosPrediction( median=median, q10=q10, q90=q90, conf=conf, as_of=now_iso, ) return results