Files
ai-trade/ai_trade/tests/test_kis_client.py
gahusb 39adfc5fc5 fix(ai_trade): KIS throttle을 asyncio.Lock으로 직렬화 (F2)
코드 리뷰 F2: pull_worker.py가 asyncio.gather로 종목별 분봉/호가를 동시 호출하는데
_throttle()이 lock 없이 _last_throttle_at만 갱신해 race condition. 여러 coroutine이
같은 elapsed 계산 후 동시에 깨어나 KIS 초당 2회 한도(EGW00201) 위반 위험.

테스트로 5 concurrent gather 측정: 수정 전 0.51s → 수정 후 2.0s+ 직렬화 확인.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 19:32:50 +09:00

191 lines
6.3 KiB
Python

"""Tests for KISClient (REST)."""
import asyncio
import json
import time as time_module
from pathlib import Path
import httpx
import pytest
import respx
from ai_trade.kis_client import KISClient
@pytest.fixture
def fake_v1_token(tmp_path):
"""V1 토큰 파일 fixture."""
token_file = tmp_path / "kis_token.json"
token_file.write_text(json.dumps({
"access_token": "test-kis-token-abc123",
"token_expired": "2099-12-31 23:59:59",
}))
return token_file
@pytest.fixture
def kis_client_factory(fake_v1_token):
def _make():
return KISClient(
app_key="test-app-key",
app_secret="test-app-secret",
account="50000000-01",
is_virtual=True,
v1_token_path=fake_v1_token,
)
return _make
@respx.mock
async def test_get_minute_ohlcv_normal_returns_30_bars(kis_client_factory):
"""정상 200 → 30개 분봉 list 반환."""
sample_output2 = [
{
"stck_bsop_date": "20260518",
"stck_cntg_hour": f"09{m:02d}00",
"stck_oprc": "78000", "stck_hgpr": "78500",
"stck_lwpr": "77800", "stck_prpr": "78300",
"cntg_vol": "12345",
}
for m in range(30) # 9:00-9:29 = 30 bars
]
respx.get(
"https://openapivts.koreainvestment.com:29443/uapi/domestic-stock/v1/quotations/inquire-time-itemchartprice"
).mock(
return_value=httpx.Response(200, json={"output2": sample_output2})
)
client = kis_client_factory()
try:
bars = await client.get_minute_ohlcv("005930")
assert len(bars) == 30
assert bars[0]["close"] == 78300
assert "datetime" in bars[0]
finally:
await client.close()
@respx.mock
async def test_get_minute_ohlcv_429_retry_then_success(kis_client_factory, monkeypatch):
"""429 → exponential backoff → 200."""
sleep_calls = []
async def fake_sleep(s): sleep_calls.append(s)
monkeypatch.setattr("asyncio.sleep", fake_sleep)
respx.get(
"https://openapivts.koreainvestment.com:29443/uapi/domestic-stock/v1/quotations/inquire-time-itemchartprice"
).mock(side_effect=[
httpx.Response(429, text="rate limit"),
httpx.Response(200, json={"output2": []}),
])
client = kis_client_factory()
try:
result = await client.get_minute_ohlcv("005930")
assert result == []
assert 1 in sleep_calls
finally:
await client.close()
@respx.mock
async def test_get_minute_ohlcv_uses_v1_token(kis_client_factory, fake_v1_token):
"""KIS 호출 헤더에 V1 토큰 파일의 access_token 사용."""
route = respx.get(
"https://openapivts.koreainvestment.com:29443/uapi/domestic-stock/v1/quotations/inquire-time-itemchartprice"
).mock(return_value=httpx.Response(200, json={"output2": []}))
client = kis_client_factory()
try:
await client.get_minute_ohlcv("005930")
assert route.called
req = route.calls.last.request
# check authorization header contains the V1 token
auth = req.headers.get("authorization", "")
assert "test-kis-token-abc123" in auth
finally:
await client.close()
@respx.mock
async def test_get_asking_price_computes_bid_ratio(kis_client_factory):
"""호가 응답 → bid_total/(bid+ask) bid_ratio 계산."""
respx.get(
"https://openapivts.koreainvestment.com:29443/uapi/domestic-stock/v1/quotations/inquire-asking-price-exp-ccn"
).mock(return_value=httpx.Response(200, json={
"output1": {
"total_bidp_rsqn": "600",
"total_askp_rsqn": "400",
"stck_prpr": "78500",
}
}))
client = kis_client_factory()
try:
data = await client.get_asking_price("005930")
assert data["bid_total"] == 600
assert data["ask_total"] == 400
assert abs(data["bid_ratio"] - 0.6) < 1e-9
assert data["current_price"] == 78500
assert "as_of" in data
finally:
await client.close()
@respx.mock
async def test_get_daily_ohlcv_returns_60_bars(kis_client_factory):
"""KIS daily endpoint returns 60 ascending bars after parsing."""
# Build 60 KIS-format daily bars (descending dates as KIS does)
sample_output2 = []
for i in range(60):
# Generate a fake date 60 days ago, descending
day = 60 - i
sample_output2.append({
"stck_bsop_date": f"2026{(((day-1)//30)+1):02d}{(((day-1)%30)+1):02d}",
"stck_oprc": "78000", "stck_hgpr": "78500",
"stck_lwpr": "77800", "stck_clpr": str(78000 + i),
"acml_vol": "12345",
})
respx.get(
"https://openapivts.koreainvestment.com:29443/uapi/domestic-stock/v1/quotations/inquire-daily-itemchartprice"
).mock(return_value=httpx.Response(200, json={"output2": sample_output2}))
client = kis_client_factory()
try:
bars = await client.get_daily_ohlcv("005930", days=60)
# KIS returns descending; client reverses to ascending
assert len(bars) == 60
# Ascending order: first item has smaller datetime than last
assert bars[0]["datetime"] < bars[-1]["datetime"]
assert isinstance(bars[0]["open"], int)
assert isinstance(bars[0]["close"], int)
assert "datetime" in bars[0]
finally:
await client.close()
@respx.mock
async def test_throttle_serializes_concurrent_gather(kis_client_factory):
"""F2 — 5개 동시 요청이 asyncio.gather로 들어와도 0.5초 간격으로 직렬화.
초당 2회 = 0.5초 간격. 5개 요청 시 최소 (5-1)*0.5 = 2.0초.
Lock 없으면 race condition으로 거의 동시에 나가 0.5초대로 끝남.
"""
sample = {"output2": []}
respx.get(
"https://openapivts.koreainvestment.com:29443"
"/uapi/domestic-stock/v1/quotations/inquire-time-itemchartprice"
).mock(return_value=httpx.Response(200, json=sample))
client = kis_client_factory()
try:
start = time_module.monotonic()
await asyncio.gather(*[client.get_minute_ohlcv(f"00593{i}") for i in range(5)])
elapsed = time_module.monotonic() - start
# 5 throttle = 최소 (5-1)*0.5 = 2.0s, tolerance 0.3s
assert elapsed >= 1.7, (
f"throttle race condition: 5 concurrent calls took only {elapsed:.2f}s, "
f"expected >=1.7s (0.5s * 4 inter-call gaps)"
)
finally:
await client.close()