ChronosBoltPipeline.predict_quantiles takes `inputs` positional, not `context` keyword. Use positional with TypeError fallback for older chronos versions. FP16 caused inf overflow on Korean stock prices (e.g. 280,000원 > FP16 max 65,504). Force FP32 for prices to avoid this. Chronos model itself handles internal scaling. Verified end-to-end: 60-day daily fetch → Chronos predict → quantile output. Example 005930: median=-0.59%, q10=-8.9%, q90=+6.4%, conf=0.0 (low conf is mathematically correct when median is near zero relative to distribution width). 45/45 tests still pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
133 lines
5.2 KiB
Python
133 lines
5.2 KiB
Python
"""Chronos-2 zero-shot forecaster wrapper."""
|
|
from __future__ import annotations
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from zoneinfo import ZoneInfo
|
|
|
|
import numpy as np
|
|
|
|
logger = logging.getLogger(__name__)
|
|
KST = ZoneInfo("Asia/Seoul")
|
|
|
|
|
|
@dataclass
|
|
class ChronosPrediction:
|
|
median: float
|
|
q10: float
|
|
q90: float
|
|
conf: float
|
|
as_of: str
|
|
|
|
|
|
class ChronosPredictor:
|
|
"""HuggingFace Chronos-2 zero-shot forecaster."""
|
|
|
|
def __init__(self, model_name: str = "amazon/chronos-2", device: str | None = None):
|
|
# BaseChronosPipeline auto-detects model variant (Chronos / ChronosBolt / Chronos-2)
|
|
# and returns the appropriate sub-pipeline. ChronosPipeline only supports legacy T5.
|
|
import torch
|
|
try:
|
|
from chronos import BaseChronosPipeline
|
|
pipeline_cls = BaseChronosPipeline
|
|
except ImportError:
|
|
from chronos import ChronosPipeline
|
|
pipeline_cls = ChronosPipeline
|
|
|
|
self._device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
# Always use float32 — Korean stock prices (e.g. 280,000원) exceed FP16 max (~65,504)
|
|
# causing inf in quantile output. FP32 is safe for typical price magnitudes.
|
|
dtype = torch.float32
|
|
logger.info("Loading Chronos pipeline: %s on %s (cls=%s)",
|
|
model_name, self._device, pipeline_cls.__name__)
|
|
# Try `dtype` (newer API) first, fall back to `torch_dtype` (older)
|
|
try:
|
|
self._pipeline = pipeline_cls.from_pretrained(
|
|
model_name, device_map=self._device, dtype=dtype,
|
|
)
|
|
except TypeError:
|
|
self._pipeline = pipeline_cls.from_pretrained(
|
|
model_name, device_map=self._device, torch_dtype=dtype,
|
|
)
|
|
logger.info("Chronos pipeline loaded.")
|
|
|
|
def predict_batch(
|
|
self,
|
|
daily_ohlcv_dict: dict[str, list[dict]],
|
|
prediction_length: int = 1,
|
|
num_samples: int = 100,
|
|
) -> dict[str, ChronosPrediction]:
|
|
"""종목별 1-day return 분포 예측.
|
|
|
|
ChronosBolt / Chronos-2 등 신모델은 predict_quantiles 사용 (deterministic).
|
|
Legacy ChronosPipeline (T5) 는 sample-based predict.
|
|
"""
|
|
import torch
|
|
|
|
tickers = list(daily_ohlcv_dict.keys())
|
|
if not tickers:
|
|
return {}
|
|
|
|
contexts = [
|
|
torch.tensor([bar["close"] for bar in daily_ohlcv_dict[t]], dtype=torch.float32)
|
|
for t in tickers
|
|
]
|
|
now_iso = datetime.now(KST).isoformat()
|
|
results: dict[str, ChronosPrediction] = {}
|
|
|
|
# Modern API: predict_quantiles (ChronosBolt / Chronos-2)
|
|
if hasattr(self._pipeline, "predict_quantiles"):
|
|
quantile_levels = [0.1, 0.5, 0.9]
|
|
# ChronosBolt API: positional `inputs` (first arg). Older variants use `context`.
|
|
try:
|
|
quantiles_tensor, _ = self._pipeline.predict_quantiles(
|
|
contexts,
|
|
prediction_length=prediction_length,
|
|
quantile_levels=quantile_levels,
|
|
)
|
|
except TypeError:
|
|
quantiles_tensor, _ = self._pipeline.predict_quantiles(
|
|
context=contexts,
|
|
prediction_length=prediction_length,
|
|
quantile_levels=quantile_levels,
|
|
)
|
|
quantiles_np = (
|
|
quantiles_tensor.cpu().numpy()
|
|
if hasattr(quantiles_tensor, "cpu")
|
|
else np.asarray(quantiles_tensor)
|
|
)
|
|
# shape: [num_series, prediction_length, 3]
|
|
for i, ticker in enumerate(tickers):
|
|
q10_price, q50_price, q90_price = quantiles_np[i, 0, :]
|
|
last_close = daily_ohlcv_dict[ticker][-1]["close"]
|
|
median = float((q50_price - last_close) / last_close)
|
|
q10 = float((q10_price - last_close) / last_close)
|
|
q90 = float((q90_price - last_close) / last_close)
|
|
spread = (q90 - q10) / max(abs(median), 0.001)
|
|
conf = float(max(0.0, min(1.0, 1.0 - spread / 2.0)))
|
|
results[ticker] = ChronosPrediction(
|
|
median=median, q10=q10, q90=q90, conf=conf, as_of=now_iso,
|
|
)
|
|
return results
|
|
|
|
# Legacy API: sample-based predict (ChronosPipeline T5)
|
|
forecasts = self._pipeline.predict(
|
|
context=contexts,
|
|
prediction_length=prediction_length,
|
|
num_samples=num_samples,
|
|
)
|
|
forecasts_np = forecasts.numpy() if hasattr(forecasts, "numpy") else np.asarray(forecasts)
|
|
for i, ticker in enumerate(tickers):
|
|
samples = forecasts_np[i, :, 0]
|
|
last_close = daily_ohlcv_dict[ticker][-1]["close"]
|
|
returns = (samples - last_close) / last_close
|
|
median = float(np.quantile(returns, 0.5))
|
|
q10 = float(np.quantile(returns, 0.1))
|
|
q90 = float(np.quantile(returns, 0.9))
|
|
spread = (q90 - q10) / max(abs(median), 0.001)
|
|
conf = float(max(0.0, min(1.0, 1.0 - spread / 2.0)))
|
|
results[ticker] = ChronosPrediction(
|
|
median=median, q10=q10, q90=q90, conf=conf, as_of=now_iso,
|
|
)
|
|
return results
|