diff --git a/ai_trade/kis_client.py b/ai_trade/kis_client.py index 5492360..84edb61 100644 --- a/ai_trade/kis_client.py +++ b/ai_trade/kis_client.py @@ -38,6 +38,7 @@ class KISClient: self._client = httpx.AsyncClient(timeout=timeout) self._token_cache: tuple[str, float] | None = None # (token, file_mtime) self._last_throttle_at = 0.0 + self._throttle_lock = asyncio.Lock() async def close(self) -> None: await self._client.aclose() @@ -56,10 +57,13 @@ class KISClient: return token async def _throttle(self) -> None: - elapsed = time.monotonic() - self._last_throttle_at - if elapsed < _THROTTLE_INTERVAL: - await asyncio.sleep(_THROTTLE_INTERVAL - elapsed) - self._last_throttle_at = time.monotonic() + # F2: Lock으로 직렬화. 없으면 asyncio.gather 동시 호출 시 race로 + # 같은 elapsed 계산 후 동시에 깨어나 KIS 초당 2회(EGW00201) 위반. + async with self._throttle_lock: + elapsed = time.monotonic() - self._last_throttle_at + if elapsed < _THROTTLE_INTERVAL: + await asyncio.sleep(_THROTTLE_INTERVAL - elapsed) + self._last_throttle_at = time.monotonic() def _common_headers(self, tr_id: str) -> dict[str, str]: token = self._read_v1_token() diff --git a/ai_trade/tests/test_kis_client.py b/ai_trade/tests/test_kis_client.py index 125af6e..7d17bff 100644 --- a/ai_trade/tests/test_kis_client.py +++ b/ai_trade/tests/test_kis_client.py @@ -1,5 +1,7 @@ """Tests for KISClient (REST).""" +import asyncio import json +import time as time_module from pathlib import Path import httpx @@ -159,3 +161,30 @@ async def test_get_daily_ohlcv_returns_60_bars(kis_client_factory): assert "datetime" in bars[0] finally: await client.close() + + +@respx.mock +async def test_throttle_serializes_concurrent_gather(kis_client_factory): + """F2 — 5개 동시 요청이 asyncio.gather로 들어와도 0.5초 간격으로 직렬화. + + 초당 2회 = 0.5초 간격. 5개 요청 시 최소 (5-1)*0.5 = 2.0초. + Lock 없으면 race condition으로 거의 동시에 나가 0.5초대로 끝남. + """ + sample = {"output2": []} + respx.get( + "https://openapivts.koreainvestment.com:29443" + "/uapi/domestic-stock/v1/quotations/inquire-time-itemchartprice" + ).mock(return_value=httpx.Response(200, json=sample)) + + client = kis_client_factory() + try: + start = time_module.monotonic() + await asyncio.gather(*[client.get_minute_ohlcv(f"00593{i}") for i in range(5)]) + elapsed = time_module.monotonic() - start + # 5 throttle = 최소 (5-1)*0.5 = 2.0s, tolerance 0.3s + assert elapsed >= 1.7, ( + f"throttle race condition: 5 concurrent calls took only {elapsed:.2f}s, " + f"expected >=1.7s (0.5s * 4 inter-call gaps)" + ) + finally: + await client.close()