"""KIS WebSocket — approval_key + 실시간 호가 구독.""" from __future__ import annotations import asyncio import json import logging from datetime import datetime from typing import Callable from zoneinfo import ZoneInfo import httpx import websockets logger = logging.getLogger(__name__) KST = ZoneInfo("Asia/Seoul") # KIS 호가 메시지 필드 인덱스 (운영 환경 검증 필요) # H0STASP0 응답: ticker | time | current_price | ... | ask_total | bid_total # 본 spec/plan 의 가정: 마지막 2개 필드가 ask_total / bid_total _ASKING_TICKER_IDX = 0 _ASKING_TIME_IDX = 1 _ASKING_CURRENT_PRICE_IDX = 2 _ASKING_TOTAL_ASK_IDX = -2 _ASKING_TOTAL_BID_IDX = -1 class KISWebSocket: """KIS WebSocket client. approval_key 발급 + 호가 실시간.""" def __init__(self, app_key: str, app_secret: str, is_virtual: bool): self._app_key = app_key self._app_secret = app_secret self._is_virtual = is_virtual self._base_rest = ( "https://openapivts.koreainvestment.com:29443" if is_virtual else "https://openapi.koreainvestment.com:9443" ) self._ws_url = ( "ws://ops.koreainvestment.com:31000" if is_virtual else "ws://ops.koreainvestment.com:21000" ) self._approval_key: str | None = None self._ws = None self._subscriptions: set[str] = set() self._on_asking_price: Callable[[str, dict], None] | None = None self._recv_task: asyncio.Task | None = None self._shutdown = asyncio.Event() async def _fetch_approval_key(self) -> str: async with httpx.AsyncClient(timeout=10.0) as client: response = await client.post( f"{self._base_rest}/oauth2/Approval", json={ "grant_type": "client_credentials", "appkey": self._app_key, "secretkey": self._app_secret, }, ) response.raise_for_status() data = response.json() self._approval_key = data["approval_key"] return self._approval_key async def _connect(self): return await websockets.connect(self._ws_url) async def _connect_with_backoff(self): """연결 시도 with exponential backoff (1s → 2s → 4s → max 30s).""" for attempt in range(10): try: ws = await self._connect() return ws except Exception as e: wait = min(2**attempt, 30) logger.warning( "KIS WebSocket connect failed (attempt %d): %r — retrying in %ds", attempt + 1, e, wait, ) await asyncio.sleep(wait) raise RuntimeError("KIS WebSocket connect exhausted retries") async def start( self, tickers: list[str], on_asking_price: Callable[[str, dict], None], ) -> None: if self._approval_key is None: await self._fetch_approval_key() self._on_asking_price = on_asking_price self._ws = await self._connect_with_backoff() for ticker in tickers: await self.subscribe(ticker) self._recv_task = asyncio.create_task(self._receive_loop()) async def subscribe(self, ticker: str) -> None: if self._ws is None or self._approval_key is None: raise RuntimeError("KIS WebSocket not started") msg = json.dumps({ "header": { "approval_key": self._approval_key, "custtype": "P", "tr_type": "1", "content-type": "utf-8", }, "body": { "input": {"tr_id": "H0STASP0", "tr_key": ticker}, }, }) await self._ws.send(msg) self._subscriptions.add(ticker) async def unsubscribe(self, ticker: str) -> None: if self._ws is None or self._approval_key is None: return msg = json.dumps({ "header": { "approval_key": self._approval_key, "custtype": "P", "tr_type": "2", "content-type": "utf-8", }, "body": { "input": {"tr_id": "H0STASP0", "tr_key": ticker}, }, }) await self._ws.send(msg) self._subscriptions.discard(ticker) async def close(self) -> None: self._shutdown.set() if self._recv_task is not None: self._recv_task.cancel() try: await self._recv_task except asyncio.CancelledError: pass if self._ws is not None: await self._ws.close() async def _receive_loop(self) -> None: while not self._shutdown.is_set(): try: raw = await self._ws.recv() except websockets.ConnectionClosed: logger.warning("KIS WebSocket closed — reconnecting") self._ws = await self._connect_with_backoff() for ticker in list(self._subscriptions): await self.subscribe(ticker) continue if not isinstance(raw, str): continue parsed = self._parse_asking_price(raw) if parsed is not None and self._on_asking_price is not None: ticker, data = parsed try: self._on_asking_price(ticker, data) except Exception: logger.exception("on_asking_price callback failed") def _parse_asking_price(self, raw: str) -> tuple[str, dict] | None: """KIS H0STASP0 raw → (ticker, asking_price dict). Raw format: '0|H0STASP0||' where data = '^'-joined fields. Field indices (운영 검증 필요): 마지막 2개 가정 (ask, bid). """ try: parts = raw.split("|") if len(parts) < 4 or parts[1] != "H0STASP0": return None fields = parts[3].split("^") ticker = fields[_ASKING_TICKER_IDX] current_price_str = fields[_ASKING_CURRENT_PRICE_IDX] current_price = int(current_price_str) if current_price_str.lstrip("-").isdigit() else 0 ask_str = fields[_ASKING_TOTAL_ASK_IDX] bid_str = fields[_ASKING_TOTAL_BID_IDX] ask_total = int(ask_str) if ask_str.lstrip("-").isdigit() else 0 bid_total = int(bid_str) if bid_str.lstrip("-").isdigit() else 0 total = bid_total + ask_total return ticker, { "bid_total": bid_total, "ask_total": ask_total, "bid_ratio": bid_total / total if total > 0 else 0.0, "current_price": current_price, "as_of": datetime.now(KST).isoformat(), } except (IndexError, ValueError) as e: logger.warning("parse_asking_price failed: %r", e) return None