From 91de16675b732fbb14513b0cacfb7675180d7115 Mon Sep 17 00:00:00 2001 From: gahusb Date: Sun, 17 May 2026 08:57:22 +0900 Subject: [PATCH] fix(signal_v2-phase3b): use BaseChronosPipeline for new model architectures ChronosPipeline (legacy T5) does not support amazon/chronos-2 or chronos-bolt-* (input_patch_size). Switch to BaseChronosPipeline which auto-detects variant and returns the appropriate sub-pipeline (ChronosBoltPipeline / Chronos2Pipeline / ChronosPipeline). Also handle the dtype kwarg deprecation: try newer `dtype=` first, fall back to `torch_dtype=` for older versions. Test mock_pipeline fixture updated to patch BaseChronosPipeline. 45/45 tests pass. Verified amazon/chronos-bolt-base loads on CUDA. Co-Authored-By: Claude Opus 4.7 (1M context) --- signal_v2/chronos_predictor.py | 27 +++++++++++++++++------ signal_v2/tests/test_chronos_predictor.py | 5 +++-- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/signal_v2/chronos_predictor.py b/signal_v2/chronos_predictor.py index 5ec73ab..f754b46 100644 --- a/signal_v2/chronos_predictor.py +++ b/signal_v2/chronos_predictor.py @@ -24,16 +24,29 @@ class ChronosPredictor: """HuggingFace Chronos-2 zero-shot forecaster.""" def __init__(self, model_name: str = "amazon/chronos-2", device: str | None = None): - from chronos import ChronosPipeline + # BaseChronosPipeline auto-detects model variant (Chronos / ChronosBolt / Chronos-2) + # and returns the appropriate sub-pipeline. ChronosPipeline only supports legacy T5. import torch + try: + from chronos import BaseChronosPipeline + pipeline_cls = BaseChronosPipeline + except ImportError: + from chronos import ChronosPipeline + pipeline_cls = ChronosPipeline self._device = device or ("cuda" if torch.cuda.is_available() else "cpu") - logger.info("Loading Chronos pipeline: %s on %s", model_name, self._device) - self._pipeline = ChronosPipeline.from_pretrained( - model_name, - device_map=self._device, - torch_dtype=torch.float16 if self._device == "cuda" else torch.float32, - ) + dtype = torch.float16 if self._device == "cuda" else torch.float32 + logger.info("Loading Chronos pipeline: %s on %s (cls=%s)", + model_name, self._device, pipeline_cls.__name__) + # Try `dtype` (newer API) first, fall back to `torch_dtype` (older) + try: + self._pipeline = pipeline_cls.from_pretrained( + model_name, device_map=self._device, dtype=dtype, + ) + except TypeError: + self._pipeline = pipeline_cls.from_pretrained( + model_name, device_map=self._device, torch_dtype=dtype, + ) logger.info("Chronos pipeline loaded.") def predict_batch( diff --git a/signal_v2/tests/test_chronos_predictor.py b/signal_v2/tests/test_chronos_predictor.py index 42729a4..43e5358 100644 --- a/signal_v2/tests/test_chronos_predictor.py +++ b/signal_v2/tests/test_chronos_predictor.py @@ -7,8 +7,9 @@ import pytest @pytest.fixture def mock_pipeline(): - """Mock ChronosPipeline.from_pretrained returning a mock pipeline object.""" - with patch("chronos.ChronosPipeline") as cls: + """Mock BaseChronosPipeline.from_pretrained returning a mock pipeline object.""" + with patch("chronos.BaseChronosPipeline") as cls: + cls.__name__ = "BaseChronosPipeline" instance = MagicMock() cls.from_pretrained.return_value = instance yield instance