fix(signal_v2-phase3b): use BaseChronosPipeline for new model architectures
ChronosPipeline (legacy T5) does not support amazon/chronos-2 or chronos-bolt-* (input_patch_size). Switch to BaseChronosPipeline which auto-detects variant and returns the appropriate sub-pipeline (ChronosBoltPipeline / Chronos2Pipeline / ChronosPipeline). Also handle the dtype kwarg deprecation: try newer `dtype=` first, fall back to `torch_dtype=` for older versions. Test mock_pipeline fixture updated to patch BaseChronosPipeline. 45/45 tests pass. Verified amazon/chronos-bolt-base loads on CUDA. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -24,15 +24,28 @@ class ChronosPredictor:
|
|||||||
"""HuggingFace Chronos-2 zero-shot forecaster."""
|
"""HuggingFace Chronos-2 zero-shot forecaster."""
|
||||||
|
|
||||||
def __init__(self, model_name: str = "amazon/chronos-2", device: str | None = None):
|
def __init__(self, model_name: str = "amazon/chronos-2", device: str | None = None):
|
||||||
from chronos import ChronosPipeline
|
# BaseChronosPipeline auto-detects model variant (Chronos / ChronosBolt / Chronos-2)
|
||||||
|
# and returns the appropriate sub-pipeline. ChronosPipeline only supports legacy T5.
|
||||||
import torch
|
import torch
|
||||||
|
try:
|
||||||
|
from chronos import BaseChronosPipeline
|
||||||
|
pipeline_cls = BaseChronosPipeline
|
||||||
|
except ImportError:
|
||||||
|
from chronos import ChronosPipeline
|
||||||
|
pipeline_cls = ChronosPipeline
|
||||||
|
|
||||||
self._device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
self._device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
logger.info("Loading Chronos pipeline: %s on %s", model_name, self._device)
|
dtype = torch.float16 if self._device == "cuda" else torch.float32
|
||||||
self._pipeline = ChronosPipeline.from_pretrained(
|
logger.info("Loading Chronos pipeline: %s on %s (cls=%s)",
|
||||||
model_name,
|
model_name, self._device, pipeline_cls.__name__)
|
||||||
device_map=self._device,
|
# Try `dtype` (newer API) first, fall back to `torch_dtype` (older)
|
||||||
torch_dtype=torch.float16 if self._device == "cuda" else torch.float32,
|
try:
|
||||||
|
self._pipeline = pipeline_cls.from_pretrained(
|
||||||
|
model_name, device_map=self._device, dtype=dtype,
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
self._pipeline = pipeline_cls.from_pretrained(
|
||||||
|
model_name, device_map=self._device, torch_dtype=dtype,
|
||||||
)
|
)
|
||||||
logger.info("Chronos pipeline loaded.")
|
logger.info("Chronos pipeline loaded.")
|
||||||
|
|
||||||
|
|||||||
@@ -7,8 +7,9 @@ import pytest
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_pipeline():
|
def mock_pipeline():
|
||||||
"""Mock ChronosPipeline.from_pretrained returning a mock pipeline object."""
|
"""Mock BaseChronosPipeline.from_pretrained returning a mock pipeline object."""
|
||||||
with patch("chronos.ChronosPipeline") as cls:
|
with patch("chronos.BaseChronosPipeline") as cls:
|
||||||
|
cls.__name__ = "BaseChronosPipeline"
|
||||||
instance = MagicMock()
|
instance = MagicMock()
|
||||||
cls.from_pretrained.return_value = instance
|
cls.from_pretrained.return_value = instance
|
||||||
yield instance
|
yield instance
|
||||||
|
|||||||
Reference in New Issue
Block a user