From 8eefe9d79daec81a6fe6cd0b4bc582e4fcc39d35 Mon Sep 17 00:00:00 2001 From: gahusb Date: Sun, 17 May 2026 09:07:11 +0900 Subject: [PATCH] fix(signal_v2-phase3b): ChronosBolt predict_quantiles API support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ChronosBoltPipeline.predict() does not accept `context` kwarg; it uses positional-only and is deterministic (no num_samples). Switch to predict_quantiles(context, prediction_length, quantile_levels) which returns (quantiles_tensor, mean_tensor). Implementation: if hasattr(pipeline, "predict_quantiles") → modern quantile branch. Else fall back to legacy sample-based predict (T5). Tests: switch to predict_quantiles mock returning (quantiles, None) with shape [1, 1, 3] for q10/q50/q90 directly. 45/45 tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- signal_v2/chronos_predictor.py | 41 ++++++++++++++++++--- signal_v2/tests/test_chronos_predictor.py | 44 +++++++++++++---------- 2 files changed, 62 insertions(+), 23 deletions(-) diff --git a/signal_v2/chronos_predictor.py b/signal_v2/chronos_predictor.py index f754b46..ff74375 100644 --- a/signal_v2/chronos_predictor.py +++ b/signal_v2/chronos_predictor.py @@ -55,7 +55,11 @@ class ChronosPredictor: prediction_length: int = 1, num_samples: int = 100, ) -> dict[str, ChronosPrediction]: - """종목별 1-day return 분포 예측.""" + """종목별 1-day return 분포 예측. + + ChronosBolt / Chronos-2 등 신모델은 predict_quantiles 사용 (deterministic). + Legacy ChronosPipeline (T5) 는 sample-based predict. + """ import torch tickers = list(daily_ohlcv_dict.keys()) @@ -66,16 +70,43 @@ class ChronosPredictor: torch.tensor([bar["close"] for bar in daily_ohlcv_dict[t]], dtype=torch.float32) for t in tickers ] + now_iso = datetime.now(KST).isoformat() + results: dict[str, ChronosPrediction] = {} + + # Modern API: predict_quantiles (ChronosBolt / Chronos-2) + if hasattr(self._pipeline, "predict_quantiles"): + quantile_levels = [0.1, 0.5, 0.9] + quantiles_tensor, _ = self._pipeline.predict_quantiles( + context=contexts, + prediction_length=prediction_length, + quantile_levels=quantile_levels, + ) + quantiles_np = ( + quantiles_tensor.cpu().numpy() + if hasattr(quantiles_tensor, "cpu") + else np.asarray(quantiles_tensor) + ) + # shape: [num_series, prediction_length, 3] + for i, ticker in enumerate(tickers): + q10_price, q50_price, q90_price = quantiles_np[i, 0, :] + last_close = daily_ohlcv_dict[ticker][-1]["close"] + median = float((q50_price - last_close) / last_close) + q10 = float((q10_price - last_close) / last_close) + q90 = float((q90_price - last_close) / last_close) + spread = (q90 - q10) / max(abs(median), 0.001) + conf = float(max(0.0, min(1.0, 1.0 - spread / 2.0))) + results[ticker] = ChronosPrediction( + median=median, q10=q10, q90=q90, conf=conf, as_of=now_iso, + ) + return results + + # Legacy API: sample-based predict (ChronosPipeline T5) forecasts = self._pipeline.predict( context=contexts, prediction_length=prediction_length, num_samples=num_samples, ) - # Convert to numpy if tensor forecasts_np = forecasts.numpy() if hasattr(forecasts, "numpy") else np.asarray(forecasts) - - now_iso = datetime.now(KST).isoformat() - results: dict[str, ChronosPrediction] = {} for i, ticker in enumerate(tickers): samples = forecasts_np[i, :, 0] last_close = daily_ohlcv_dict[ticker][-1]["close"] diff --git a/signal_v2/tests/test_chronos_predictor.py b/signal_v2/tests/test_chronos_predictor.py index 43e5358..d8b6210 100644 --- a/signal_v2/tests/test_chronos_predictor.py +++ b/signal_v2/tests/test_chronos_predictor.py @@ -11,6 +11,8 @@ def mock_pipeline(): with patch("chronos.BaseChronosPipeline") as cls: cls.__name__ = "BaseChronosPipeline" instance = MagicMock() + # ChronosBolt API: predict_quantiles returns (quantiles_tensor, mean_tensor) + # Modern (predict_quantiles) branch will be used since hasattr(MagicMock, "predict_quantiles") is True. cls.from_pretrained.return_value = instance yield instance @@ -26,11 +28,16 @@ def _daily_ohlcv(close_seq): "close": c, "volume": 1000} for i, c in enumerate(close_seq)] -def test_predict_batch_returns_prediction_dict(mock_pipeline, mock_torch_cpu): - """mock pipeline → dict[ticker, ChronosPrediction]. last_close=100, samples=102 → ~+2% return.""" +def _mk_quantiles_tensor(q10_price: float, q50_price: float, q90_price: float): + """Helper: build predict_quantiles return tensor shape [1, 1, 3].""" import torch - samples = np.full((100,), 102.0) - mock_pipeline.predict.return_value = torch.tensor(samples.reshape(1, 100, 1)) + return torch.tensor([[[q10_price, q50_price, q90_price]]], dtype=torch.float32) + + +def test_predict_batch_returns_prediction_dict(mock_pipeline, mock_torch_cpu): + """mock predict_quantiles → dict[ticker, ChronosPrediction]. last_close=100, q50=102 → median≈+2%.""" + quantiles = _mk_quantiles_tensor(101.5, 102.0, 102.5) # narrow around 102 + mock_pipeline.predict_quantiles.return_value = (quantiles, None) from signal_v2.chronos_predictor import ChronosPredictor, ChronosPrediction predictor = ChronosPredictor(model_name="mock-model") @@ -43,11 +50,12 @@ def test_predict_batch_returns_prediction_dict(mock_pipeline, mock_torch_cpu): def test_conf_high_when_distribution_narrow(mock_pipeline, mock_torch_cpu): - """좁은 distribution → conf ≈ 1.""" - import torch - np.random.seed(42) - samples = np.random.normal(102.0, 0.1, 100) - mock_pipeline.predict.return_value = torch.tensor(samples.reshape(1, 100, 1)) + """좁은 distribution (q90-q10 작음, median 0 아님) → conf ≈ 1.""" + # last_close=100, q10=101.99, q50=102.00, q90=102.01 + # returns: q10=0.0199, q50=0.02, q90=0.0201 + # spread = (0.0201 - 0.0199) / max(0.02, 0.001) = 0.0002/0.02 = 0.01 → conf = 1 - 0.005 = 0.995 + quantiles = _mk_quantiles_tensor(101.99, 102.0, 102.01) + mock_pipeline.predict_quantiles.return_value = (quantiles, None) from signal_v2.chronos_predictor import ChronosPredictor predictor = ChronosPredictor(model_name="mock-model") @@ -58,10 +66,11 @@ def test_conf_high_when_distribution_narrow(mock_pipeline, mock_torch_cpu): def test_conf_low_when_distribution_wide(mock_pipeline, mock_torch_cpu): """넓은 distribution → conf ≈ 0.""" - import torch - np.random.seed(42) - samples = np.random.normal(100.0, 30.0, 100) - mock_pipeline.predict.return_value = torch.tensor(samples.reshape(1, 100, 1)) + # last_close=100, q10=70, q50=100, q90=130 + # returns: q10=-0.3, q50=0.0, q90=0.3 + # spread = (0.3 - (-0.3)) / max(0.0, 0.001) = 0.6 / 0.001 = 600 → conf = max(0, 1 - 300) = 0 + quantiles = _mk_quantiles_tensor(70.0, 100.0, 130.0) + mock_pipeline.predict_quantiles.return_value = (quantiles, None) from signal_v2.chronos_predictor import ChronosPredictor predictor = ChronosPredictor(model_name="mock-model") @@ -71,14 +80,13 @@ def test_conf_low_when_distribution_wide(mock_pipeline, mock_torch_cpu): def test_return_computed_from_price_relative_to_last_close(mock_pipeline, mock_torch_cpu): - """price 예측 → last_close 대비 return 변환. last_close=100, samples=110 → return ~+10%.""" - import torch - samples = np.full((100,), 110.0) - mock_pipeline.predict.return_value = torch.tensor(samples.reshape(1, 100, 1)) + """price 예측 → last_close 대비 return 변환. last_close=100, q50=110 → return ≈ +10%.""" + quantiles = _mk_quantiles_tensor(109.0, 110.0, 111.0) + mock_pipeline.predict_quantiles.return_value = (quantiles, None) from signal_v2.chronos_predictor import ChronosPredictor predictor = ChronosPredictor(model_name="mock-model") - # last close in the seq = 100 + # last close = 100 daily = {"005930": _daily_ohlcv(list(range(41, 101)))} # last = 100 result = predictor.predict_batch(daily) assert abs(result["005930"].median - 0.10) < 0.001