fix(signal_v2-phase3b): ChronosBolt predict_quantiles API support
ChronosBoltPipeline.predict() does not accept `context` kwarg; it uses positional-only and is deterministic (no num_samples). Switch to predict_quantiles(context, prediction_length, quantile_levels) which returns (quantiles_tensor, mean_tensor). Implementation: if hasattr(pipeline, "predict_quantiles") → modern quantile branch. Else fall back to legacy sample-based predict (T5). Tests: switch to predict_quantiles mock returning (quantiles, None) with shape [1, 1, 3] for q10/q50/q90 directly. 45/45 tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -11,6 +11,8 @@ def mock_pipeline():
|
||||
with patch("chronos.BaseChronosPipeline") as cls:
|
||||
cls.__name__ = "BaseChronosPipeline"
|
||||
instance = MagicMock()
|
||||
# ChronosBolt API: predict_quantiles returns (quantiles_tensor, mean_tensor)
|
||||
# Modern (predict_quantiles) branch will be used since hasattr(MagicMock, "predict_quantiles") is True.
|
||||
cls.from_pretrained.return_value = instance
|
||||
yield instance
|
||||
|
||||
@@ -26,11 +28,16 @@ def _daily_ohlcv(close_seq):
|
||||
"close": c, "volume": 1000} for i, c in enumerate(close_seq)]
|
||||
|
||||
|
||||
def test_predict_batch_returns_prediction_dict(mock_pipeline, mock_torch_cpu):
|
||||
"""mock pipeline → dict[ticker, ChronosPrediction]. last_close=100, samples=102 → ~+2% return."""
|
||||
def _mk_quantiles_tensor(q10_price: float, q50_price: float, q90_price: float):
|
||||
"""Helper: build predict_quantiles return tensor shape [1, 1, 3]."""
|
||||
import torch
|
||||
samples = np.full((100,), 102.0)
|
||||
mock_pipeline.predict.return_value = torch.tensor(samples.reshape(1, 100, 1))
|
||||
return torch.tensor([[[q10_price, q50_price, q90_price]]], dtype=torch.float32)
|
||||
|
||||
|
||||
def test_predict_batch_returns_prediction_dict(mock_pipeline, mock_torch_cpu):
|
||||
"""mock predict_quantiles → dict[ticker, ChronosPrediction]. last_close=100, q50=102 → median≈+2%."""
|
||||
quantiles = _mk_quantiles_tensor(101.5, 102.0, 102.5) # narrow around 102
|
||||
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
|
||||
|
||||
from signal_v2.chronos_predictor import ChronosPredictor, ChronosPrediction
|
||||
predictor = ChronosPredictor(model_name="mock-model")
|
||||
@@ -43,11 +50,12 @@ def test_predict_batch_returns_prediction_dict(mock_pipeline, mock_torch_cpu):
|
||||
|
||||
|
||||
def test_conf_high_when_distribution_narrow(mock_pipeline, mock_torch_cpu):
|
||||
"""좁은 distribution → conf ≈ 1."""
|
||||
import torch
|
||||
np.random.seed(42)
|
||||
samples = np.random.normal(102.0, 0.1, 100)
|
||||
mock_pipeline.predict.return_value = torch.tensor(samples.reshape(1, 100, 1))
|
||||
"""좁은 distribution (q90-q10 작음, median 0 아님) → conf ≈ 1."""
|
||||
# last_close=100, q10=101.99, q50=102.00, q90=102.01
|
||||
# returns: q10=0.0199, q50=0.02, q90=0.0201
|
||||
# spread = (0.0201 - 0.0199) / max(0.02, 0.001) = 0.0002/0.02 = 0.01 → conf = 1 - 0.005 = 0.995
|
||||
quantiles = _mk_quantiles_tensor(101.99, 102.0, 102.01)
|
||||
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
|
||||
|
||||
from signal_v2.chronos_predictor import ChronosPredictor
|
||||
predictor = ChronosPredictor(model_name="mock-model")
|
||||
@@ -58,10 +66,11 @@ def test_conf_high_when_distribution_narrow(mock_pipeline, mock_torch_cpu):
|
||||
|
||||
def test_conf_low_when_distribution_wide(mock_pipeline, mock_torch_cpu):
|
||||
"""넓은 distribution → conf ≈ 0."""
|
||||
import torch
|
||||
np.random.seed(42)
|
||||
samples = np.random.normal(100.0, 30.0, 100)
|
||||
mock_pipeline.predict.return_value = torch.tensor(samples.reshape(1, 100, 1))
|
||||
# last_close=100, q10=70, q50=100, q90=130
|
||||
# returns: q10=-0.3, q50=0.0, q90=0.3
|
||||
# spread = (0.3 - (-0.3)) / max(0.0, 0.001) = 0.6 / 0.001 = 600 → conf = max(0, 1 - 300) = 0
|
||||
quantiles = _mk_quantiles_tensor(70.0, 100.0, 130.0)
|
||||
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
|
||||
|
||||
from signal_v2.chronos_predictor import ChronosPredictor
|
||||
predictor = ChronosPredictor(model_name="mock-model")
|
||||
@@ -71,14 +80,13 @@ def test_conf_low_when_distribution_wide(mock_pipeline, mock_torch_cpu):
|
||||
|
||||
|
||||
def test_return_computed_from_price_relative_to_last_close(mock_pipeline, mock_torch_cpu):
|
||||
"""price 예측 → last_close 대비 return 변환. last_close=100, samples=110 → return ~+10%."""
|
||||
import torch
|
||||
samples = np.full((100,), 110.0)
|
||||
mock_pipeline.predict.return_value = torch.tensor(samples.reshape(1, 100, 1))
|
||||
"""price 예측 → last_close 대비 return 변환. last_close=100, q50=110 → return ≈ +10%."""
|
||||
quantiles = _mk_quantiles_tensor(109.0, 110.0, 111.0)
|
||||
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
|
||||
|
||||
from signal_v2.chronos_predictor import ChronosPredictor
|
||||
predictor = ChronosPredictor(model_name="mock-model")
|
||||
# last close in the seq = 100
|
||||
# last close = 100
|
||||
daily = {"005930": _daily_ohlcv(list(range(41, 101)))} # last = 100
|
||||
result = predictor.predict_batch(daily)
|
||||
assert abs(result["005930"].median - 0.10) < 0.001
|
||||
|
||||
Reference in New Issue
Block a user