refactor(web-ai): rename signal_v2→ai_trade, deprecate signal_v1
박재오 결정 2026-05-19 — V2를 정식 명칭 ai_trade로 graduation, V1은 deprecated 마킹 (legacy 디렉토리 이동은 file lock 풀린 후 후속). 변경 사항: - signal_v2/ → ai_trade/ (git mv, import 일괄 sed: signal_v2.x → ai_trade.x) - root start.bat → legacy/start_v1.bat (V1 자동 시작 차단) - ai_trade/start.bat 내부 uvicorn target signal_v2.main → ai_trade.main - signal_v1/DEPRECATED.md 추가 (사용 금지 명시) - CLAUDE.md 디렉토리 표·서버 시작 방식 갱신 - services/ 디렉토리 미래 예정 (Plan-B-Insta 작업 시 신설) ai_trade tests 59/59 PASS 확인. signal_v1/ 디렉토리 자체 이동(legacy/signal_v1/)은 telegram_bot.log + data/news_snapshots.db file lock으로 보류. lock 해제 후 후속 커밋. 후속 작업: Plan-B-Insta (services/insta-render + NAS insta 분할) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
0
ai_trade/tests/__init__.py
Normal file
0
ai_trade/tests/__init__.py
Normal file
BIN
ai_trade/tests/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
ai_trade/tests/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
ai_trade/tests/__pycache__/conftest.cpython-312-pytest-9.0.2.pyc
Normal file
BIN
ai_trade/tests/__pycache__/conftest.cpython-312-pytest-9.0.2.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
18
ai_trade/tests/conftest.py
Normal file
18
ai_trade/tests/conftest.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Pytest fixtures for ai_trade tests."""
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_dedup_db(tmp_path) -> Path:
|
||||
"""SQLite 단위 테스트용 임시 DB path."""
|
||||
return tmp_path / "test_ai_trade.db"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stock_api():
|
||||
"""respx 로 stock API mock. base_url 은 테스트마다 임의."""
|
||||
with respx.mock(base_url="https://test.stock.local", assert_all_called=False) as mock:
|
||||
yield mock
|
||||
92
ai_trade/tests/test_chronos_predictor.py
Normal file
92
ai_trade/tests/test_chronos_predictor.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Tests for ChronosPredictor (model mock)."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pipeline():
|
||||
"""Mock BaseChronosPipeline.from_pretrained returning a mock pipeline object."""
|
||||
with patch("chronos.BaseChronosPipeline") as cls:
|
||||
cls.__name__ = "BaseChronosPipeline"
|
||||
instance = MagicMock()
|
||||
# ChronosBolt API: predict_quantiles returns (quantiles_tensor, mean_tensor)
|
||||
# Modern (predict_quantiles) branch will be used since hasattr(MagicMock, "predict_quantiles") is True.
|
||||
cls.from_pretrained.return_value = instance
|
||||
yield instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_torch_cpu():
|
||||
with patch("torch.cuda.is_available", return_value=False):
|
||||
yield
|
||||
|
||||
|
||||
def _daily_ohlcv(close_seq):
|
||||
return [{"datetime": f"2026-05-{i+1:02d}", "open": c, "high": c, "low": c,
|
||||
"close": c, "volume": 1000} for i, c in enumerate(close_seq)]
|
||||
|
||||
|
||||
def _mk_quantiles_tensor(q10_price: float, q50_price: float, q90_price: float):
|
||||
"""Helper: build predict_quantiles return tensor shape [1, 1, 3]."""
|
||||
import torch
|
||||
return torch.tensor([[[q10_price, q50_price, q90_price]]], dtype=torch.float32)
|
||||
|
||||
|
||||
def test_predict_batch_returns_prediction_dict(mock_pipeline, mock_torch_cpu):
|
||||
"""mock predict_quantiles → dict[ticker, ChronosPrediction]. last_close=100, q50=102 → median≈+2%."""
|
||||
quantiles = _mk_quantiles_tensor(101.5, 102.0, 102.5) # narrow around 102
|
||||
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
|
||||
|
||||
from ai_trade.chronos_predictor import ChronosPredictor, ChronosPrediction
|
||||
predictor = ChronosPredictor(model_name="mock-model")
|
||||
daily = {"005930": _daily_ohlcv([100] * 60)}
|
||||
result = predictor.predict_batch(daily)
|
||||
assert "005930" in result
|
||||
pred = result["005930"]
|
||||
assert isinstance(pred, ChronosPrediction)
|
||||
assert abs(pred.median - 0.02) < 0.001
|
||||
|
||||
|
||||
def test_conf_high_when_distribution_narrow(mock_pipeline, mock_torch_cpu):
|
||||
"""좁은 distribution (q90-q10 작음, median 0 아님) → conf ≈ 1."""
|
||||
# last_close=100, q10=101.99, q50=102.00, q90=102.01
|
||||
# returns: q10=0.0199, q50=0.02, q90=0.0201
|
||||
# spread = (0.0201 - 0.0199) / max(0.02, 0.001) = 0.0002/0.02 = 0.01 → conf = 1 - 0.005 = 0.995
|
||||
quantiles = _mk_quantiles_tensor(101.99, 102.0, 102.01)
|
||||
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
|
||||
|
||||
from ai_trade.chronos_predictor import ChronosPredictor
|
||||
predictor = ChronosPredictor(model_name="mock-model")
|
||||
daily = {"005930": _daily_ohlcv([100] * 60)}
|
||||
result = predictor.predict_batch(daily)
|
||||
assert result["005930"].conf > 0.8
|
||||
|
||||
|
||||
def test_conf_low_when_distribution_wide(mock_pipeline, mock_torch_cpu):
|
||||
"""넓은 distribution → conf ≈ 0."""
|
||||
# last_close=100, q10=70, q50=100, q90=130
|
||||
# returns: q10=-0.3, q50=0.0, q90=0.3
|
||||
# spread = (0.3 - (-0.3)) / max(0.0, 0.001) = 0.6 / 0.001 = 600 → conf = max(0, 1 - 300) = 0
|
||||
quantiles = _mk_quantiles_tensor(70.0, 100.0, 130.0)
|
||||
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
|
||||
|
||||
from ai_trade.chronos_predictor import ChronosPredictor
|
||||
predictor = ChronosPredictor(model_name="mock-model")
|
||||
daily = {"005930": _daily_ohlcv([100] * 60)}
|
||||
result = predictor.predict_batch(daily)
|
||||
assert result["005930"].conf < 0.3
|
||||
|
||||
|
||||
def test_return_computed_from_price_relative_to_last_close(mock_pipeline, mock_torch_cpu):
|
||||
"""price 예측 → last_close 대비 return 변환. last_close=100, q50=110 → return ≈ +10%."""
|
||||
quantiles = _mk_quantiles_tensor(109.0, 110.0, 111.0)
|
||||
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
|
||||
|
||||
from ai_trade.chronos_predictor import ChronosPredictor
|
||||
predictor = ChronosPredictor(model_name="mock-model")
|
||||
# last close = 100
|
||||
daily = {"005930": _daily_ohlcv(list(range(41, 101)))} # last = 100
|
||||
result = predictor.predict_batch(daily)
|
||||
assert abs(result["005930"].median - 0.10) < 0.001
|
||||
161
ai_trade/tests/test_kis_client.py
Normal file
161
ai_trade/tests/test_kis_client.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Tests for KISClient (REST)."""
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from ai_trade.kis_client import KISClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_v1_token(tmp_path):
|
||||
"""V1 토큰 파일 fixture."""
|
||||
token_file = tmp_path / "kis_token.json"
|
||||
token_file.write_text(json.dumps({
|
||||
"access_token": "test-kis-token-abc123",
|
||||
"token_expired": "2099-12-31 23:59:59",
|
||||
}))
|
||||
return token_file
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kis_client_factory(fake_v1_token):
|
||||
def _make():
|
||||
return KISClient(
|
||||
app_key="test-app-key",
|
||||
app_secret="test-app-secret",
|
||||
account="50000000-01",
|
||||
is_virtual=True,
|
||||
v1_token_path=fake_v1_token,
|
||||
)
|
||||
return _make
|
||||
|
||||
|
||||
@respx.mock
|
||||
async def test_get_minute_ohlcv_normal_returns_30_bars(kis_client_factory):
|
||||
"""정상 200 → 30개 분봉 list 반환."""
|
||||
sample_output2 = [
|
||||
{
|
||||
"stck_bsop_date": "20260518",
|
||||
"stck_cntg_hour": f"09{m:02d}00",
|
||||
"stck_oprc": "78000", "stck_hgpr": "78500",
|
||||
"stck_lwpr": "77800", "stck_prpr": "78300",
|
||||
"cntg_vol": "12345",
|
||||
}
|
||||
for m in range(30) # 9:00-9:29 = 30 bars
|
||||
]
|
||||
respx.get(
|
||||
"https://openapivts.koreainvestment.com:29443/uapi/domestic-stock/v1/quotations/inquire-time-itemchartprice"
|
||||
).mock(
|
||||
return_value=httpx.Response(200, json={"output2": sample_output2})
|
||||
)
|
||||
|
||||
client = kis_client_factory()
|
||||
try:
|
||||
bars = await client.get_minute_ohlcv("005930")
|
||||
assert len(bars) == 30
|
||||
assert bars[0]["close"] == 78300
|
||||
assert "datetime" in bars[0]
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
@respx.mock
|
||||
async def test_get_minute_ohlcv_429_retry_then_success(kis_client_factory, monkeypatch):
|
||||
"""429 → exponential backoff → 200."""
|
||||
sleep_calls = []
|
||||
async def fake_sleep(s): sleep_calls.append(s)
|
||||
monkeypatch.setattr("asyncio.sleep", fake_sleep)
|
||||
|
||||
respx.get(
|
||||
"https://openapivts.koreainvestment.com:29443/uapi/domestic-stock/v1/quotations/inquire-time-itemchartprice"
|
||||
).mock(side_effect=[
|
||||
httpx.Response(429, text="rate limit"),
|
||||
httpx.Response(200, json={"output2": []}),
|
||||
])
|
||||
client = kis_client_factory()
|
||||
try:
|
||||
result = await client.get_minute_ohlcv("005930")
|
||||
assert result == []
|
||||
assert 1 in sleep_calls
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
@respx.mock
|
||||
async def test_get_minute_ohlcv_uses_v1_token(kis_client_factory, fake_v1_token):
|
||||
"""KIS 호출 헤더에 V1 토큰 파일의 access_token 사용."""
|
||||
route = respx.get(
|
||||
"https://openapivts.koreainvestment.com:29443/uapi/domestic-stock/v1/quotations/inquire-time-itemchartprice"
|
||||
).mock(return_value=httpx.Response(200, json={"output2": []}))
|
||||
|
||||
client = kis_client_factory()
|
||||
try:
|
||||
await client.get_minute_ohlcv("005930")
|
||||
assert route.called
|
||||
req = route.calls.last.request
|
||||
# check authorization header contains the V1 token
|
||||
auth = req.headers.get("authorization", "")
|
||||
assert "test-kis-token-abc123" in auth
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
@respx.mock
|
||||
async def test_get_asking_price_computes_bid_ratio(kis_client_factory):
|
||||
"""호가 응답 → bid_total/(bid+ask) bid_ratio 계산."""
|
||||
respx.get(
|
||||
"https://openapivts.koreainvestment.com:29443/uapi/domestic-stock/v1/quotations/inquire-asking-price-exp-ccn"
|
||||
).mock(return_value=httpx.Response(200, json={
|
||||
"output1": {
|
||||
"total_bidp_rsqn": "600",
|
||||
"total_askp_rsqn": "400",
|
||||
"stck_prpr": "78500",
|
||||
}
|
||||
}))
|
||||
|
||||
client = kis_client_factory()
|
||||
try:
|
||||
data = await client.get_asking_price("005930")
|
||||
assert data["bid_total"] == 600
|
||||
assert data["ask_total"] == 400
|
||||
assert abs(data["bid_ratio"] - 0.6) < 1e-9
|
||||
assert data["current_price"] == 78500
|
||||
assert "as_of" in data
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
@respx.mock
|
||||
async def test_get_daily_ohlcv_returns_60_bars(kis_client_factory):
|
||||
"""KIS daily endpoint returns 60 ascending bars after parsing."""
|
||||
# Build 60 KIS-format daily bars (descending dates as KIS does)
|
||||
sample_output2 = []
|
||||
for i in range(60):
|
||||
# Generate a fake date 60 days ago, descending
|
||||
day = 60 - i
|
||||
sample_output2.append({
|
||||
"stck_bsop_date": f"2026{(((day-1)//30)+1):02d}{(((day-1)%30)+1):02d}",
|
||||
"stck_oprc": "78000", "stck_hgpr": "78500",
|
||||
"stck_lwpr": "77800", "stck_clpr": str(78000 + i),
|
||||
"acml_vol": "12345",
|
||||
})
|
||||
|
||||
respx.get(
|
||||
"https://openapivts.koreainvestment.com:29443/uapi/domestic-stock/v1/quotations/inquire-daily-itemchartprice"
|
||||
).mock(return_value=httpx.Response(200, json={"output2": sample_output2}))
|
||||
|
||||
client = kis_client_factory()
|
||||
try:
|
||||
bars = await client.get_daily_ohlcv("005930", days=60)
|
||||
# KIS returns descending; client reverses to ascending
|
||||
assert len(bars) == 60
|
||||
# Ascending order: first item has smaller datetime than last
|
||||
assert bars[0]["datetime"] < bars[-1]["datetime"]
|
||||
assert isinstance(bars[0]["open"], int)
|
||||
assert isinstance(bars[0]["close"], int)
|
||||
assert "datetime" in bars[0]
|
||||
finally:
|
||||
await client.close()
|
||||
94
ai_trade/tests/test_kis_websocket.py
Normal file
94
ai_trade/tests/test_kis_websocket.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Tests for KISWebSocket."""
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from ai_trade.kis_websocket import KISWebSocket
|
||||
|
||||
|
||||
BASE_REST = "https://openapivts.koreainvestment.com:29443"
|
||||
|
||||
|
||||
@respx.mock
|
||||
async def test_fetch_approval_key_via_oauth_endpoint():
|
||||
"""POST /oauth2/Approval → approval_key 추출."""
|
||||
respx.post(f"{BASE_REST}/oauth2/Approval").mock(
|
||||
return_value=httpx.Response(200, json={"approval_key": "test-approval-key-xyz"})
|
||||
)
|
||||
ws = KISWebSocket(app_key="k", app_secret="s", is_virtual=True)
|
||||
key = await ws._fetch_approval_key()
|
||||
assert key == "test-approval-key-xyz"
|
||||
assert ws._approval_key == "test-approval-key-xyz"
|
||||
|
||||
|
||||
async def test_subscribe_sends_h0stasp0_message():
|
||||
"""subscribe() → WebSocket 으로 H0STASP0 구독 메시지 전송."""
|
||||
sent_messages = []
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.send = AsyncMock(side_effect=lambda m: sent_messages.append(m))
|
||||
|
||||
ws = KISWebSocket(app_key="k", app_secret="s", is_virtual=True)
|
||||
ws._approval_key = "test-key"
|
||||
ws._ws = mock_ws
|
||||
await ws.subscribe("005930")
|
||||
assert ws._subscriptions == {"005930"}
|
||||
assert len(sent_messages) == 1
|
||||
msg = json.loads(sent_messages[0])
|
||||
assert msg["header"]["tr_type"] == "1" # subscribe
|
||||
assert msg["body"]["input"]["tr_id"] == "H0STASP0"
|
||||
assert msg["body"]["input"]["tr_key"] == "005930"
|
||||
|
||||
|
||||
def test_parse_asking_price_extracts_bid_ask_totals():
|
||||
"""KIS raw '0|H0STASP0|001|...' → (ticker, dict).
|
||||
|
||||
KIS 호가 메시지 형식 — KIS 공식 spec 의 정확한 필드 인덱스 운영 검증 필요.
|
||||
본 테스트는 implementer 의 _parse_asking_price 구현 인덱스에 맞춰서 sample 작성.
|
||||
"""
|
||||
ws = KISWebSocket(app_key="k", app_secret="s", is_virtual=True)
|
||||
# Build a sample raw message — implementer 가 _ASKING_TOTAL_BID/ASK 인덱스에
|
||||
# 맞춰서 필드 배치하면 됨. 예: 마지막 2개 필드를 bid_total / ask_total 로.
|
||||
fields = ["005930", "091500", "78500"] # ticker, time, current_price
|
||||
fields.extend(["0"] * 40) # padding (KIS 의 실 필드 수 ~50개)
|
||||
fields.append("400") # ask_total
|
||||
fields.append("600") # bid_total
|
||||
raw = f"0|H0STASP0|001|{'^'.join(fields)}"
|
||||
|
||||
result = ws._parse_asking_price(raw)
|
||||
assert result is not None, "parse_asking_price returned None"
|
||||
ticker, data = result
|
||||
assert ticker == "005930"
|
||||
assert "bid_total" in data
|
||||
assert "ask_total" in data
|
||||
assert "bid_ratio" in data
|
||||
assert "current_price" in data
|
||||
# bid_total=600, ask_total=400, bid_ratio=0.6
|
||||
assert data["bid_total"] == 600
|
||||
assert data["ask_total"] == 400
|
||||
assert abs(data["bid_ratio"] - 0.6) < 1e-9
|
||||
|
||||
|
||||
async def test_reconnect_on_disconnect_with_backoff(monkeypatch):
|
||||
"""연결 끊김 → exponential backoff retry. _connect_with_backoff() 검증."""
|
||||
sleep_calls = []
|
||||
async def fake_sleep(s): sleep_calls.append(s)
|
||||
monkeypatch.setattr("asyncio.sleep", fake_sleep)
|
||||
|
||||
ws = KISWebSocket(app_key="k", app_secret="s", is_virtual=True)
|
||||
# Mock _connect to fail twice then succeed
|
||||
call_count = [0]
|
||||
async def fake_connect():
|
||||
call_count[0] += 1
|
||||
if call_count[0] < 3:
|
||||
raise ConnectionError("fake disconnect")
|
||||
return AsyncMock()
|
||||
monkeypatch.setattr(ws, "_connect", fake_connect)
|
||||
|
||||
result = await ws._connect_with_backoff()
|
||||
assert call_count[0] == 3 # 2 fails + 1 success
|
||||
# exponential 1s, 2s
|
||||
assert sleep_calls[:2] == [1, 2]
|
||||
62
ai_trade/tests/test_main.py
Normal file
62
ai_trade/tests/test_main.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Tests for FastAPI main app."""
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
def test_health_endpoint_returns_status_online(monkeypatch):
|
||||
monkeypatch.setenv("STOCK_API_URL", "https://test.stock.local")
|
||||
monkeypatch.setenv("WEBAI_API_KEY", "test-secret")
|
||||
# Reload modules so they pick up the new env
|
||||
import importlib
|
||||
from ai_trade import config as cfg
|
||||
importlib.reload(cfg)
|
||||
from ai_trade import main as main_mod
|
||||
importlib.reload(main_mod)
|
||||
with TestClient(main_mod.app) as client:
|
||||
r = client.get("/health")
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
assert body["status"] == "online"
|
||||
assert body["stock_api_url"] == "https://test.stock.local"
|
||||
|
||||
|
||||
def test_startup_warns_if_webai_api_key_missing(monkeypatch, caplog):
|
||||
# Use setenv with empty string + no-op load_dotenv to defeat .env re-read on reload
|
||||
monkeypatch.setattr("ai_trade.config.load_dotenv", lambda *a, **k: None)
|
||||
monkeypatch.setenv("WEBAI_API_KEY", "")
|
||||
monkeypatch.setenv("STOCK_API_URL", "https://test.stock.local")
|
||||
import importlib
|
||||
from ai_trade import config as cfg
|
||||
importlib.reload(cfg)
|
||||
# After reload, load_dotenv reference is fresh — re-patch
|
||||
monkeypatch.setattr("ai_trade.config.load_dotenv", lambda *a, **k: None)
|
||||
from ai_trade import main as main_mod
|
||||
importlib.reload(main_mod)
|
||||
with caplog.at_level(logging.WARNING, logger="ai_trade.main"):
|
||||
with TestClient(main_mod.app) as client:
|
||||
client.get("/health")
|
||||
assert any("WEBAI_API_KEY" in rec.message for rec in caplog.records)
|
||||
|
||||
|
||||
def test_startup_warns_if_kis_app_key_missing(monkeypatch, caplog):
|
||||
"""KIS app_key 미설정 시 startup WARNING (KIS 호출 disabled) — V1 패턴."""
|
||||
monkeypatch.setattr("ai_trade.config.load_dotenv", lambda *a, **k: None)
|
||||
monkeypatch.setenv("STOCK_API_URL", "https://test.stock.local")
|
||||
monkeypatch.setenv("WEBAI_API_KEY", "test-secret")
|
||||
# V1 pattern: kis_env_type=virtual, both virtual keys empty
|
||||
monkeypatch.setenv("KIS_ENV_TYPE", "virtual")
|
||||
monkeypatch.setenv("KIS_VIRTUAL_APP_KEY", "")
|
||||
monkeypatch.setenv("KIS_REAL_APP_KEY", "")
|
||||
|
||||
import importlib
|
||||
from ai_trade import config as cfg
|
||||
importlib.reload(cfg)
|
||||
monkeypatch.setattr("ai_trade.config.load_dotenv", lambda *a, **k: None)
|
||||
from ai_trade import main as main_mod
|
||||
importlib.reload(main_mod)
|
||||
with caplog.at_level(logging.WARNING, logger="ai_trade.main"):
|
||||
with TestClient(main_mod.app) as client:
|
||||
client.get("/health")
|
||||
assert any("KIS" in rec.message and "app_key" in rec.message.lower() for rec in caplog.records)
|
||||
92
ai_trade/tests/test_momentum_classifier.py
Normal file
92
ai_trade/tests/test_momentum_classifier.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Tests for minute momentum classifier."""
|
||||
from collections import deque
|
||||
|
||||
from ai_trade.momentum_classifier import (
|
||||
aggregate_1min_to_5min, classify_minute_momentum,
|
||||
STRONG_UP, WEAK_UP, NEUTRAL, WEAK_DOWN, STRONG_DOWN,
|
||||
)
|
||||
|
||||
|
||||
def _bar(open_, high, low, close, volume):
|
||||
return {
|
||||
"datetime": "2026-05-18T09:00:00+09:00",
|
||||
"open": open_, "high": high, "low": low, "close": close, "volume": volume,
|
||||
}
|
||||
|
||||
|
||||
def _make_chunks(num_chunks_up: int, num_chunks_total: int, base_vol: int = 1000):
|
||||
"""num_chunks_total 개의 5-bar 청크. num_chunks_up 청크는 양봉, 나머지는 음봉.
|
||||
각 청크는 5개 1분봉. 거래량 = base_vol per bar.
|
||||
"""
|
||||
bars = []
|
||||
for i in range(num_chunks_total):
|
||||
is_up = i < num_chunks_up
|
||||
o, c = (100, 110) if is_up else (110, 100)
|
||||
for j in range(5):
|
||||
bars.append(_bar(o, max(o, c) + 5, min(o, c) - 5, c, base_vol))
|
||||
return bars
|
||||
|
||||
|
||||
def test_strong_up_5_consecutive_green_with_high_volume():
|
||||
"""직전 5개 5분봉 모두 양봉 + 거래량 1.5x → STRONG_UP."""
|
||||
# 60분 (12 5분봉) 데이터: 7 normal + 5 high-vol up
|
||||
older = _make_chunks(num_chunks_up=3, num_chunks_total=7, base_vol=1000)
|
||||
recent = _make_chunks(num_chunks_up=5, num_chunks_total=5, base_vol=2500)
|
||||
minute_bars = deque(older + recent, maxlen=60)
|
||||
assert classify_minute_momentum(minute_bars) == STRONG_UP
|
||||
|
||||
|
||||
def test_weak_up_3of5_green_normal_volume():
|
||||
"""직전 5개 5분봉 중 3-4개 양봉 + 거래량 ≥ 1.0x → WEAK_UP."""
|
||||
older = _make_chunks(num_chunks_up=3, num_chunks_total=7, base_vol=1000)
|
||||
# 5 chunks: 3 up + 2 down, normal vol
|
||||
recent_up = _make_chunks(num_chunks_up=3, num_chunks_total=3, base_vol=1000)
|
||||
recent_down = _make_chunks(num_chunks_up=0, num_chunks_total=2, base_vol=1000)
|
||||
minute_bars = deque(older + recent_up + recent_down, maxlen=60)
|
||||
assert classify_minute_momentum(minute_bars) == WEAK_UP
|
||||
|
||||
|
||||
def test_neutral_mixed():
|
||||
"""up_count=2, vol normal → NEUTRAL (rule 미해당)."""
|
||||
older = _make_chunks(num_chunks_up=3, num_chunks_total=7, base_vol=1000)
|
||||
recent_up = _make_chunks(num_chunks_up=2, num_chunks_total=2, base_vol=1000)
|
||||
recent_down = _make_chunks(num_chunks_up=0, num_chunks_total=3, base_vol=1000)
|
||||
minute_bars = deque(older + recent_up + recent_down, maxlen=60)
|
||||
# up_count=2, vol_mult=1.0 → 어느 분기 조건도 만족 안 함 → NEUTRAL
|
||||
assert classify_minute_momentum(minute_bars) == NEUTRAL
|
||||
|
||||
|
||||
def test_weak_down_low_green_low_volume():
|
||||
"""up_count <= 2 + vol < 1.0 → WEAK_DOWN."""
|
||||
older = _make_chunks(num_chunks_up=3, num_chunks_total=7, base_vol=1000)
|
||||
recent_up = _make_chunks(num_chunks_up=1, num_chunks_total=1, base_vol=500)
|
||||
recent_down = _make_chunks(num_chunks_up=0, num_chunks_total=4, base_vol=500)
|
||||
minute_bars = deque(older + recent_up + recent_down, maxlen=60)
|
||||
# recent 5 chunks avg vol = 500, long 12 avg ≈ (7*1000 + 5*500) / 12 ≈ 791 → vol_mult ≈ 0.63
|
||||
assert classify_minute_momentum(minute_bars) == WEAK_DOWN
|
||||
|
||||
|
||||
def test_strong_down_5_consecutive_red_high_volume():
|
||||
"""직전 5개 5분봉 모두 음봉 + 거래량 1.5x → STRONG_DOWN."""
|
||||
older = _make_chunks(num_chunks_up=3, num_chunks_total=7, base_vol=1000)
|
||||
recent = _make_chunks(num_chunks_up=0, num_chunks_total=5, base_vol=2500)
|
||||
minute_bars = deque(older + recent, maxlen=60)
|
||||
assert classify_minute_momentum(minute_bars) == STRONG_DOWN
|
||||
|
||||
|
||||
def test_aggregate_1min_to_5min_correctness():
|
||||
"""5 1분봉 → 1개 5분봉 — open/close/high/low/volume 정확."""
|
||||
bars = [
|
||||
_bar(100, 105, 99, 102, 1000),
|
||||
_bar(102, 108, 101, 107, 1500),
|
||||
_bar(107, 110, 105, 106, 800),
|
||||
_bar(106, 109, 104, 108, 1200),
|
||||
_bar(108, 112, 107, 111, 900),
|
||||
]
|
||||
result = aggregate_1min_to_5min(bars)
|
||||
assert len(result) == 1
|
||||
assert result[0]["open"] == 100 # 첫 bar
|
||||
assert result[0]["close"] == 111 # 마지막 bar
|
||||
assert result[0]["high"] == 112 # max
|
||||
assert result[0]["low"] == 99 # min
|
||||
assert result[0]["volume"] == 5400 # sum
|
||||
131
ai_trade/tests/test_pull_worker.py
Normal file
131
ai_trade/tests/test_pull_worker.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Tests for pull_worker (Phase 3a additions)."""
|
||||
from collections import deque
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from ai_trade.state import PollState
|
||||
|
||||
|
||||
async def test_minute_polling_cycle_updates_state_minute_bars():
|
||||
"""KIS REST mock 의 분봉 데이터가 state.minute_bars[ticker] deque 에 들어간다."""
|
||||
from ai_trade.pull_worker import _run_kis_minute_cycle
|
||||
|
||||
state = PollState()
|
||||
state.portfolio = {"holdings": [{"ticker": "005930"}, {"ticker": "000660"}]}
|
||||
state.screener_preview = {
|
||||
"items": [{"ticker": "005930"}, {"ticker": "035720"}]
|
||||
}
|
||||
|
||||
kis_client_mock = MagicMock()
|
||||
kis_client_mock.get_minute_ohlcv = AsyncMock(side_effect=[
|
||||
[{"datetime": "2026-05-18T09:00:00+09:00", "open": 78000,
|
||||
"high": 78500, "low": 77900, "close": 78300, "volume": 12345}],
|
||||
[{"datetime": "2026-05-18T09:00:00+09:00", "open": 180000,
|
||||
"high": 181000, "low": 179800, "close": 180500, "volume": 5000}],
|
||||
[{"datetime": "2026-05-18T09:00:00+09:00", "open": 51000,
|
||||
"high": 51200, "low": 50800, "close": 51100, "volume": 8000}],
|
||||
])
|
||||
kis_client_mock.get_asking_price = AsyncMock(return_value={
|
||||
"bid_total": 600, "ask_total": 400, "bid_ratio": 0.6,
|
||||
"current_price": 51100, "as_of": "2026-05-18T09:00:30+09:00",
|
||||
})
|
||||
|
||||
await _run_kis_minute_cycle(kis_client_mock, state)
|
||||
|
||||
# 3 unique tickers (005930, 000660, 035720)
|
||||
assert "005930" in state.minute_bars
|
||||
assert "000660" in state.minute_bars
|
||||
assert "035720" in state.minute_bars
|
||||
assert len(state.minute_bars["005930"]) >= 1
|
||||
# asking_price 만 screener-only ticker (035720) 에 들어가야 함
|
||||
# (portfolio = 005930, 000660 는 WebSocket 으로 들어옴)
|
||||
assert "035720" in state.asking_price
|
||||
|
||||
|
||||
def test_websocket_message_updates_state_asking_price():
|
||||
"""WebSocket callback factory → state.asking_price 갱신."""
|
||||
from ai_trade.pull_worker import make_asking_price_callback
|
||||
|
||||
state = PollState()
|
||||
cb = make_asking_price_callback(state)
|
||||
cb("005930", {"bid_total": 1000, "ask_total": 800, "bid_ratio": 0.555,
|
||||
"current_price": 78500, "as_of": "2026-05-18T10:00:00+09:00"})
|
||||
assert state.asking_price["005930"]["bid_total"] == 1000
|
||||
assert "asking_price/005930" in state.last_updated
|
||||
|
||||
|
||||
async def test_post_close_cycle_updates_chronos_predictions():
|
||||
"""mock kis + mock chronos → state.chronos_predictions + state.daily_ohlcv 갱신."""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from ai_trade.pull_worker import _run_post_close_cycle
|
||||
from ai_trade.chronos_predictor import ChronosPrediction
|
||||
from ai_trade.state import PollState
|
||||
|
||||
state = PollState()
|
||||
state.portfolio = {"holdings": [{"ticker": "005930"}]}
|
||||
state.screener_preview = {"items": [{"ticker": "000660"}]}
|
||||
|
||||
kis_mock = MagicMock()
|
||||
daily_005930 = [{"datetime": f"2026-05-{i+1:02d}", "open": 100, "high": 105,
|
||||
"low": 95, "close": 100 + i, "volume": 1000} for i in range(60)]
|
||||
daily_000660 = [{"datetime": f"2026-05-{i+1:02d}", "open": 200, "high": 210,
|
||||
"low": 190, "close": 200 + i, "volume": 2000} for i in range(60)]
|
||||
# _run_post_close_cycle iterates tickers and calls get_daily_ohlcv per ticker.
|
||||
# Order depends on set() so use side_effect mapping if possible, otherwise list.
|
||||
async def fake_daily(ticker, days=60):
|
||||
if ticker == "005930":
|
||||
return daily_005930
|
||||
if ticker == "000660":
|
||||
return daily_000660
|
||||
return []
|
||||
kis_mock.get_daily_ohlcv = AsyncMock(side_effect=fake_daily)
|
||||
|
||||
chronos_mock = MagicMock()
|
||||
chronos_mock.predict_batch = MagicMock(return_value={
|
||||
"005930": ChronosPrediction(0.02, -0.01, 0.04, 0.85, "2026-05-18T16:00:00+09:00"),
|
||||
"000660": ChronosPrediction(0.03, -0.02, 0.06, 0.75, "2026-05-18T16:00:00+09:00"),
|
||||
})
|
||||
|
||||
await _run_post_close_cycle(kis_mock, chronos_mock, state)
|
||||
|
||||
assert "005930" in state.chronos_predictions
|
||||
assert "000660" in state.chronos_predictions
|
||||
assert state.chronos_predictions["005930"]["median"] == 0.02
|
||||
assert state.chronos_predictions["005930"]["conf"] == 0.85
|
||||
assert "005930" in state.daily_ohlcv
|
||||
assert "chronos/005930" in state.last_updated
|
||||
|
||||
|
||||
def test_poll_loop_calls_generate_signals_after_cycle(monkeypatch):
|
||||
"""Phase 4: generate_signals 가 cycle 후 state.signals 를 갱신한다."""
|
||||
from unittest.mock import MagicMock
|
||||
from ai_trade.state import PollState
|
||||
from ai_trade.signal_generator import generate_signals
|
||||
|
||||
state = PollState()
|
||||
state.portfolio = {"holdings": [{
|
||||
"ticker": "005930", "name": "삼성전자",
|
||||
"avg_price": 75000, "current_price": 69000,
|
||||
"pnl_pct": -0.08, "profit_rate": -8.0,
|
||||
"quantity": 100, "broker": "키움",
|
||||
}]}
|
||||
state.screener_preview = {"items": []}
|
||||
|
||||
dedup = MagicMock()
|
||||
dedup.is_recent.return_value = False
|
||||
|
||||
settings = MagicMock()
|
||||
settings.stop_loss_pct = -0.07
|
||||
settings.take_profit_pct = 0.15
|
||||
settings.chronos_spread_threshold = 0.6
|
||||
settings.asking_bid_ratio_threshold = 0.6
|
||||
settings.confidence_threshold = 0.7
|
||||
settings.min_momentum_for_buy = "strong_up"
|
||||
|
||||
generate_signals(state, dedup, settings)
|
||||
|
||||
assert "005930" in state.signals
|
||||
assert state.signals["005930"]["action"] == "sell"
|
||||
assert state.signals["005930"]["confidence_webai"] == 1.0
|
||||
dedup.record.assert_called_with("005930", "sell", confidence=1.0)
|
||||
34
ai_trade/tests/test_rate_limit.py
Normal file
34
ai_trade/tests/test_rate_limit.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Tests for SignalDedup."""
|
||||
from datetime import datetime, timedelta
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from ai_trade.rate_limit import SignalDedup
|
||||
|
||||
KST = ZoneInfo("Asia/Seoul")
|
||||
|
||||
|
||||
def test_is_recent_returns_false_for_new_ticker_action(tmp_dedup_db):
|
||||
dedup = SignalDedup(tmp_dedup_db)
|
||||
assert dedup.is_recent("005930", "buy") is False
|
||||
|
||||
|
||||
def test_is_recent_returns_true_within_24h(tmp_dedup_db):
|
||||
dedup = SignalDedup(tmp_dedup_db)
|
||||
dedup.record("005930", "buy", confidence=0.82)
|
||||
assert dedup.is_recent("005930", "buy") is True
|
||||
|
||||
|
||||
def test_is_recent_returns_false_after_24h(tmp_dedup_db, monkeypatch):
|
||||
dedup = SignalDedup(tmp_dedup_db)
|
||||
# Record with a timestamp 25 hours ago
|
||||
now = datetime.now(KST)
|
||||
fake_now = now - timedelta(hours=25)
|
||||
monkeypatch.setattr(
|
||||
"ai_trade.rate_limit._now_iso", lambda: fake_now.isoformat()
|
||||
)
|
||||
dedup.record("005930", "buy", confidence=0.82)
|
||||
# Reset to real now for is_recent check
|
||||
monkeypatch.setattr(
|
||||
"ai_trade.rate_limit._now_iso", lambda: now.isoformat()
|
||||
)
|
||||
assert dedup.is_recent("005930", "buy", within_hours=24) is False
|
||||
81
ai_trade/tests/test_scheduler.py
Normal file
81
ai_trade/tests/test_scheduler.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Tests for scheduler interval logic."""
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from ai_trade.scheduler import _next_interval, _is_market_day, KST
|
||||
|
||||
|
||||
def _kst(year, month, day, hour, minute=0):
|
||||
return datetime(year, month, day, hour, minute, tzinfo=KST)
|
||||
|
||||
|
||||
def test_next_interval_pre_market_5min():
|
||||
now = _kst(2026, 5, 18, 8, 30) # Monday 08:30
|
||||
assert _next_interval(now) == 300
|
||||
|
||||
|
||||
def test_next_interval_market_open_1min():
|
||||
now = _kst(2026, 5, 18, 10, 0) # Monday 10:00
|
||||
assert _next_interval(now) == 60
|
||||
|
||||
|
||||
def test_next_interval_post_market_5min():
|
||||
now = _kst(2026, 5, 18, 17, 0) # Monday 17:00
|
||||
assert _next_interval(now) == 300
|
||||
|
||||
|
||||
def test_next_interval_overnight_skip_to_next_morning():
|
||||
now = _kst(2026, 5, 18, 2, 30) # Monday 02:30 (dead zone, not NXT window)
|
||||
interval = _next_interval(now)
|
||||
# Dead zone 23:30-04:30 → next 04:30 is ~2h away
|
||||
assert 2 * 3600 - 60 < interval < 2 * 3600 + 60
|
||||
|
||||
|
||||
def test_next_interval_holiday_skip():
|
||||
# 2026-05-05 어린이날 (Tuesday holiday)
|
||||
now = _kst(2026, 5, 5, 10, 0)
|
||||
assert _is_market_day(now) is False
|
||||
interval = _next_interval(now)
|
||||
# Next: 2026-05-06 (Wed) 07:00, ~21h away
|
||||
assert 20 * 3600 < interval < 22 * 3600
|
||||
|
||||
|
||||
def test_next_interval_at_market_open_boundary():
|
||||
"""09:00:00 정확 second → 60초 (market 구간 진입)."""
|
||||
now = _kst(2026, 5, 18, 9, 0) # Monday 09:00:00
|
||||
assert _next_interval(now) == 60
|
||||
|
||||
|
||||
def test_next_interval_at_market_close_boundary():
|
||||
"""15:30:00 정확 second → 300초 (post-market 구간 진입)."""
|
||||
now = _kst(2026, 5, 18, 15, 30) # Monday 15:30:00
|
||||
assert _next_interval(now) == 300
|
||||
|
||||
|
||||
def test_next_interval_at_polling_window_end_boundary():
|
||||
"""23:30:00 정확 second → dead zone skip (다음 04:30 까지)."""
|
||||
now = _kst(2026, 5, 18, 23, 30) # Monday 23:30:00 (NXT_PRE_END boundary)
|
||||
interval = _next_interval(now)
|
||||
# Dead zone 23:30-04:30 → next 04:30 is ~5h away
|
||||
assert 5 * 3600 - 60 < interval < 5 * 3600 + 60
|
||||
|
||||
|
||||
def test_next_interval_nxt_evening_5min():
|
||||
"""22:00 평일 (NXT 야간) → 300 (5분)."""
|
||||
now = _kst(2026, 5, 18, 22, 0)
|
||||
assert _next_interval(now) == 300
|
||||
|
||||
|
||||
def test_next_interval_nxt_dawn_5min():
|
||||
"""05:30 평일 (NXT 새벽) → 300 (5분)."""
|
||||
now = _kst(2026, 5, 18, 5, 30)
|
||||
assert _next_interval(now) == 300
|
||||
|
||||
|
||||
def test_next_interval_dead_zone_skip():
|
||||
"""02:00 평일 (dead zone 23:30-04:30) → 다음 04:30 까지 (~9000s)."""
|
||||
now = _kst(2026, 5, 18, 2, 0)
|
||||
interval = _next_interval(now)
|
||||
# 02:00 → 04:30 = 2.5h = 9000s
|
||||
assert 9000 - 60 < interval < 9000 + 60
|
||||
172
ai_trade/tests/test_signal_generator.py
Normal file
172
ai_trade/tests/test_signal_generator.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""Tests for signal_generator."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from ai_trade.signal_generator import generate_signals
|
||||
from ai_trade.state import PollState
|
||||
|
||||
|
||||
def _settings(**overrides):
|
||||
"""Build a Settings-like object for tests (avoid env)."""
|
||||
defaults = dict(
|
||||
stop_loss_pct=-0.07,
|
||||
take_profit_pct=0.15,
|
||||
chronos_spread_threshold=0.6,
|
||||
asking_bid_ratio_threshold=0.6,
|
||||
confidence_threshold=0.7,
|
||||
min_momentum_for_buy="strong_up",
|
||||
)
|
||||
defaults.update(overrides)
|
||||
m = MagicMock()
|
||||
for k, v in defaults.items():
|
||||
setattr(m, k, v)
|
||||
return m
|
||||
|
||||
|
||||
def _make_state_with_buy_candidate(
|
||||
ticker="005930", name="삼성전자",
|
||||
chronos_median=0.02, chronos_q10=-0.01, chronos_q90=0.04, chronos_conf=0.85,
|
||||
momentum="strong_up", bid_ratio=0.7, current_price=78500,
|
||||
):
|
||||
state = PollState()
|
||||
state.screener_preview = {"items": [{"ticker": ticker, "name": name}]}
|
||||
state.chronos_predictions[ticker] = {
|
||||
"median": chronos_median, "q10": chronos_q10, "q90": chronos_q90,
|
||||
"conf": chronos_conf, "as_of": "2026-05-17T16:00:00+09:00",
|
||||
}
|
||||
state.minute_momentum[ticker] = momentum
|
||||
state.asking_price[ticker] = {
|
||||
"bid_total": int(bid_ratio * 1000),
|
||||
"ask_total": int((1 - bid_ratio) * 1000),
|
||||
"bid_ratio": bid_ratio,
|
||||
"current_price": current_price,
|
||||
"as_of": "2026-05-17T16:00:01+09:00",
|
||||
}
|
||||
return state
|
||||
|
||||
|
||||
def _make_state_with_holding(
|
||||
ticker="005930", name="삼성전자",
|
||||
pnl_pct=0.0, avg_price=75000, current_price=75000,
|
||||
):
|
||||
state = PollState()
|
||||
state.portfolio = {"holdings": [{
|
||||
"ticker": ticker, "name": name,
|
||||
"avg_price": avg_price, "current_price": current_price,
|
||||
"pnl_pct": pnl_pct, "profit_rate": pnl_pct * 100,
|
||||
"quantity": 100, "broker": "키움",
|
||||
}]}
|
||||
state.screener_preview = {"items": []}
|
||||
return state
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dedup_mock():
|
||||
d = MagicMock()
|
||||
d.is_recent.return_value = False
|
||||
return d
|
||||
|
||||
|
||||
def test_buy_signal_when_all_conditions_pass_and_confidence_high(dedup_mock):
|
||||
state = _make_state_with_buy_candidate()
|
||||
generate_signals(state, dedup_mock, _settings())
|
||||
assert "005930" in state.signals
|
||||
sig = state.signals["005930"]
|
||||
assert sig["action"] == "buy"
|
||||
assert sig["confidence_webai"] > 0.7
|
||||
dedup_mock.record.assert_called()
|
||||
|
||||
|
||||
def test_silent_when_chronos_median_negative(dedup_mock):
|
||||
state = _make_state_with_buy_candidate(chronos_median=-0.01)
|
||||
generate_signals(state, dedup_mock, _settings())
|
||||
assert "005930" not in state.signals
|
||||
|
||||
|
||||
def test_silent_when_distribution_spread_too_wide(dedup_mock):
|
||||
# spread = q90 - q10 = 0.5 - (-0.5) = 1.0 > 0.6 → hard gate fails
|
||||
state = _make_state_with_buy_candidate(
|
||||
chronos_median=0.001, chronos_q10=-0.5, chronos_q90=0.5,
|
||||
)
|
||||
generate_signals(state, dedup_mock, _settings())
|
||||
assert "005930" not in state.signals
|
||||
|
||||
|
||||
def test_silent_when_momentum_not_strong_up(dedup_mock):
|
||||
state = _make_state_with_buy_candidate(momentum="weak_up")
|
||||
generate_signals(state, dedup_mock, _settings())
|
||||
assert "005930" not in state.signals
|
||||
|
||||
|
||||
def test_silent_when_bid_ratio_below_threshold(dedup_mock):
|
||||
state = _make_state_with_buy_candidate(bid_ratio=0.5)
|
||||
generate_signals(state, dedup_mock, _settings())
|
||||
assert "005930" not in state.signals
|
||||
|
||||
|
||||
def test_silent_when_confidence_below_threshold(dedup_mock):
|
||||
# chronos_conf low + rank=20 → confidence < 0.7
|
||||
state = _make_state_with_buy_candidate(chronos_conf=0.3)
|
||||
# add 19 fake items to push 005930 rank to 20
|
||||
state.screener_preview["items"] = (
|
||||
[{"ticker": f"FAKE{i:03d}"} for i in range(19)]
|
||||
+ [{"ticker": "005930", "name": "삼성전자"}]
|
||||
)
|
||||
generate_signals(state, dedup_mock, _settings())
|
||||
# confidence_webai = 0.3*0.5 + 1.0*0.3 + 0.05*0.2 = 0.46 < 0.7
|
||||
assert "005930" not in state.signals
|
||||
|
||||
|
||||
def test_sell_signal_when_stop_loss_triggered(dedup_mock):
|
||||
state = _make_state_with_holding(pnl_pct=-0.08, current_price=69000, avg_price=75000)
|
||||
generate_signals(state, dedup_mock, _settings())
|
||||
assert "005930" in state.signals
|
||||
sig = state.signals["005930"]
|
||||
assert sig["action"] == "sell"
|
||||
assert sig["confidence_webai"] == 1.0
|
||||
assert sig["pnl_pct"] == -0.08
|
||||
|
||||
|
||||
def test_sell_signal_when_take_profit_triggered(dedup_mock):
|
||||
state = _make_state_with_holding(pnl_pct=0.16, current_price=87000, avg_price=75000)
|
||||
generate_signals(state, dedup_mock, _settings())
|
||||
assert "005930" in state.signals
|
||||
sig = state.signals["005930"]
|
||||
assert sig["action"] == "sell"
|
||||
assert sig["confidence_webai"] == 0.6
|
||||
|
||||
|
||||
def test_silent_when_dedup_recently_sent(dedup_mock):
|
||||
state = _make_state_with_buy_candidate()
|
||||
dedup_mock.is_recent.return_value = True
|
||||
generate_signals(state, dedup_mock, _settings())
|
||||
assert "005930" not in state.signals
|
||||
dedup_mock.record.assert_not_called()
|
||||
|
||||
|
||||
def test_sell_signal_triggers_on_anomaly_path(dedup_mock):
|
||||
"""Anomaly sell: median < -1%, momentum strong_down, low bid_ratio, confidence > threshold."""
|
||||
state = PollState()
|
||||
state.portfolio = {"holdings": [{
|
||||
"ticker": "005930", "name": "삼성전자",
|
||||
"avg_price": 75000, "current_price": 70000,
|
||||
"pnl_pct": -0.067, # within stop_loss tolerance (default -0.07): NOT triggering stop_loss
|
||||
"quantity": 100, "broker": "키움",
|
||||
}]}
|
||||
state.screener_preview = {"items": []}
|
||||
state.chronos_predictions["005930"] = {
|
||||
"median": -0.025, "q10": -0.05, "q90": 0.005, "conf": 0.85,
|
||||
}
|
||||
state.minute_momentum["005930"] = "strong_down"
|
||||
state.asking_price["005930"] = {"current_price": 70000, "bid_ratio": 0.30}
|
||||
# bid_ratio 0.30 < (1 - 0.6) = 0.4 → anomaly bid_ratio gate passes
|
||||
# confidence = 0.85*0.5 + 1.0*0.3 + 1.0*0.2 = 0.425 + 0.3 + 0.2 = 0.925 > 0.7
|
||||
|
||||
generate_signals(state, dedup_mock, _settings())
|
||||
|
||||
assert "005930" in state.signals
|
||||
sig = state.signals["005930"]
|
||||
assert sig["action"] == "sell"
|
||||
assert sig["context"]["sell_reason"] == "anomaly"
|
||||
assert sig["confidence_webai"] > 0.7
|
||||
168
ai_trade/tests/test_stock_client.py
Normal file
168
ai_trade/tests/test_stock_client.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Tests for stock_client.StockClient."""
|
||||
import asyncio
|
||||
import logging
|
||||
import pytest
|
||||
import httpx
|
||||
|
||||
from ai_trade.stock_client import StockClient
|
||||
|
||||
|
||||
BASE_URL = "https://test.stock.local"
|
||||
API_KEY = "test-secret"
|
||||
|
||||
|
||||
async def test_get_portfolio_normal_returns_dict_with_pnl_pct(mock_stock_api):
|
||||
"""정상 200 응답 + cache 저장."""
|
||||
mock_stock_api.get("/api/webai/portfolio").mock(
|
||||
return_value=httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"holdings": [{"ticker": "005930", "pnl_pct": 0.047}],
|
||||
"cash": [],
|
||||
"summary": {},
|
||||
},
|
||||
)
|
||||
)
|
||||
client = StockClient(BASE_URL, API_KEY)
|
||||
try:
|
||||
result = await client.get_portfolio()
|
||||
assert result["holdings"][0]["pnl_pct"] == 0.047
|
||||
# Cache populated
|
||||
assert len(client._cache) >= 1
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
async def test_get_portfolio_uses_cache_within_ttl(mock_stock_api):
|
||||
"""180s TTL 내 두번째 호출 = mock 콜 1회."""
|
||||
route = mock_stock_api.get("/api/webai/portfolio").mock(
|
||||
return_value=httpx.Response(
|
||||
200, json={"holdings": [], "cash": [], "summary": {}}
|
||||
)
|
||||
)
|
||||
client = StockClient(BASE_URL, API_KEY)
|
||||
try:
|
||||
await client.get_portfolio()
|
||||
await client.get_portfolio() # second call within TTL
|
||||
assert route.call_count == 1
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
async def test_get_portfolio_refetches_after_ttl_expiry(mock_stock_api, monkeypatch):
|
||||
"""TTL 만료 후 재호출 = mock 콜 2회. time.monotonic 모킹."""
|
||||
route = mock_stock_api.get("/api/webai/portfolio").mock(
|
||||
return_value=httpx.Response(
|
||||
200, json={"holdings": [], "cash": [], "summary": {}}
|
||||
)
|
||||
)
|
||||
# Fake clock: starts at 0, jumps past portfolio TTL (180s) between calls
|
||||
fake_time = [0.0]
|
||||
monkeypatch.setattr(
|
||||
"ai_trade.stock_client.time.monotonic", lambda: fake_time[0]
|
||||
)
|
||||
|
||||
client = StockClient(BASE_URL, API_KEY)
|
||||
try:
|
||||
await client.get_portfolio()
|
||||
fake_time[0] = 181.0 # 180s TTL 만료
|
||||
await client.get_portfolio()
|
||||
assert route.call_count == 2
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
async def test_get_portfolio_retries_3_times_on_timeout(mock_stock_api, monkeypatch):
|
||||
"""timeout 2번 + 200 1번 → 최종 성공. exponential sleep 호출 검증."""
|
||||
sleep_calls = []
|
||||
|
||||
async def fake_sleep(s):
|
||||
sleep_calls.append(s)
|
||||
|
||||
monkeypatch.setattr("asyncio.sleep", fake_sleep)
|
||||
|
||||
mock_stock_api.get("/api/webai/portfolio").mock(
|
||||
side_effect=[
|
||||
httpx.TimeoutException("timeout 1"),
|
||||
httpx.TimeoutException("timeout 2"),
|
||||
httpx.Response(
|
||||
200, json={"holdings": [], "cash": [], "summary": {}}
|
||||
),
|
||||
]
|
||||
)
|
||||
client = StockClient(BASE_URL, API_KEY)
|
||||
try:
|
||||
result = await client.get_portfolio()
|
||||
assert result["holdings"] == []
|
||||
assert sleep_calls == [1, 2] # exponential 1s, 2s
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
async def test_get_portfolio_429_triggers_backoff(mock_stock_api, monkeypatch):
|
||||
"""429 → 1s backoff → 200."""
|
||||
sleep_calls = []
|
||||
|
||||
async def fake_sleep(s):
|
||||
sleep_calls.append(s)
|
||||
|
||||
monkeypatch.setattr("asyncio.sleep", fake_sleep)
|
||||
|
||||
mock_stock_api.get("/api/webai/portfolio").mock(
|
||||
side_effect=[
|
||||
httpx.Response(429, text="rate limit"),
|
||||
httpx.Response(
|
||||
200, json={"holdings": [], "cash": [], "summary": {}}
|
||||
),
|
||||
]
|
||||
)
|
||||
client = StockClient(BASE_URL, API_KEY)
|
||||
try:
|
||||
result = await client.get_portfolio()
|
||||
assert result["holdings"] == []
|
||||
assert sleep_calls == [1]
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
async def test_get_portfolio_falls_back_to_stale_on_all_failures(
|
||||
mock_stock_api, monkeypatch, caplog
|
||||
):
|
||||
"""cache 에 이전 성공 응답 + 모든 retry 5xx → stale 반환 + logger.warning."""
|
||||
# No-op sleep for fast test
|
||||
async def fake_sleep(s):
|
||||
return None
|
||||
monkeypatch.setattr("asyncio.sleep", fake_sleep)
|
||||
|
||||
# Patch time.monotonic BEFORE first call so cached timestamp uses fake clock
|
||||
fake_time = [0.0]
|
||||
monkeypatch.setattr(
|
||||
"ai_trade.stock_client.time.monotonic", lambda: fake_time[0]
|
||||
)
|
||||
|
||||
# First call succeeds
|
||||
route1 = mock_stock_api.get("/api/webai/portfolio").mock(
|
||||
return_value=httpx.Response(
|
||||
200,
|
||||
json={"holdings": [{"ticker": "005930"}], "cash": [], "summary": {}},
|
||||
)
|
||||
)
|
||||
client = StockClient(BASE_URL, API_KEY)
|
||||
try:
|
||||
first = await client.get_portfolio()
|
||||
assert first["holdings"][0]["ticker"] == "005930"
|
||||
|
||||
# Advance fake clock past TTL (180s) so cache is stale
|
||||
fake_time[0] = 181.0
|
||||
|
||||
# Now mock to return 500s persistently
|
||||
route1.mock(return_value=httpx.Response(500, text="server error"))
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="ai_trade.stock_client"):
|
||||
result = await client.get_portfolio()
|
||||
assert result["holdings"][0]["ticker"] == "005930" # stale data returned
|
||||
assert any(
|
||||
"stale" in rec.message.lower() for rec in caplog.records
|
||||
)
|
||||
finally:
|
||||
await client.close()
|
||||
18
ai_trade/tests/test_stock_client_ttl.py
Normal file
18
ai_trade/tests/test_stock_client_ttl.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# tests/test_stock_client_ttl.py
|
||||
"""SP-A1 회귀 — _TTL이 NAS 부담 완화를 위한 값으로 설정되어 있어야 함."""
|
||||
from ai_trade.stock_client import _TTL
|
||||
|
||||
|
||||
def test_portfolio_ttl_is_180s():
|
||||
"""portfolio TTL은 180초 이상 (3분 폴링에서 1회 fetch가 3 폴링 커버)."""
|
||||
assert _TTL["portfolio"] >= 180.0
|
||||
|
||||
|
||||
def test_news_sentiment_ttl_is_600s():
|
||||
"""news-sentiment TTL은 600초 이상 (10분, 뉴스 sentiment는 자주 안 바뀜)."""
|
||||
assert _TTL["news-sentiment"] >= 600.0
|
||||
|
||||
|
||||
def test_screener_preview_ttl_is_300s():
|
||||
"""screener-preview TTL은 300초 이상 (5분, Top-20은 분 단위로 거의 안 바뀜)."""
|
||||
assert _TTL["screener-preview"] >= 300.0
|
||||
Reference in New Issue
Block a user