diff --git a/signal_v2/chronos_predictor.py b/signal_v2/chronos_predictor.py index 5ec73ab..f754b46 100644 --- a/signal_v2/chronos_predictor.py +++ b/signal_v2/chronos_predictor.py @@ -24,16 +24,29 @@ class ChronosPredictor: """HuggingFace Chronos-2 zero-shot forecaster.""" def __init__(self, model_name: str = "amazon/chronos-2", device: str | None = None): - from chronos import ChronosPipeline + # BaseChronosPipeline auto-detects model variant (Chronos / ChronosBolt / Chronos-2) + # and returns the appropriate sub-pipeline. ChronosPipeline only supports legacy T5. import torch + try: + from chronos import BaseChronosPipeline + pipeline_cls = BaseChronosPipeline + except ImportError: + from chronos import ChronosPipeline + pipeline_cls = ChronosPipeline self._device = device or ("cuda" if torch.cuda.is_available() else "cpu") - logger.info("Loading Chronos pipeline: %s on %s", model_name, self._device) - self._pipeline = ChronosPipeline.from_pretrained( - model_name, - device_map=self._device, - torch_dtype=torch.float16 if self._device == "cuda" else torch.float32, - ) + dtype = torch.float16 if self._device == "cuda" else torch.float32 + logger.info("Loading Chronos pipeline: %s on %s (cls=%s)", + model_name, self._device, pipeline_cls.__name__) + # Try `dtype` (newer API) first, fall back to `torch_dtype` (older) + try: + self._pipeline = pipeline_cls.from_pretrained( + model_name, device_map=self._device, dtype=dtype, + ) + except TypeError: + self._pipeline = pipeline_cls.from_pretrained( + model_name, device_map=self._device, torch_dtype=dtype, + ) logger.info("Chronos pipeline loaded.") def predict_batch( diff --git a/signal_v2/tests/test_chronos_predictor.py b/signal_v2/tests/test_chronos_predictor.py index 42729a4..43e5358 100644 --- a/signal_v2/tests/test_chronos_predictor.py +++ b/signal_v2/tests/test_chronos_predictor.py @@ -7,8 +7,9 @@ import pytest @pytest.fixture def mock_pipeline(): - """Mock ChronosPipeline.from_pretrained returning a mock pipeline object.""" - with patch("chronos.ChronosPipeline") as cls: + """Mock BaseChronosPipeline.from_pretrained returning a mock pipeline object.""" + with patch("chronos.BaseChronosPipeline") as cls: + cls.__name__ = "BaseChronosPipeline" instance = MagicMock() cls.from_pretrained.return_value = instance yield instance