diff --git a/signal_v2/chronos_predictor.py b/signal_v2/chronos_predictor.py index ff74375..e9ea196 100644 --- a/signal_v2/chronos_predictor.py +++ b/signal_v2/chronos_predictor.py @@ -35,7 +35,9 @@ class ChronosPredictor: pipeline_cls = ChronosPipeline self._device = device or ("cuda" if torch.cuda.is_available() else "cpu") - dtype = torch.float16 if self._device == "cuda" else torch.float32 + # Always use float32 — Korean stock prices (e.g. 280,000원) exceed FP16 max (~65,504) + # causing inf in quantile output. FP32 is safe for typical price magnitudes. + dtype = torch.float32 logger.info("Loading Chronos pipeline: %s on %s (cls=%s)", model_name, self._device, pipeline_cls.__name__) # Try `dtype` (newer API) first, fall back to `torch_dtype` (older) @@ -76,11 +78,19 @@ class ChronosPredictor: # Modern API: predict_quantiles (ChronosBolt / Chronos-2) if hasattr(self._pipeline, "predict_quantiles"): quantile_levels = [0.1, 0.5, 0.9] - quantiles_tensor, _ = self._pipeline.predict_quantiles( - context=contexts, - prediction_length=prediction_length, - quantile_levels=quantile_levels, - ) + # ChronosBolt API: positional `inputs` (first arg). Older variants use `context`. + try: + quantiles_tensor, _ = self._pipeline.predict_quantiles( + contexts, + prediction_length=prediction_length, + quantile_levels=quantile_levels, + ) + except TypeError: + quantiles_tensor, _ = self._pipeline.predict_quantiles( + context=contexts, + prediction_length=prediction_length, + quantile_levels=quantile_levels, + ) quantiles_np = ( quantiles_tensor.cpu().numpy() if hasattr(quantiles_tensor, "cpu")