fix(signal_v2-phase3b): force FP32 + predict_quantiles positional args
ChronosBoltPipeline.predict_quantiles takes `inputs` positional, not `context` keyword. Use positional with TypeError fallback for older chronos versions. FP16 caused inf overflow on Korean stock prices (e.g. 280,000원 > FP16 max 65,504). Force FP32 for prices to avoid this. Chronos model itself handles internal scaling. Verified end-to-end: 60-day daily fetch → Chronos predict → quantile output. Example 005930: median=-0.59%, q10=-8.9%, q90=+6.4%, conf=0.0 (low conf is mathematically correct when median is near zero relative to distribution width). 45/45 tests still pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -35,7 +35,9 @@ class ChronosPredictor:
|
|||||||
pipeline_cls = ChronosPipeline
|
pipeline_cls = ChronosPipeline
|
||||||
|
|
||||||
self._device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
self._device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
dtype = torch.float16 if self._device == "cuda" else torch.float32
|
# Always use float32 — Korean stock prices (e.g. 280,000원) exceed FP16 max (~65,504)
|
||||||
|
# causing inf in quantile output. FP32 is safe for typical price magnitudes.
|
||||||
|
dtype = torch.float32
|
||||||
logger.info("Loading Chronos pipeline: %s on %s (cls=%s)",
|
logger.info("Loading Chronos pipeline: %s on %s (cls=%s)",
|
||||||
model_name, self._device, pipeline_cls.__name__)
|
model_name, self._device, pipeline_cls.__name__)
|
||||||
# Try `dtype` (newer API) first, fall back to `torch_dtype` (older)
|
# Try `dtype` (newer API) first, fall back to `torch_dtype` (older)
|
||||||
@@ -76,6 +78,14 @@ class ChronosPredictor:
|
|||||||
# Modern API: predict_quantiles (ChronosBolt / Chronos-2)
|
# Modern API: predict_quantiles (ChronosBolt / Chronos-2)
|
||||||
if hasattr(self._pipeline, "predict_quantiles"):
|
if hasattr(self._pipeline, "predict_quantiles"):
|
||||||
quantile_levels = [0.1, 0.5, 0.9]
|
quantile_levels = [0.1, 0.5, 0.9]
|
||||||
|
# ChronosBolt API: positional `inputs` (first arg). Older variants use `context`.
|
||||||
|
try:
|
||||||
|
quantiles_tensor, _ = self._pipeline.predict_quantiles(
|
||||||
|
contexts,
|
||||||
|
prediction_length=prediction_length,
|
||||||
|
quantile_levels=quantile_levels,
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
quantiles_tensor, _ = self._pipeline.predict_quantiles(
|
quantiles_tensor, _ = self._pipeline.predict_quantiles(
|
||||||
context=contexts,
|
context=contexts,
|
||||||
prediction_length=prediction_length,
|
prediction_length=prediction_length,
|
||||||
|
|||||||
Reference in New Issue
Block a user