diff --git a/signal_v2/kis_websocket.py b/signal_v2/kis_websocket.py new file mode 100644 index 0000000..edc81ff --- /dev/null +++ b/signal_v2/kis_websocket.py @@ -0,0 +1,186 @@ +"""KIS WebSocket — approval_key + 실시간 호가 구독.""" +from __future__ import annotations +import asyncio +import json +import logging +from datetime import datetime +from typing import Callable +from zoneinfo import ZoneInfo + +import httpx +import websockets + +logger = logging.getLogger(__name__) +KST = ZoneInfo("Asia/Seoul") + +# KIS 호가 메시지 필드 인덱스 (운영 환경 검증 필요) +# H0STASP0 응답: ticker | time | current_price | ... | ask_total | bid_total +# 본 spec/plan 의 가정: 마지막 2개 필드가 ask_total / bid_total +_ASKING_TICKER_IDX = 0 +_ASKING_TIME_IDX = 1 +_ASKING_CURRENT_PRICE_IDX = 2 +_ASKING_TOTAL_ASK_IDX = -2 +_ASKING_TOTAL_BID_IDX = -1 + + +class KISWebSocket: + """KIS WebSocket client. approval_key 발급 + 호가 실시간.""" + + def __init__(self, app_key: str, app_secret: str, is_virtual: bool): + self._app_key = app_key + self._app_secret = app_secret + self._is_virtual = is_virtual + self._base_rest = ( + "https://openapivts.koreainvestment.com:29443" if is_virtual + else "https://openapi.koreainvestment.com:9443" + ) + self._ws_url = ( + "ws://ops.koreainvestment.com:31000" if is_virtual + else "ws://ops.koreainvestment.com:21000" + ) + self._approval_key: str | None = None + self._ws = None + self._subscriptions: set[str] = set() + self._on_asking_price: Callable[[str, dict], None] | None = None + self._recv_task: asyncio.Task | None = None + self._shutdown = asyncio.Event() + + async def _fetch_approval_key(self) -> str: + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.post( + f"{self._base_rest}/oauth2/Approval", + json={ + "grant_type": "client_credentials", + "appkey": self._app_key, + "secretkey": self._app_secret, + }, + ) + response.raise_for_status() + data = response.json() + self._approval_key = data["approval_key"] + return self._approval_key + + async def _connect(self): + return await websockets.connect(self._ws_url) + + async def _connect_with_backoff(self): + """연결 시도 with exponential backoff (1s → 2s → 4s → max 30s).""" + for attempt in range(10): + try: + ws = await self._connect() + return ws + except Exception as e: + wait = min(2**attempt, 30) + logger.warning( + "KIS WebSocket connect failed (attempt %d): %r — retrying in %ds", + attempt + 1, e, wait, + ) + await asyncio.sleep(wait) + raise RuntimeError("KIS WebSocket connect exhausted retries") + + async def start( + self, tickers: list[str], + on_asking_price: Callable[[str, dict], None], + ) -> None: + if self._approval_key is None: + await self._fetch_approval_key() + self._on_asking_price = on_asking_price + self._ws = await self._connect_with_backoff() + for ticker in tickers: + await self.subscribe(ticker) + self._recv_task = asyncio.create_task(self._receive_loop()) + + async def subscribe(self, ticker: str) -> None: + if self._ws is None or self._approval_key is None: + raise RuntimeError("KIS WebSocket not started") + msg = json.dumps({ + "header": { + "approval_key": self._approval_key, + "custtype": "P", + "tr_type": "1", + "content-type": "utf-8", + }, + "body": { + "input": {"tr_id": "H0STASP0", "tr_key": ticker}, + }, + }) + await self._ws.send(msg) + self._subscriptions.add(ticker) + + async def unsubscribe(self, ticker: str) -> None: + if self._ws is None or self._approval_key is None: + return + msg = json.dumps({ + "header": { + "approval_key": self._approval_key, + "custtype": "P", + "tr_type": "2", + "content-type": "utf-8", + }, + "body": { + "input": {"tr_id": "H0STASP0", "tr_key": ticker}, + }, + }) + await self._ws.send(msg) + self._subscriptions.discard(ticker) + + async def close(self) -> None: + self._shutdown.set() + if self._recv_task is not None: + self._recv_task.cancel() + try: + await self._recv_task + except asyncio.CancelledError: + pass + if self._ws is not None: + await self._ws.close() + + async def _receive_loop(self) -> None: + while not self._shutdown.is_set(): + try: + raw = await self._ws.recv() + except websockets.ConnectionClosed: + logger.warning("KIS WebSocket closed — reconnecting") + self._ws = await self._connect_with_backoff() + for ticker in list(self._subscriptions): + await self.subscribe(ticker) + continue + if not isinstance(raw, str): + continue + parsed = self._parse_asking_price(raw) + if parsed is not None and self._on_asking_price is not None: + ticker, data = parsed + try: + self._on_asking_price(ticker, data) + except Exception: + logger.exception("on_asking_price callback failed") + + def _parse_asking_price(self, raw: str) -> tuple[str, dict] | None: + """KIS H0STASP0 raw → (ticker, asking_price dict). + + Raw format: '0|H0STASP0||' where data = '^'-joined fields. + Field indices (운영 검증 필요): 마지막 2개 가정 (ask, bid). + """ + try: + parts = raw.split("|") + if len(parts) < 4 or parts[1] != "H0STASP0": + return None + fields = parts[3].split("^") + ticker = fields[_ASKING_TICKER_IDX] + current_price_str = fields[_ASKING_CURRENT_PRICE_IDX] + current_price = int(current_price_str) if current_price_str.lstrip("-").isdigit() else 0 + ask_str = fields[_ASKING_TOTAL_ASK_IDX] + bid_str = fields[_ASKING_TOTAL_BID_IDX] + ask_total = int(ask_str) if ask_str.lstrip("-").isdigit() else 0 + bid_total = int(bid_str) if bid_str.lstrip("-").isdigit() else 0 + total = bid_total + ask_total + return ticker, { + "bid_total": bid_total, + "ask_total": ask_total, + "bid_ratio": bid_total / total if total > 0 else 0.0, + "current_price": current_price, + "as_of": datetime.now(KST).isoformat(), + } + except (IndexError, ValueError) as e: + logger.warning("parse_asking_price failed: %r", e) + return None diff --git a/signal_v2/tests/test_kis_websocket.py b/signal_v2/tests/test_kis_websocket.py new file mode 100644 index 0000000..82f18b8 --- /dev/null +++ b/signal_v2/tests/test_kis_websocket.py @@ -0,0 +1,94 @@ +"""Tests for KISWebSocket.""" +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock + +import httpx +import pytest +import respx + +from signal_v2.kis_websocket import KISWebSocket + + +BASE_REST = "https://openapivts.koreainvestment.com:29443" + + +@respx.mock +async def test_fetch_approval_key_via_oauth_endpoint(): + """POST /oauth2/Approval → approval_key 추출.""" + respx.post(f"{BASE_REST}/oauth2/Approval").mock( + return_value=httpx.Response(200, json={"approval_key": "test-approval-key-xyz"}) + ) + ws = KISWebSocket(app_key="k", app_secret="s", is_virtual=True) + key = await ws._fetch_approval_key() + assert key == "test-approval-key-xyz" + assert ws._approval_key == "test-approval-key-xyz" + + +async def test_subscribe_sends_h0stasp0_message(): + """subscribe() → WebSocket 으로 H0STASP0 구독 메시지 전송.""" + sent_messages = [] + mock_ws = AsyncMock() + mock_ws.send = AsyncMock(side_effect=lambda m: sent_messages.append(m)) + + ws = KISWebSocket(app_key="k", app_secret="s", is_virtual=True) + ws._approval_key = "test-key" + ws._ws = mock_ws + await ws.subscribe("005930") + assert ws._subscriptions == {"005930"} + assert len(sent_messages) == 1 + msg = json.loads(sent_messages[0]) + assert msg["header"]["tr_type"] == "1" # subscribe + assert msg["body"]["input"]["tr_id"] == "H0STASP0" + assert msg["body"]["input"]["tr_key"] == "005930" + + +def test_parse_asking_price_extracts_bid_ask_totals(): + """KIS raw '0|H0STASP0|001|...' → (ticker, dict). + + KIS 호가 메시지 형식 — KIS 공식 spec 의 정확한 필드 인덱스 운영 검증 필요. + 본 테스트는 implementer 의 _parse_asking_price 구현 인덱스에 맞춰서 sample 작성. + """ + ws = KISWebSocket(app_key="k", app_secret="s", is_virtual=True) + # Build a sample raw message — implementer 가 _ASKING_TOTAL_BID/ASK 인덱스에 + # 맞춰서 필드 배치하면 됨. 예: 마지막 2개 필드를 bid_total / ask_total 로. + fields = ["005930", "091500", "78500"] # ticker, time, current_price + fields.extend(["0"] * 40) # padding (KIS 의 실 필드 수 ~50개) + fields.append("400") # ask_total + fields.append("600") # bid_total + raw = f"0|H0STASP0|001|{'^'.join(fields)}" + + result = ws._parse_asking_price(raw) + assert result is not None, "parse_asking_price returned None" + ticker, data = result + assert ticker == "005930" + assert "bid_total" in data + assert "ask_total" in data + assert "bid_ratio" in data + assert "current_price" in data + # bid_total=600, ask_total=400, bid_ratio=0.6 + assert data["bid_total"] == 600 + assert data["ask_total"] == 400 + assert abs(data["bid_ratio"] - 0.6) < 1e-9 + + +async def test_reconnect_on_disconnect_with_backoff(monkeypatch): + """연결 끊김 → exponential backoff retry. _connect_with_backoff() 검증.""" + sleep_calls = [] + async def fake_sleep(s): sleep_calls.append(s) + monkeypatch.setattr("asyncio.sleep", fake_sleep) + + ws = KISWebSocket(app_key="k", app_secret="s", is_virtual=True) + # Mock _connect to fail twice then succeed + call_count = [0] + async def fake_connect(): + call_count[0] += 1 + if call_count[0] < 3: + raise ConnectionError("fake disconnect") + return AsyncMock() + monkeypatch.setattr(ws, "_connect", fake_connect) + + result = await ws._connect_with_backoff() + assert call_count[0] == 3 # 2 fails + 1 success + # exponential 1s, 2s + assert sleep_calls[:2] == [1, 2]