fix(signal_v2-phase3b): use BaseChronosPipeline for new model architectures

ChronosPipeline (legacy T5) does not support amazon/chronos-2 or
chronos-bolt-* (input_patch_size). Switch to BaseChronosPipeline
which auto-detects variant and returns the appropriate sub-pipeline
(ChronosBoltPipeline / Chronos2Pipeline / ChronosPipeline).

Also handle the dtype kwarg deprecation: try newer `dtype=` first,
fall back to `torch_dtype=` for older versions.

Test mock_pipeline fixture updated to patch BaseChronosPipeline.

45/45 tests pass. Verified amazon/chronos-bolt-base loads on CUDA.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-17 08:57:22 +09:00
parent 44888d6ede
commit 91de16675b
2 changed files with 23 additions and 9 deletions

View File

@@ -24,15 +24,28 @@ class ChronosPredictor:
"""HuggingFace Chronos-2 zero-shot forecaster.""" """HuggingFace Chronos-2 zero-shot forecaster."""
def __init__(self, model_name: str = "amazon/chronos-2", device: str | None = None): def __init__(self, model_name: str = "amazon/chronos-2", device: str | None = None):
from chronos import ChronosPipeline # BaseChronosPipeline auto-detects model variant (Chronos / ChronosBolt / Chronos-2)
# and returns the appropriate sub-pipeline. ChronosPipeline only supports legacy T5.
import torch import torch
try:
from chronos import BaseChronosPipeline
pipeline_cls = BaseChronosPipeline
except ImportError:
from chronos import ChronosPipeline
pipeline_cls = ChronosPipeline
self._device = device or ("cuda" if torch.cuda.is_available() else "cpu") self._device = device or ("cuda" if torch.cuda.is_available() else "cpu")
logger.info("Loading Chronos pipeline: %s on %s", model_name, self._device) dtype = torch.float16 if self._device == "cuda" else torch.float32
self._pipeline = ChronosPipeline.from_pretrained( logger.info("Loading Chronos pipeline: %s on %s (cls=%s)",
model_name, model_name, self._device, pipeline_cls.__name__)
device_map=self._device, # Try `dtype` (newer API) first, fall back to `torch_dtype` (older)
torch_dtype=torch.float16 if self._device == "cuda" else torch.float32, try:
self._pipeline = pipeline_cls.from_pretrained(
model_name, device_map=self._device, dtype=dtype,
)
except TypeError:
self._pipeline = pipeline_cls.from_pretrained(
model_name, device_map=self._device, torch_dtype=dtype,
) )
logger.info("Chronos pipeline loaded.") logger.info("Chronos pipeline loaded.")

View File

@@ -7,8 +7,9 @@ import pytest
@pytest.fixture @pytest.fixture
def mock_pipeline(): def mock_pipeline():
"""Mock ChronosPipeline.from_pretrained returning a mock pipeline object.""" """Mock BaseChronosPipeline.from_pretrained returning a mock pipeline object."""
with patch("chronos.ChronosPipeline") as cls: with patch("chronos.BaseChronosPipeline") as cls:
cls.__name__ = "BaseChronosPipeline"
instance = MagicMock() instance = MagicMock()
cls.from_pretrained.return_value = instance cls.from_pretrained.return_value = instance
yield instance yield instance