fix(signal_v2-phase3b): use BaseChronosPipeline for new model architectures
ChronosPipeline (legacy T5) does not support amazon/chronos-2 or chronos-bolt-* (input_patch_size). Switch to BaseChronosPipeline which auto-detects variant and returns the appropriate sub-pipeline (ChronosBoltPipeline / Chronos2Pipeline / ChronosPipeline). Also handle the dtype kwarg deprecation: try newer `dtype=` first, fall back to `torch_dtype=` for older versions. Test mock_pipeline fixture updated to patch BaseChronosPipeline. 45/45 tests pass. Verified amazon/chronos-bolt-base loads on CUDA. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -24,16 +24,29 @@ class ChronosPredictor:
|
||||
"""HuggingFace Chronos-2 zero-shot forecaster."""
|
||||
|
||||
def __init__(self, model_name: str = "amazon/chronos-2", device: str | None = None):
|
||||
from chronos import ChronosPipeline
|
||||
# BaseChronosPipeline auto-detects model variant (Chronos / ChronosBolt / Chronos-2)
|
||||
# and returns the appropriate sub-pipeline. ChronosPipeline only supports legacy T5.
|
||||
import torch
|
||||
try:
|
||||
from chronos import BaseChronosPipeline
|
||||
pipeline_cls = BaseChronosPipeline
|
||||
except ImportError:
|
||||
from chronos import ChronosPipeline
|
||||
pipeline_cls = ChronosPipeline
|
||||
|
||||
self._device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info("Loading Chronos pipeline: %s on %s", model_name, self._device)
|
||||
self._pipeline = ChronosPipeline.from_pretrained(
|
||||
model_name,
|
||||
device_map=self._device,
|
||||
torch_dtype=torch.float16 if self._device == "cuda" else torch.float32,
|
||||
)
|
||||
dtype = torch.float16 if self._device == "cuda" else torch.float32
|
||||
logger.info("Loading Chronos pipeline: %s on %s (cls=%s)",
|
||||
model_name, self._device, pipeline_cls.__name__)
|
||||
# Try `dtype` (newer API) first, fall back to `torch_dtype` (older)
|
||||
try:
|
||||
self._pipeline = pipeline_cls.from_pretrained(
|
||||
model_name, device_map=self._device, dtype=dtype,
|
||||
)
|
||||
except TypeError:
|
||||
self._pipeline = pipeline_cls.from_pretrained(
|
||||
model_name, device_map=self._device, torch_dtype=dtype,
|
||||
)
|
||||
logger.info("Chronos pipeline loaded.")
|
||||
|
||||
def predict_batch(
|
||||
|
||||
Reference in New Issue
Block a user