Compare commits
31 Commits
ad2c65c2b2
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 71ef959310 | |||
| 2aa9f48ea3 | |||
| cc6310d72f | |||
| e574074ca8 | |||
| b9def06993 | |||
| 05ab2846bb | |||
| 760f914d3b | |||
| 8eefe9d79d | |||
| 91de16675b | |||
| 44888d6ede | |||
| 9e5fecb369 | |||
| 28f9c8c3a6 | |||
| c5a88fab66 | |||
| 7056cf2fa6 | |||
| 4ac7da8670 | |||
| b690900cfc | |||
| d85512d036 | |||
| 3ebe95ba29 | |||
| 163c9fb690 | |||
| 27bf360b01 | |||
| eafa73edb1 | |||
| 68eb7b073c | |||
| 8342d38935 | |||
| e47947fb69 | |||
| 94c684bab8 | |||
| 1a6d9fcb39 | |||
| 6cb5085118 | |||
| fdabc69004 | |||
| 90235497ae | |||
| 8469bf7ffa | |||
| 8a2fac03a6 |
8
.gitignore
vendored
8
.gitignore
vendored
@@ -47,9 +47,11 @@ daily_trade_history.json
|
|||||||
watchlist.json
|
watchlist.json
|
||||||
bot_ipc.json
|
bot_ipc.json
|
||||||
|
|
||||||
# Test
|
# Test (top-level only; signal_v2/tests tracked separately)
|
||||||
tests/
|
tests/
|
||||||
tests/*
|
tests/*
|
||||||
|
!signal_v2/tests/
|
||||||
|
!signal_v2/tests/**
|
||||||
|
|
||||||
# System
|
# System
|
||||||
Thumbs.db
|
Thumbs.db
|
||||||
@@ -59,3 +61,7 @@ Desktop.ini
|
|||||||
KIS_SETUP.md
|
KIS_SETUP.md
|
||||||
# Claude Code subagent state
|
# Claude Code subagent state
|
||||||
.claude/
|
.claude/
|
||||||
|
|
||||||
|
# Signal V2 runtime data
|
||||||
|
signal_v2/data/*.db
|
||||||
|
signal_v2/data/*.db-*
|
||||||
|
|||||||
143
CLAUDE.md
143
CLAUDE.md
@@ -1,24 +1,141 @@
|
|||||||
# web-ai — Workspace 가이드
|
# web-ai — Workspace 가이드
|
||||||
|
|
||||||
Windows AI 머신 (AMD 9800X3D + RTX 5070 Ti) 의 두 시그널 파이프라인 컨테이너.
|
Windows AI 머신 (AMD 9800X3D + RTX 5070 Ti 16GB) 의 두 신호 파이프라인.
|
||||||
|
**Confidence Signal Pipeline V2 의 Windows-side 구현체** (NAS stock 백엔드와 HTTP 연동).
|
||||||
|
|
||||||
|
상위 워크스페이스 컨텍스트는 `../CLAUDE.md` 참조.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## 디렉토리 구조
|
## 디렉토리 구조
|
||||||
|
|
||||||
| 경로 | 역할 | 상태 |
|
| 경로 | 역할 | 포트 | 상태 |
|
||||||
|------|------|------|
|
|------|------|------|------|
|
||||||
| `signal_v1/` | V1 자체 자동매매 시스템 (main_server.py + Trading Bot + Telegram Bot + LSTM + Ollama + KIS 자동주문) | 운영 중. Confidence Signal Pipeline V2 Phase 6 에서 deprecation 예정 |
|
| `signal_v1/` | 레거시 자동매매 시스템 (LSTM 7-features + Gemini Flash + Telegram Bot + KIS 자동주문) | `:8000` | 운영 중. **V2 Phase 6 에서 deprecation 예정** |
|
||||||
| `signal_v2/` | V2 신호 파이프라인 (stock pull worker + Chronos-2 + signal API client) | Phase 2 에서 신설 |
|
| `signal_v2/` | Confidence Signal Pipeline V2 (Chronos-bolt + 분봉 모멘텀 + KIS WebSocket + 신호 생성) | `:8001` | **Phase 4 완료 (2026-05-17)**, Phase 5 대기 |
|
||||||
| `.env` | V1 + V2 환경변수 공유 | KIS_*, TELEGRAM_*, STOCK_API_URL, WEBAI_API_KEY 등 |
|
| `.env` | V1 + V2 환경변수 공유 | — | `KIS_REAL_*`, `TELEGRAM_*`, `STOCK_API_URL`, `WEBAI_API_KEY`, `LOG_LEVEL` |
|
||||||
| `start.bat` | V1 진입 (signal_v1 디렉토리 안 main_server.py 실행) | V2 별도 start 스크립트는 signal_v2/start.bat |
|
| `start.bat` | V1 진입점 | — | `signal_v1/main_server.py` 실행 |
|
||||||
|
| `signal_v2/start.bat` | V2 진입점 | — | `signal_v2/main.py` uvicorn 실행 |
|
||||||
|
| `requirements.txt` | 공용 의존성 | — | torch, chronos-forecasting, fastapi, httpx, websockets 등 |
|
||||||
|
|
||||||
## 운영 가이드
|
`.venv` 는 **구조적으로 깨짐**: `pyvenv.cfg` 가 한글 사용자 경로(`C:\Users\박재오\...`) 를 포함하여 콘솔 코드페이지가 roundtrip 못함. 테스트는 시스템 Python 으로 실행: `C:\Users\jaeoh\AppData\Local\Programs\Python\Python312\python.exe -m pytest signal_v2/tests -q`.
|
||||||
|
|
||||||
- V1 시작: `start.bat` 또는 `cd signal_v1 && python main_server.py`
|
---
|
||||||
- V2 시작 (Phase 2 이후): `cd signal_v2 && python -m uvicorn main:app --port 8001`
|
|
||||||
- 둘 다 동시 실행 가능 (포트 분리: V1=8000, V2=8001)
|
## 서버 시작 방식
|
||||||
|
|
||||||
|
### V1 단독 (운영 기본)
|
||||||
|
```bat
|
||||||
|
cd C:\Users\jaeoh\Desktop\workspace\web-ai
|
||||||
|
.\start.bat
|
||||||
|
```
|
||||||
|
기대 로그: `[Bot] Cycle Start ...`, `[AI] 005930: NN epochs ...`, `[Ensemble] tech=... news=... lstm=...`, `Score: 0.xx [HOLD]`
|
||||||
|
|
||||||
|
### V2 단독 (smoke/검증)
|
||||||
|
```bat
|
||||||
|
cd C:\Users\jaeoh\Desktop\workspace\web-ai\signal_v2
|
||||||
|
.\start.bat
|
||||||
|
```
|
||||||
|
기대 로그: `Uvicorn running on http://0.0.0.0:8001`, `poll_loop started`, `[KIS] minute bars ... OK`, `[Chronos] predicted N tickers`, `signal emit XXXXXX buy conf=0.xxx`.
|
||||||
|
|
||||||
|
휴장일/장 외 시간엔 `poll_loop` 만 idle. `Application startup complete` 만 보이면 정상.
|
||||||
|
|
||||||
|
### V1 + V2 동시 실행 — **권장 안 함**
|
||||||
|
**KIS app_key 초당 2회 한도 (EGW00201)** 충돌. V1 cycle + V2 분봉 cron 이 같은 KIS app_key 로 동시 호출하면 rate limit. 채택 해결책: V2 임시 종료 (Phase 3a 결정), Phase 6 V1 deprecation 시 자연 해소. 별도 app_key 발급은 옵션 B.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Phase 진행 상태 (Confidence Signal Pipeline V2)
|
## Phase 진행 상태 (Confidence Signal Pipeline V2)
|
||||||
|
|
||||||
`web-ui/docs/superpowers/specs/2026-05-15-confidence-signal-pipeline-v2-architecture.md` 참조.
|
| Phase | 내용 | 상태 |
|
||||||
|
|-------|------|------|
|
||||||
|
| 0 | Architecture & contract spec | ✅ Chronos-2 + Qwen3 14B 채택 |
|
||||||
|
| 1 | stock 백엔드 WebAI API 보강 (NAS) | ✅ 102/102 tests, 운영 배포 |
|
||||||
|
| 1.5 | V1 → `signal_v1/` rename | ✅ V1 정상 기동 |
|
||||||
|
| 2 | signal_v2 pull worker + signal API client + scheduler | ✅ 19/19 tests, `:8001` 기동 |
|
||||||
|
| 3a | KIS REST 분봉 + WebSocket 호가 + NXT 스케줄 | ✅ 33/33 tests |
|
||||||
|
| 3b | Chronos-bolt-base 추론 + 5분봉 모멘텀 분류기 | ✅ 45/45 tests, 실 KIS+Chronos chain 검증 |
|
||||||
|
| 4 | Signal Generator (매수/매도 룰) + pull_worker 통합 + 로깅 | ✅ **2026-05-17 완료, 56/56 tests, push 완료** |
|
||||||
|
| 5 | agent-office `/signal` + Ollama Qwen3 14B + 이중 텔레그램 | ⏳ 2주 예상 |
|
||||||
|
| 6 | signal_v1 deprecation | ⏳ 1주 |
|
||||||
|
| 7 | 운영 모니터링 + 4주 IC 검증 | ⏳ 1주 + 4주 |
|
||||||
|
|
||||||
자세한 V1 가이드는 `signal_v1/CLAUDE.md` 참조.
|
상세 spec/plan: `../web-ui/docs/superpowers/specs/` 및 `../web-ui/docs/superpowers/plans/` (web-ui repo 안에 보관됨 — V2 자체 코드와 분리 보관).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## signal_v2 디렉토리 내부
|
||||||
|
|
||||||
|
| 파일 | 역할 |
|
||||||
|
|------|------|
|
||||||
|
| `main.py` | FastAPI app + lifespan (StockClient + KISClient + KISWebSocket + ChronosPredictor + SignalDedup 초기화). poll_loop task 생성 |
|
||||||
|
| `config.py` | Settings dataclass — 환경변수 로드. Phase 4 추가 6 필드: `stop_loss_pct`, `take_profit_pct`, `chronos_spread_threshold`, `asking_bid_ratio_threshold`, `confidence_threshold`, `min_momentum_for_buy` |
|
||||||
|
| `state.py` | PollState (process-wide singleton) — portfolio, screener_preview, news_sentiment, chronos_predictions, minute_bars, asking_price, **signals** (Phase 4) |
|
||||||
|
| `stock_client.py` | NAS stock 백엔드 pull (X-WebAI-Key + 메모리 cache 60s/300s/60s + retry) |
|
||||||
|
| `kis_client.py` | KIS REST 분봉/호가 — V1 토큰 read-only 공유 (mtime cache) + 초당 2회 throttle + 지수 backoff |
|
||||||
|
| `kis_websocket.py` | KIS WebSocket H0STASP0 호가 + approval_key + 재연결 (1→2→4→max 30s) |
|
||||||
|
| `chronos_predictor.py` | `amazon/chronos-bolt-base` zero-shot quantile (FP32 강제 — FP16 overflow 회피) |
|
||||||
|
| `minute_momentum.py` | 5분봉 → strong_up/weak_up/neutral/weak_down/strong_down 5단계 분류 |
|
||||||
|
| `signal_generator.py` | **Phase 4 — 매수/매도 룰 엔진**. `generate_signals(state, dedup, settings)` 진입. sell-first → buy 순서. 신호 emit/skip INFO/DEBUG 로그 |
|
||||||
|
| `pull_worker.py` | asyncio cron — 장전 5분 / 장중 1분 / 장후 5분 / NXT / dead zone skip. cycle 끝에 `generate_signals` 호출 |
|
||||||
|
| `scheduler.py` | polling window 판정 (KST 캘린더 + 휴장일) |
|
||||||
|
| `rate_limit.py` | 초당 N회 token bucket |
|
||||||
|
| `dedup.py` | SignalDedup SQLite WAL — `(ticker, action)` PK 24h |
|
||||||
|
| `tests/` | 56 tests (pytest + respx HTTP mock + monkeypatch) |
|
||||||
|
| `data/` | dedup.db (SQLite WAL) + `holidays.json` (NAS stock 에서 manual copy) |
|
||||||
|
| `start.bat` | V2 진입 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 신호 룰 요약 (Phase 4)
|
||||||
|
|
||||||
|
### 매수 (screener Top-N + portfolio, sell 신호 받은 종목은 skip)
|
||||||
|
모두 충족:
|
||||||
|
1. `chronos.median > 0`
|
||||||
|
2. **`chronos.q90 - chronos.q10 < 0.6`** (absolute spread — 2026-05-17 spec amend, 기존 relative formula 가 zero-shot median≈0 빈번에서 모든 신호 거부)
|
||||||
|
3. `minute_momentum == strong_up` (env 로 조정 가능)
|
||||||
|
4. `asking_price.bid_ratio >= 0.6`
|
||||||
|
|
||||||
|
종합 confidence = `chronos_conf * 0.5 + minute_score * 0.3 + screener_norm * 0.2`. `> 0.7` 시 emit.
|
||||||
|
|
||||||
|
### 매도 (portfolio only, 우선순위 stop_loss → anomaly → take_profit)
|
||||||
|
- **stop_loss**: `pnl_pct < -7%` 즉시 (confidence=1.0)
|
||||||
|
- **anomaly**: `chronos.median < -1%` + `strong_down` + `bid_ratio < 0.4` + 종합 conf > 0.7
|
||||||
|
- **take_profit**: `pnl_pct > 15%` 검토 (confidence=0.6)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 알려진 함정 / Phase 7 백로그
|
||||||
|
|
||||||
|
1. **KIS rate limit (EGW00201)** — V1+V2 동시 실행 시 충돌. Phase 6 자연 해소
|
||||||
|
2. **`.venv` 한글 경로 깨짐** — 시스템 Python 사용
|
||||||
|
3. **Chronos FP16 overflow** — 한국 주가 5만+ 시 inf. FP32 강제 (`chronos_predictor.py:39-41`)
|
||||||
|
4. **`predict_quantiles` positional `inputs`** — ChronosBolt API 새 변경. `try/except TypeError` fallback 처리됨
|
||||||
|
5. **`state.signals` consumer-drain protocol 미정의** — Phase 5 prereq. dict 무한 누적 위험 (실제로는 bounded by unique ticker count)
|
||||||
|
6. **integration test 가 poll_loop 실제 호출 안 함** — `test_pull_worker.py:test_poll_loop_calls_generate_signals_after_cycle` 가 `generate_signals` 직접 호출. Phase 7 hardening 시 mock-iteration 으로 강화
|
||||||
|
7. **KIS WebSocket URL `ws://ops.koreainvestment.com:21000/31000`** — 첫 운영 시 실제 KIS API docs 와 대조 필요
|
||||||
|
8. **`_parse_asking_price` 필드 인덱스** — 마지막 2 필드 가정. 실 운영 raw 메시지 캡처 후 매핑 검증 필요
|
||||||
|
9. **`holidays.json` 자동 동기화 부재** — NAS stock 의 `holidays.json` 을 수동 copy
|
||||||
|
10. **schema rename** — Phase 0 §5.2 의 `lstm_pred_*`, `news_top[]` 는 `chronos_pred_*`, `news_reason(string)` 으로 변경됨. Phase 5 prompt 작성 시 반영
|
||||||
|
11. **6개 env 필드가 `.env` 에 미기재** — 기본값으로 동작 가능하나 discoverability 위해 `.env.example` 또는 commented block 추가 권장
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 다음 단계 (Phase 5 진입 시 brainstorming 주제)
|
||||||
|
|
||||||
|
- `state.signals` consumer 패턴: pop vs leave + Phase 5 자체 dedup
|
||||||
|
- agent-office 의 `/signal` endpoint 설계 — POST 페이로드 schema
|
||||||
|
- Ollama Qwen3 14B Q4 로컬 호출 — 타임아웃, retry, VRAM 공존 (Chronos + Qwen3 동시 메모리 9.3GB / 15.5GB 가용)
|
||||||
|
- 이중 텔레그램 (본인 풀 / 아내 lite) — context augmentation 단일 호출에서 양쪽 메시지 생성
|
||||||
|
- LLM 비용: ₩0 목표 유지 (로컬)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 양쪽 디렉토리 (web-ui ↔ web-ai) 작업 시 주의
|
||||||
|
|
||||||
|
- **코드**: signal_v2 는 web-ai/, spec/plan/메모리는 web-ui/
|
||||||
|
- **커밋**: `web-ai` 와 `web-ui` 는 **별도 Gitea 저장소**. 각각 경로에서만 `git add/commit/push`
|
||||||
|
- **메모리**: Claude Code 의 auto-memory 는 디렉토리별 격리. 핵심 reference 는 양쪽에 미러됨 (`./memory-mirror/` 또는 `~/.claude/projects/C--Users-jaeoh-Desktop-workspace-web-ai/memory/`)
|
||||||
|
- **spec amendment 발생 시**: 코드는 `web-ai` 에 commit, spec 갱신은 `web-ui/docs/superpowers/specs/` 에 commit (Phase 4 spread formula 변경 사례 = web-ui commit `534ded5`)
|
||||||
|
|
||||||
|
자세한 V1 가이드는 `signal_v1/CLAUDE.md` 참조 (있다면).
|
||||||
|
|||||||
13
requirements.txt
Normal file
13
requirements.txt
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Signal V2 dependencies (added 2026-05-16, Phase 2)
|
||||||
|
httpx>=0.27
|
||||||
|
fastapi>=0.110
|
||||||
|
uvicorn>=0.27
|
||||||
|
python-dotenv>=1.0
|
||||||
|
pytest>=8.0
|
||||||
|
pytest-asyncio>=0.23
|
||||||
|
respx>=0.21
|
||||||
|
websockets>=12
|
||||||
|
# Phase 3b dependencies (Chronos-2 + ML)
|
||||||
|
transformers>=4.40
|
||||||
|
chronos-forecasting>=1.4
|
||||||
|
# torch: typically already installed via V1 venv; if not, install with CUDA support manually
|
||||||
0
signal_v2/__init__.py
Normal file
0
signal_v2/__init__.py
Normal file
132
signal_v2/chronos_predictor.py
Normal file
132
signal_v2/chronos_predictor.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
"""Chronos-2 zero-shot forecaster wrapper."""
|
||||||
|
from __future__ import annotations
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
KST = ZoneInfo("Asia/Seoul")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChronosPrediction:
|
||||||
|
median: float
|
||||||
|
q10: float
|
||||||
|
q90: float
|
||||||
|
conf: float
|
||||||
|
as_of: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChronosPredictor:
|
||||||
|
"""HuggingFace Chronos-2 zero-shot forecaster."""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = "amazon/chronos-2", device: str | None = None):
|
||||||
|
# BaseChronosPipeline auto-detects model variant (Chronos / ChronosBolt / Chronos-2)
|
||||||
|
# and returns the appropriate sub-pipeline. ChronosPipeline only supports legacy T5.
|
||||||
|
import torch
|
||||||
|
try:
|
||||||
|
from chronos import BaseChronosPipeline
|
||||||
|
pipeline_cls = BaseChronosPipeline
|
||||||
|
except ImportError:
|
||||||
|
from chronos import ChronosPipeline
|
||||||
|
pipeline_cls = ChronosPipeline
|
||||||
|
|
||||||
|
self._device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
# Always use float32 — Korean stock prices (e.g. 280,000원) exceed FP16 max (~65,504)
|
||||||
|
# causing inf in quantile output. FP32 is safe for typical price magnitudes.
|
||||||
|
dtype = torch.float32
|
||||||
|
logger.info("Loading Chronos pipeline: %s on %s (cls=%s)",
|
||||||
|
model_name, self._device, pipeline_cls.__name__)
|
||||||
|
# Try `dtype` (newer API) first, fall back to `torch_dtype` (older)
|
||||||
|
try:
|
||||||
|
self._pipeline = pipeline_cls.from_pretrained(
|
||||||
|
model_name, device_map=self._device, dtype=dtype,
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
self._pipeline = pipeline_cls.from_pretrained(
|
||||||
|
model_name, device_map=self._device, torch_dtype=dtype,
|
||||||
|
)
|
||||||
|
logger.info("Chronos pipeline loaded.")
|
||||||
|
|
||||||
|
def predict_batch(
|
||||||
|
self,
|
||||||
|
daily_ohlcv_dict: dict[str, list[dict]],
|
||||||
|
prediction_length: int = 1,
|
||||||
|
num_samples: int = 100,
|
||||||
|
) -> dict[str, ChronosPrediction]:
|
||||||
|
"""종목별 1-day return 분포 예측.
|
||||||
|
|
||||||
|
ChronosBolt / Chronos-2 등 신모델은 predict_quantiles 사용 (deterministic).
|
||||||
|
Legacy ChronosPipeline (T5) 는 sample-based predict.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
tickers = list(daily_ohlcv_dict.keys())
|
||||||
|
if not tickers:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
torch.tensor([bar["close"] for bar in daily_ohlcv_dict[t]], dtype=torch.float32)
|
||||||
|
for t in tickers
|
||||||
|
]
|
||||||
|
now_iso = datetime.now(KST).isoformat()
|
||||||
|
results: dict[str, ChronosPrediction] = {}
|
||||||
|
|
||||||
|
# Modern API: predict_quantiles (ChronosBolt / Chronos-2)
|
||||||
|
if hasattr(self._pipeline, "predict_quantiles"):
|
||||||
|
quantile_levels = [0.1, 0.5, 0.9]
|
||||||
|
# ChronosBolt API: positional `inputs` (first arg). Older variants use `context`.
|
||||||
|
try:
|
||||||
|
quantiles_tensor, _ = self._pipeline.predict_quantiles(
|
||||||
|
contexts,
|
||||||
|
prediction_length=prediction_length,
|
||||||
|
quantile_levels=quantile_levels,
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
quantiles_tensor, _ = self._pipeline.predict_quantiles(
|
||||||
|
context=contexts,
|
||||||
|
prediction_length=prediction_length,
|
||||||
|
quantile_levels=quantile_levels,
|
||||||
|
)
|
||||||
|
quantiles_np = (
|
||||||
|
quantiles_tensor.cpu().numpy()
|
||||||
|
if hasattr(quantiles_tensor, "cpu")
|
||||||
|
else np.asarray(quantiles_tensor)
|
||||||
|
)
|
||||||
|
# shape: [num_series, prediction_length, 3]
|
||||||
|
for i, ticker in enumerate(tickers):
|
||||||
|
q10_price, q50_price, q90_price = quantiles_np[i, 0, :]
|
||||||
|
last_close = daily_ohlcv_dict[ticker][-1]["close"]
|
||||||
|
median = float((q50_price - last_close) / last_close)
|
||||||
|
q10 = float((q10_price - last_close) / last_close)
|
||||||
|
q90 = float((q90_price - last_close) / last_close)
|
||||||
|
spread = (q90 - q10) / max(abs(median), 0.001)
|
||||||
|
conf = float(max(0.0, min(1.0, 1.0 - spread / 2.0)))
|
||||||
|
results[ticker] = ChronosPrediction(
|
||||||
|
median=median, q10=q10, q90=q90, conf=conf, as_of=now_iso,
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Legacy API: sample-based predict (ChronosPipeline T5)
|
||||||
|
forecasts = self._pipeline.predict(
|
||||||
|
context=contexts,
|
||||||
|
prediction_length=prediction_length,
|
||||||
|
num_samples=num_samples,
|
||||||
|
)
|
||||||
|
forecasts_np = forecasts.numpy() if hasattr(forecasts, "numpy") else np.asarray(forecasts)
|
||||||
|
for i, ticker in enumerate(tickers):
|
||||||
|
samples = forecasts_np[i, :, 0]
|
||||||
|
last_close = daily_ohlcv_dict[ticker][-1]["close"]
|
||||||
|
returns = (samples - last_close) / last_close
|
||||||
|
median = float(np.quantile(returns, 0.5))
|
||||||
|
q10 = float(np.quantile(returns, 0.1))
|
||||||
|
q90 = float(np.quantile(returns, 0.9))
|
||||||
|
spread = (q90 - q10) / max(abs(median), 0.001)
|
||||||
|
conf = float(max(0.0, min(1.0, 1.0 - spread / 2.0)))
|
||||||
|
results[ticker] = ChronosPrediction(
|
||||||
|
median=median, q10=q10, q90=q90, conf=conf, as_of=now_iso,
|
||||||
|
)
|
||||||
|
return results
|
||||||
75
signal_v2/config.py
Normal file
75
signal_v2/config.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
"""Signal V2 환경변수 로딩."""
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv(Path(__file__).parent.parent / ".env")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Settings:
|
||||||
|
stock_api_url: str = field(
|
||||||
|
default_factory=lambda: os.getenv("STOCK_API_URL", "").rstrip("/")
|
||||||
|
)
|
||||||
|
webai_api_key: str = field(
|
||||||
|
default_factory=lambda: os.getenv("WEBAI_API_KEY", "").strip()
|
||||||
|
)
|
||||||
|
port: int = field(default_factory=lambda: int(os.getenv("SIGNAL_V2_PORT", "8001")))
|
||||||
|
db_path: Path = field(
|
||||||
|
default_factory=lambda: Path(__file__).parent / "data" / "signal_v2.db"
|
||||||
|
)
|
||||||
|
# KIS — V1 호환 패턴 (KIS_ENV_TYPE virtual/real)
|
||||||
|
kis_env_type: str = field(default_factory=lambda: os.getenv("KIS_ENV_TYPE", "virtual").lower())
|
||||||
|
kis_real_app_key: str = field(default_factory=lambda: os.getenv("KIS_REAL_APP_KEY", "").strip())
|
||||||
|
kis_real_app_secret: str = field(default_factory=lambda: os.getenv("KIS_REAL_APP_SECRET", "").strip())
|
||||||
|
kis_real_account: str = field(default_factory=lambda: os.getenv("KIS_REAL_ACCOUNT", "").strip())
|
||||||
|
kis_virtual_app_key: str = field(default_factory=lambda: os.getenv("KIS_VIRTUAL_APP_KEY", "").strip())
|
||||||
|
kis_virtual_app_secret: str = field(default_factory=lambda: os.getenv("KIS_VIRTUAL_APP_SECRET", "").strip())
|
||||||
|
kis_virtual_account: str = field(default_factory=lambda: os.getenv("KIS_VIRTUAL_ACCOUNT", "").strip())
|
||||||
|
v1_token_path: Path = field(
|
||||||
|
default_factory=lambda: Path(
|
||||||
|
os.getenv("V1_TOKEN_PATH",
|
||||||
|
str(Path(__file__).parent.parent / "signal_v1" / "data" / "kis_token.json"))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
chronos_model: str = field(default_factory=lambda: os.getenv("CHRONOS_MODEL", "amazon/chronos-2"))
|
||||||
|
stop_loss_pct: float = field(
|
||||||
|
default_factory=lambda: float(os.getenv("STOP_LOSS_PCT", "-0.07"))
|
||||||
|
)
|
||||||
|
take_profit_pct: float = field(
|
||||||
|
default_factory=lambda: float(os.getenv("TAKE_PROFIT_PCT", "0.15"))
|
||||||
|
)
|
||||||
|
chronos_spread_threshold: float = field(
|
||||||
|
default_factory=lambda: float(os.getenv("CHRONOS_SPREAD_THRESHOLD", "0.6"))
|
||||||
|
)
|
||||||
|
asking_bid_ratio_threshold: float = field(
|
||||||
|
default_factory=lambda: float(os.getenv("ASKING_BID_RATIO_THRESHOLD", "0.6"))
|
||||||
|
)
|
||||||
|
confidence_threshold: float = field(
|
||||||
|
default_factory=lambda: float(os.getenv("CONFIDENCE_THRESHOLD", "0.7"))
|
||||||
|
)
|
||||||
|
min_momentum_for_buy: str = field(
|
||||||
|
default_factory=lambda: os.getenv("MIN_MOMENTUM_FOR_BUY", "strong_up")
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kis_is_virtual(self) -> bool:
|
||||||
|
return self.kis_env_type != "real"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kis_app_key(self) -> str:
|
||||||
|
return self.kis_real_app_key if self.kis_env_type == "real" else self.kis_virtual_app_key
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kis_app_secret(self) -> str:
|
||||||
|
return self.kis_real_app_secret if self.kis_env_type == "real" else self.kis_virtual_app_secret
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kis_account(self) -> str:
|
||||||
|
return self.kis_real_account if self.kis_env_type == "real" else self.kis_virtual_account
|
||||||
|
|
||||||
|
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
return Settings()
|
||||||
0
signal_v2/data/.gitkeep
Normal file
0
signal_v2/data/.gitkeep
Normal file
18
signal_v2/holidays.json
Normal file
18
signal_v2/holidays.json
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
[
|
||||||
|
"2026-01-01",
|
||||||
|
"2026-01-28",
|
||||||
|
"2026-01-29",
|
||||||
|
"2026-01-30",
|
||||||
|
"2026-03-01",
|
||||||
|
"2026-05-05",
|
||||||
|
"2026-05-25",
|
||||||
|
"2026-06-06",
|
||||||
|
"2026-08-15",
|
||||||
|
"2026-09-24",
|
||||||
|
"2026-09-25",
|
||||||
|
"2026-09-26",
|
||||||
|
"2026-10-03",
|
||||||
|
"2026-10-09",
|
||||||
|
"2026-12-25",
|
||||||
|
"2026-12-31"
|
||||||
|
]
|
||||||
193
signal_v2/kis_client.py
Normal file
193
signal_v2/kis_client.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
"""KIS REST API client — 분봉 + 호가. V1 토큰 read-only 공유."""
|
||||||
|
from __future__ import annotations
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
KST = ZoneInfo("Asia/Seoul")
|
||||||
|
|
||||||
|
_MAX_ATTEMPTS = 3
|
||||||
|
_THROTTLE_INTERVAL = 0.5 # 초당 2회 제한
|
||||||
|
|
||||||
|
|
||||||
|
class KISClient:
|
||||||
|
"""KIS REST (분봉 + 호가). V1 토큰 파일 read-only."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app_key: str, app_secret: str, account: str, is_virtual: bool,
|
||||||
|
v1_token_path: Path,
|
||||||
|
timeout: float = 10.0,
|
||||||
|
):
|
||||||
|
self._app_key = app_key
|
||||||
|
self._app_secret = app_secret
|
||||||
|
self._account = account
|
||||||
|
self._is_virtual = is_virtual
|
||||||
|
self._v1_token_path = Path(v1_token_path)
|
||||||
|
self._base_url = (
|
||||||
|
"https://openapivts.koreainvestment.com:29443" if is_virtual
|
||||||
|
else "https://openapi.koreainvestment.com:9443"
|
||||||
|
)
|
||||||
|
self._client = httpx.AsyncClient(timeout=timeout)
|
||||||
|
self._token_cache: tuple[str, float] | None = None # (token, file_mtime)
|
||||||
|
self._last_throttle_at = 0.0
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
await self._client.aclose()
|
||||||
|
|
||||||
|
def _read_v1_token(self) -> str:
|
||||||
|
if not self._v1_token_path.exists():
|
||||||
|
raise RuntimeError(f"V1 token file missing: {self._v1_token_path}")
|
||||||
|
mtime = self._v1_token_path.stat().st_mtime
|
||||||
|
if self._token_cache and self._token_cache[1] == mtime:
|
||||||
|
return self._token_cache[0]
|
||||||
|
data = json.loads(self._v1_token_path.read_text(encoding="utf-8"))
|
||||||
|
token = data.get("access_token", "")
|
||||||
|
if not token:
|
||||||
|
raise RuntimeError("V1 token file has no access_token")
|
||||||
|
self._token_cache = (token, mtime)
|
||||||
|
return token
|
||||||
|
|
||||||
|
async def _throttle(self) -> None:
|
||||||
|
elapsed = time.monotonic() - self._last_throttle_at
|
||||||
|
if elapsed < _THROTTLE_INTERVAL:
|
||||||
|
await asyncio.sleep(_THROTTLE_INTERVAL - elapsed)
|
||||||
|
self._last_throttle_at = time.monotonic()
|
||||||
|
|
||||||
|
def _common_headers(self, tr_id: str) -> dict[str, str]:
|
||||||
|
token = self._read_v1_token()
|
||||||
|
return {
|
||||||
|
"authorization": f"Bearer {token}",
|
||||||
|
"appkey": self._app_key,
|
||||||
|
"appsecret": self._app_secret,
|
||||||
|
"tr_id": tr_id,
|
||||||
|
"custtype": "P",
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _request_with_retry(
|
||||||
|
self, method: str, path: str, tr_id: str, **kwargs,
|
||||||
|
) -> dict:
|
||||||
|
url = f"{self._base_url}{path}"
|
||||||
|
headers = self._common_headers(tr_id)
|
||||||
|
for attempt in range(_MAX_ATTEMPTS):
|
||||||
|
await self._throttle()
|
||||||
|
try:
|
||||||
|
response = await self._client.request(
|
||||||
|
method, url, headers=headers, **kwargs
|
||||||
|
)
|
||||||
|
if response.status_code == 429:
|
||||||
|
if attempt < _MAX_ATTEMPTS - 1:
|
||||||
|
await asyncio.sleep(2**attempt)
|
||||||
|
continue
|
||||||
|
response.raise_for_status()
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
if attempt < _MAX_ATTEMPTS - 1:
|
||||||
|
await asyncio.sleep(2**attempt)
|
||||||
|
continue
|
||||||
|
raise
|
||||||
|
raise RuntimeError("retry exhausted")
|
||||||
|
|
||||||
|
async def get_minute_ohlcv(self, ticker: str) -> list[dict]:
|
||||||
|
"""현재 시점 직전 30개 1분봉 OHLCV (TR_ID FHKST03010200)."""
|
||||||
|
path = "/uapi/domestic-stock/v1/quotations/inquire-time-itemchartprice"
|
||||||
|
params = {
|
||||||
|
"FID_ETC_CLS_CODE": "",
|
||||||
|
"FID_COND_MRKT_DIV_CODE": "J",
|
||||||
|
"FID_INPUT_ISCD": ticker,
|
||||||
|
"FID_INPUT_HOUR_1": datetime.now(KST).strftime("%H%M%S"),
|
||||||
|
"FID_PW_DATA_INCU_YN": "N",
|
||||||
|
}
|
||||||
|
raw = await self._request_with_retry(
|
||||||
|
"GET", path, tr_id="FHKST03010200", params=params,
|
||||||
|
)
|
||||||
|
output2 = raw.get("output2", [])
|
||||||
|
bars = []
|
||||||
|
for row in output2:
|
||||||
|
try:
|
||||||
|
date = row["stck_bsop_date"]
|
||||||
|
hhmmss = row["stck_cntg_hour"]
|
||||||
|
dt = datetime.strptime(f"{date} {hhmmss}", "%Y%m%d %H%M%S").replace(tzinfo=KST)
|
||||||
|
bars.append({
|
||||||
|
"datetime": dt.isoformat(),
|
||||||
|
"open": int(row["stck_oprc"]),
|
||||||
|
"high": int(row["stck_hgpr"]),
|
||||||
|
"low": int(row["stck_lwpr"]),
|
||||||
|
"close": int(row["stck_prpr"]),
|
||||||
|
"volume": int(row["cntg_vol"]),
|
||||||
|
})
|
||||||
|
except (KeyError, ValueError) as e:
|
||||||
|
logger.warning("skip malformed bar for %s: %r", ticker, e)
|
||||||
|
# KIS returns descending; reverse to ascending (most recent last)
|
||||||
|
bars.reverse()
|
||||||
|
return bars
|
||||||
|
|
||||||
|
async def get_asking_price(self, ticker: str) -> dict:
|
||||||
|
"""현재 호가 + 매수/매도 잔량 (TR_ID FHKST01010200)."""
|
||||||
|
path = "/uapi/domestic-stock/v1/quotations/inquire-asking-price-exp-ccn"
|
||||||
|
params = {
|
||||||
|
"FID_COND_MRKT_DIV_CODE": "J",
|
||||||
|
"FID_INPUT_ISCD": ticker,
|
||||||
|
}
|
||||||
|
raw = await self._request_with_retry(
|
||||||
|
"GET", path, tr_id="FHKST01010200", params=params,
|
||||||
|
)
|
||||||
|
output1 = raw.get("output1", {})
|
||||||
|
bid_total = int(output1.get("total_bidp_rsqn", 0))
|
||||||
|
ask_total = int(output1.get("total_askp_rsqn", 0))
|
||||||
|
total = bid_total + ask_total
|
||||||
|
bid_ratio = bid_total / total if total > 0 else 0.0
|
||||||
|
current_price = int(output1.get("stck_prpr", 0))
|
||||||
|
return {
|
||||||
|
"bid_total": bid_total,
|
||||||
|
"ask_total": ask_total,
|
||||||
|
"bid_ratio": bid_ratio,
|
||||||
|
"current_price": current_price,
|
||||||
|
"as_of": datetime.now(KST).isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_daily_ohlcv(self, ticker: str, days: int = 60) -> list[dict]:
|
||||||
|
"""KRX 일봉 OHLCV (TR_ID FHKST03010100).
|
||||||
|
|
||||||
|
Returns: [{"datetime", "open", "high", "low", "close", "volume"}, ...]
|
||||||
|
시간 오름차순.
|
||||||
|
"""
|
||||||
|
path = "/uapi/domestic-stock/v1/quotations/inquire-daily-itemchartprice"
|
||||||
|
today = datetime.now(KST).strftime("%Y%m%d")
|
||||||
|
start_date = (datetime.now(KST) - timedelta(days=days * 2)).strftime("%Y%m%d")
|
||||||
|
params = {
|
||||||
|
"FID_COND_MRKT_DIV_CODE": "J",
|
||||||
|
"FID_INPUT_ISCD": ticker,
|
||||||
|
"FID_INPUT_DATE_1": start_date,
|
||||||
|
"FID_INPUT_DATE_2": today,
|
||||||
|
"FID_PERIOD_DIV_CODE": "D",
|
||||||
|
"FID_ORG_ADJ_PRC": "1",
|
||||||
|
}
|
||||||
|
raw = await self._request_with_retry(
|
||||||
|
"GET", path, tr_id="FHKST03010100", params=params,
|
||||||
|
)
|
||||||
|
output2 = raw.get("output2", [])
|
||||||
|
bars = []
|
||||||
|
for row in output2:
|
||||||
|
try:
|
||||||
|
date = row["stck_bsop_date"]
|
||||||
|
bars.append({
|
||||||
|
"datetime": f"{date[:4]}-{date[4:6]}-{date[6:]}",
|
||||||
|
"open": int(row["stck_oprc"]),
|
||||||
|
"high": int(row["stck_hgpr"]),
|
||||||
|
"low": int(row["stck_lwpr"]),
|
||||||
|
"close": int(row["stck_clpr"]),
|
||||||
|
"volume": int(row["acml_vol"]),
|
||||||
|
})
|
||||||
|
except (KeyError, ValueError):
|
||||||
|
continue
|
||||||
|
bars.reverse()
|
||||||
|
return bars[-days:]
|
||||||
186
signal_v2/kis_websocket.py
Normal file
186
signal_v2/kis_websocket.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
"""KIS WebSocket — approval_key + 실시간 호가 구독."""
|
||||||
|
from __future__ import annotations
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Callable
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
KST = ZoneInfo("Asia/Seoul")
|
||||||
|
|
||||||
|
# KIS 호가 메시지 필드 인덱스 (운영 환경 검증 필요)
|
||||||
|
# H0STASP0 응답: ticker | time | current_price | ... | ask_total | bid_total
|
||||||
|
# 본 spec/plan 의 가정: 마지막 2개 필드가 ask_total / bid_total
|
||||||
|
_ASKING_TICKER_IDX = 0
|
||||||
|
_ASKING_TIME_IDX = 1
|
||||||
|
_ASKING_CURRENT_PRICE_IDX = 2
|
||||||
|
_ASKING_TOTAL_ASK_IDX = -2
|
||||||
|
_ASKING_TOTAL_BID_IDX = -1
|
||||||
|
|
||||||
|
|
||||||
|
class KISWebSocket:
|
||||||
|
"""KIS WebSocket client. approval_key 발급 + 호가 실시간."""
|
||||||
|
|
||||||
|
def __init__(self, app_key: str, app_secret: str, is_virtual: bool):
|
||||||
|
self._app_key = app_key
|
||||||
|
self._app_secret = app_secret
|
||||||
|
self._is_virtual = is_virtual
|
||||||
|
self._base_rest = (
|
||||||
|
"https://openapivts.koreainvestment.com:29443" if is_virtual
|
||||||
|
else "https://openapi.koreainvestment.com:9443"
|
||||||
|
)
|
||||||
|
self._ws_url = (
|
||||||
|
"ws://ops.koreainvestment.com:31000" if is_virtual
|
||||||
|
else "ws://ops.koreainvestment.com:21000"
|
||||||
|
)
|
||||||
|
self._approval_key: str | None = None
|
||||||
|
self._ws = None
|
||||||
|
self._subscriptions: set[str] = set()
|
||||||
|
self._on_asking_price: Callable[[str, dict], None] | None = None
|
||||||
|
self._recv_task: asyncio.Task | None = None
|
||||||
|
self._shutdown = asyncio.Event()
|
||||||
|
|
||||||
|
async def _fetch_approval_key(self) -> str:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
response = await client.post(
|
||||||
|
f"{self._base_rest}/oauth2/Approval",
|
||||||
|
json={
|
||||||
|
"grant_type": "client_credentials",
|
||||||
|
"appkey": self._app_key,
|
||||||
|
"secretkey": self._app_secret,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
self._approval_key = data["approval_key"]
|
||||||
|
return self._approval_key
|
||||||
|
|
||||||
|
async def _connect(self):
|
||||||
|
return await websockets.connect(self._ws_url)
|
||||||
|
|
||||||
|
async def _connect_with_backoff(self):
|
||||||
|
"""연결 시도 with exponential backoff (1s → 2s → 4s → max 30s)."""
|
||||||
|
for attempt in range(10):
|
||||||
|
try:
|
||||||
|
ws = await self._connect()
|
||||||
|
return ws
|
||||||
|
except Exception as e:
|
||||||
|
wait = min(2**attempt, 30)
|
||||||
|
logger.warning(
|
||||||
|
"KIS WebSocket connect failed (attempt %d): %r — retrying in %ds",
|
||||||
|
attempt + 1, e, wait,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(wait)
|
||||||
|
raise RuntimeError("KIS WebSocket connect exhausted retries")
|
||||||
|
|
||||||
|
async def start(
|
||||||
|
self, tickers: list[str],
|
||||||
|
on_asking_price: Callable[[str, dict], None],
|
||||||
|
) -> None:
|
||||||
|
if self._approval_key is None:
|
||||||
|
await self._fetch_approval_key()
|
||||||
|
self._on_asking_price = on_asking_price
|
||||||
|
self._ws = await self._connect_with_backoff()
|
||||||
|
for ticker in tickers:
|
||||||
|
await self.subscribe(ticker)
|
||||||
|
self._recv_task = asyncio.create_task(self._receive_loop())
|
||||||
|
|
||||||
|
async def subscribe(self, ticker: str) -> None:
|
||||||
|
if self._ws is None or self._approval_key is None:
|
||||||
|
raise RuntimeError("KIS WebSocket not started")
|
||||||
|
msg = json.dumps({
|
||||||
|
"header": {
|
||||||
|
"approval_key": self._approval_key,
|
||||||
|
"custtype": "P",
|
||||||
|
"tr_type": "1",
|
||||||
|
"content-type": "utf-8",
|
||||||
|
},
|
||||||
|
"body": {
|
||||||
|
"input": {"tr_id": "H0STASP0", "tr_key": ticker},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
await self._ws.send(msg)
|
||||||
|
self._subscriptions.add(ticker)
|
||||||
|
|
||||||
|
async def unsubscribe(self, ticker: str) -> None:
|
||||||
|
if self._ws is None or self._approval_key is None:
|
||||||
|
return
|
||||||
|
msg = json.dumps({
|
||||||
|
"header": {
|
||||||
|
"approval_key": self._approval_key,
|
||||||
|
"custtype": "P",
|
||||||
|
"tr_type": "2",
|
||||||
|
"content-type": "utf-8",
|
||||||
|
},
|
||||||
|
"body": {
|
||||||
|
"input": {"tr_id": "H0STASP0", "tr_key": ticker},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
await self._ws.send(msg)
|
||||||
|
self._subscriptions.discard(ticker)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
self._shutdown.set()
|
||||||
|
if self._recv_task is not None:
|
||||||
|
self._recv_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._recv_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
if self._ws is not None:
|
||||||
|
await self._ws.close()
|
||||||
|
|
||||||
|
async def _receive_loop(self) -> None:
|
||||||
|
while not self._shutdown.is_set():
|
||||||
|
try:
|
||||||
|
raw = await self._ws.recv()
|
||||||
|
except websockets.ConnectionClosed:
|
||||||
|
logger.warning("KIS WebSocket closed — reconnecting")
|
||||||
|
self._ws = await self._connect_with_backoff()
|
||||||
|
for ticker in list(self._subscriptions):
|
||||||
|
await self.subscribe(ticker)
|
||||||
|
continue
|
||||||
|
if not isinstance(raw, str):
|
||||||
|
continue
|
||||||
|
parsed = self._parse_asking_price(raw)
|
||||||
|
if parsed is not None and self._on_asking_price is not None:
|
||||||
|
ticker, data = parsed
|
||||||
|
try:
|
||||||
|
self._on_asking_price(ticker, data)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("on_asking_price callback failed")
|
||||||
|
|
||||||
|
def _parse_asking_price(self, raw: str) -> tuple[str, dict] | None:
|
||||||
|
"""KIS H0STASP0 raw → (ticker, asking_price dict).
|
||||||
|
|
||||||
|
Raw format: '0|H0STASP0|<count>|<data>' where data = '^'-joined fields.
|
||||||
|
Field indices (운영 검증 필요): 마지막 2개 가정 (ask, bid).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
parts = raw.split("|")
|
||||||
|
if len(parts) < 4 or parts[1] != "H0STASP0":
|
||||||
|
return None
|
||||||
|
fields = parts[3].split("^")
|
||||||
|
ticker = fields[_ASKING_TICKER_IDX]
|
||||||
|
current_price_str = fields[_ASKING_CURRENT_PRICE_IDX]
|
||||||
|
current_price = int(current_price_str) if current_price_str.lstrip("-").isdigit() else 0
|
||||||
|
ask_str = fields[_ASKING_TOTAL_ASK_IDX]
|
||||||
|
bid_str = fields[_ASKING_TOTAL_BID_IDX]
|
||||||
|
ask_total = int(ask_str) if ask_str.lstrip("-").isdigit() else 0
|
||||||
|
bid_total = int(bid_str) if bid_str.lstrip("-").isdigit() else 0
|
||||||
|
total = bid_total + ask_total
|
||||||
|
return ticker, {
|
||||||
|
"bid_total": bid_total,
|
||||||
|
"ask_total": ask_total,
|
||||||
|
"bid_ratio": bid_total / total if total > 0 else 0.0,
|
||||||
|
"current_price": current_price,
|
||||||
|
"as_of": datetime.now(KST).isoformat(),
|
||||||
|
}
|
||||||
|
except (IndexError, ValueError) as e:
|
||||||
|
logger.warning("parse_asking_price failed: %r", e)
|
||||||
|
return None
|
||||||
125
signal_v2/main.py
Normal file
125
signal_v2/main.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""FastAPI app — Signal V2 Pull Worker."""
|
||||||
|
from __future__ import annotations
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from signal_v2 import state as state_mod
|
||||||
|
from signal_v2.chronos_predictor import ChronosPredictor
|
||||||
|
from signal_v2.config import get_settings
|
||||||
|
from signal_v2.kis_client import KISClient
|
||||||
|
from signal_v2.kis_websocket import KISWebSocket
|
||||||
|
from signal_v2.pull_worker import poll_loop, make_asking_price_callback
|
||||||
|
from signal_v2.rate_limit import SignalDedup
|
||||||
|
from signal_v2.stock_client import StockClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AppContext:
|
||||||
|
client: StockClient | None = None
|
||||||
|
dedup: SignalDedup | None = None
|
||||||
|
shutdown: asyncio.Event | None = None
|
||||||
|
poll_task: asyncio.Task | None = None
|
||||||
|
kis_client: KISClient | None = None
|
||||||
|
kis_ws: KISWebSocket | None = None
|
||||||
|
chronos: ChronosPredictor | None = None
|
||||||
|
|
||||||
|
|
||||||
|
_ctx = AppContext()
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
settings = get_settings()
|
||||||
|
if not settings.webai_api_key:
|
||||||
|
logger.warning(
|
||||||
|
"WEBAI_API_KEY not configured — stock API calls will fail with 401"
|
||||||
|
)
|
||||||
|
if not settings.kis_app_key:
|
||||||
|
logger.warning(
|
||||||
|
"KIS app_key not configured (KIS_ENV_TYPE=%s, KIS_%s_APP_KEY missing) — KIS REST/WebSocket disabled",
|
||||||
|
settings.kis_env_type, settings.kis_env_type.upper()
|
||||||
|
)
|
||||||
|
|
||||||
|
_ctx.client = StockClient(settings.stock_api_url, settings.webai_api_key)
|
||||||
|
_ctx.dedup = SignalDedup(settings.db_path)
|
||||||
|
_ctx.shutdown = asyncio.Event()
|
||||||
|
|
||||||
|
# KIS only if app_key configured
|
||||||
|
if settings.kis_app_key:
|
||||||
|
_ctx.kis_client = KISClient(
|
||||||
|
app_key=settings.kis_app_key,
|
||||||
|
app_secret=settings.kis_app_secret,
|
||||||
|
account=settings.kis_account,
|
||||||
|
is_virtual=settings.kis_is_virtual,
|
||||||
|
v1_token_path=settings.v1_token_path,
|
||||||
|
)
|
||||||
|
_ctx.kis_ws = KISWebSocket(
|
||||||
|
app_key=settings.kis_app_key,
|
||||||
|
app_secret=settings.kis_app_secret,
|
||||||
|
is_virtual=settings.kis_is_virtual,
|
||||||
|
)
|
||||||
|
# Subscribe portfolio holdings (if any)
|
||||||
|
try:
|
||||||
|
portfolio = await _ctx.client.get_portfolio()
|
||||||
|
tickers = [h["ticker"] for h in portfolio.get("holdings", []) if "ticker" in h]
|
||||||
|
cb = make_asking_price_callback(state_mod.state)
|
||||||
|
await _ctx.kis_ws.start(tickers, cb)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("KIS WebSocket startup failed — continuing without realtime asking_price")
|
||||||
|
|
||||||
|
# Load Chronos (heavy: ~1GB model download first time)
|
||||||
|
try:
|
||||||
|
_ctx.chronos = ChronosPredictor(model_name=settings.chronos_model)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("ChronosPredictor load failed — continuing without chronos predictions")
|
||||||
|
|
||||||
|
_ctx.poll_task = asyncio.create_task(
|
||||||
|
poll_loop(
|
||||||
|
_ctx.client, state_mod.state, _ctx.shutdown,
|
||||||
|
kis_client=_ctx.kis_client,
|
||||||
|
chronos=_ctx.chronos,
|
||||||
|
dedup=_ctx.dedup,
|
||||||
|
settings=settings,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Shutdown
|
||||||
|
if _ctx.shutdown is not None:
|
||||||
|
_ctx.shutdown.set()
|
||||||
|
if _ctx.poll_task is not None:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(_ctx.poll_task, timeout=5.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
_ctx.poll_task.cancel()
|
||||||
|
try:
|
||||||
|
await _ctx.poll_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
if _ctx.kis_ws is not None:
|
||||||
|
await _ctx.kis_ws.close()
|
||||||
|
if _ctx.kis_client is not None:
|
||||||
|
await _ctx.kis_client.close()
|
||||||
|
if _ctx.client is not None:
|
||||||
|
await _ctx.client.close()
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="Signal V2 Pull Worker", version="0.1.0", lifespan=lifespan
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health():
|
||||||
|
settings = get_settings()
|
||||||
|
return {
|
||||||
|
"status": "online",
|
||||||
|
"stock_api_url": settings.stock_api_url,
|
||||||
|
"last_poll": state_mod.state.last_updated,
|
||||||
|
"cache_size": _ctx.client.cache_size() if _ctx.client is not None else 0,
|
||||||
|
}
|
||||||
69
signal_v2/momentum_classifier.py
Normal file
69
signal_v2/momentum_classifier.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
"""분봉 OHLCV → 5-level 모멘텀 분류."""
|
||||||
|
from __future__ import annotations
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
# 분류 카테고리
|
||||||
|
STRONG_UP = "strong_up"
|
||||||
|
WEAK_UP = "weak_up"
|
||||||
|
NEUTRAL = "neutral"
|
||||||
|
WEAK_DOWN = "weak_down"
|
||||||
|
STRONG_DOWN = "strong_down"
|
||||||
|
|
||||||
|
_BARS_PER_5MIN = 5
|
||||||
|
_LOOKBACK_5MIN_BARS = 5
|
||||||
|
_VOLUME_AVG_WINDOW = 12 # 60분 = 5분봉 12개
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_1min_to_5min(minute_bars: list[dict]) -> list[dict]:
|
||||||
|
"""1분봉 N개 → 5분봉 floor(N/5) 개. 시간 오름차순.
|
||||||
|
|
||||||
|
각 5분봉: open=첫 1분봉 open, high=max, low=min, close=마지막 close, volume=sum.
|
||||||
|
"""
|
||||||
|
bars_5min = []
|
||||||
|
chunks = len(minute_bars) // _BARS_PER_5MIN
|
||||||
|
for i in range(chunks):
|
||||||
|
chunk = minute_bars[i * _BARS_PER_5MIN : (i + 1) * _BARS_PER_5MIN]
|
||||||
|
bars_5min.append({
|
||||||
|
"datetime": chunk[0]["datetime"],
|
||||||
|
"open": chunk[0]["open"],
|
||||||
|
"high": max(b["high"] for b in chunk),
|
||||||
|
"low": min(b["low"] for b in chunk),
|
||||||
|
"close": chunk[-1]["close"],
|
||||||
|
"volume": sum(b["volume"] for b in chunk),
|
||||||
|
})
|
||||||
|
return bars_5min
|
||||||
|
|
||||||
|
|
||||||
|
def classify_minute_momentum(minute_bars: deque) -> str:
|
||||||
|
"""1분봉 deque → 5-level 모멘텀 분류.
|
||||||
|
|
||||||
|
Returns: STRONG_UP / WEAK_UP / NEUTRAL / WEAK_DOWN / STRONG_DOWN
|
||||||
|
"""
|
||||||
|
minute_list = list(minute_bars)
|
||||||
|
if len(minute_list) < _BARS_PER_5MIN * _LOOKBACK_5MIN_BARS:
|
||||||
|
return NEUTRAL # 데이터 부족
|
||||||
|
|
||||||
|
bars_5min = aggregate_1min_to_5min(minute_list)
|
||||||
|
if len(bars_5min) < _LOOKBACK_5MIN_BARS:
|
||||||
|
return NEUTRAL
|
||||||
|
|
||||||
|
recent = bars_5min[-_LOOKBACK_5MIN_BARS:]
|
||||||
|
up_count = sum(1 for b in recent if b["close"] > b["open"])
|
||||||
|
|
||||||
|
# 거래량 multiplier: recent 5 avg vs 60분 avg
|
||||||
|
recent_vol_avg = sum(b["volume"] for b in recent) / len(recent)
|
||||||
|
long_window = bars_5min[-_VOLUME_AVG_WINDOW:]
|
||||||
|
long_vol_avg = sum(b["volume"] for b in long_window) / len(long_window)
|
||||||
|
vol_mult = recent_vol_avg / long_vol_avg if long_vol_avg > 0 else 1.0
|
||||||
|
|
||||||
|
# 5-level 분류
|
||||||
|
if up_count == 5 and vol_mult >= 1.5:
|
||||||
|
return STRONG_UP
|
||||||
|
elif up_count >= 3 and vol_mult >= 1.0:
|
||||||
|
return WEAK_UP
|
||||||
|
elif up_count == 0 and vol_mult >= 1.5:
|
||||||
|
return STRONG_DOWN
|
||||||
|
elif up_count <= 2 and vol_mult < 1.0:
|
||||||
|
return WEAK_DOWN
|
||||||
|
else:
|
||||||
|
return NEUTRAL
|
||||||
193
signal_v2/pull_worker.py
Normal file
193
signal_v2/pull_worker.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
"""Polling loop — async cron + state update."""
|
||||||
|
from __future__ import annotations
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from collections import deque
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from signal_v2.kis_client import KISClient
|
||||||
|
from signal_v2.scheduler import (
|
||||||
|
KST, _is_market_day, _is_polling_window, _next_interval, _is_post_close_trigger,
|
||||||
|
)
|
||||||
|
from signal_v2.state import PollState
|
||||||
|
from signal_v2.stock_client import StockClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def poll_loop(
|
||||||
|
client: StockClient, state: PollState, shutdown: asyncio.Event,
|
||||||
|
kis_client: KISClient | None = None,
|
||||||
|
chronos=None,
|
||||||
|
dedup=None,
|
||||||
|
settings=None,
|
||||||
|
) -> None:
|
||||||
|
"""FastAPI lifespan 에서 asyncio.create_task 로 시작."""
|
||||||
|
logger.info("poll_loop started")
|
||||||
|
while not shutdown.is_set():
|
||||||
|
now = datetime.now(KST)
|
||||||
|
if _is_market_day(now) and _is_polling_window(now):
|
||||||
|
try:
|
||||||
|
await _run_polling_cycle(client, state, kis_client=kis_client)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("poll cycle failed")
|
||||||
|
# Minute momentum 갱신 (매 cycle)
|
||||||
|
try:
|
||||||
|
update_minute_momentum_for_all(state)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("minute momentum update failed")
|
||||||
|
# Post-close trigger (16:00 KST)
|
||||||
|
if _is_post_close_trigger(now) and chronos is not None and kis_client is not None:
|
||||||
|
try:
|
||||||
|
await _run_post_close_cycle(kis_client, chronos, state)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("post-close cycle failed")
|
||||||
|
# Phase 4: generate signals
|
||||||
|
if dedup is not None and settings is not None:
|
||||||
|
try:
|
||||||
|
from signal_v2.signal_generator import generate_signals
|
||||||
|
generate_signals(state, dedup, settings)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("generate_signals failed")
|
||||||
|
interval = _next_interval(now)
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(shutdown.wait(), timeout=interval)
|
||||||
|
break
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
logger.info("poll_loop ended")
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_polling_cycle(
|
||||||
|
client: StockClient, state: PollState,
|
||||||
|
kis_client: KISClient | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""기존 3 endpoint (stock) + KIS 분봉 fetch."""
|
||||||
|
portfolio, sentiment, screener = await asyncio.gather(
|
||||||
|
client.get_portfolio(),
|
||||||
|
client.get_news_sentiment(),
|
||||||
|
client.run_screener_preview(),
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
now_iso = datetime.now(KST).isoformat()
|
||||||
|
|
||||||
|
for name, result in (
|
||||||
|
("portfolio", portfolio),
|
||||||
|
("news_sentiment", sentiment),
|
||||||
|
("screener_preview", screener),
|
||||||
|
):
|
||||||
|
if isinstance(result, dict):
|
||||||
|
setattr(state, name, result)
|
||||||
|
state.last_updated[name] = now_iso
|
||||||
|
state.fetch_errors[name] = 0
|
||||||
|
else:
|
||||||
|
state.fetch_errors[name] = state.fetch_errors.get(name, 0) + 1
|
||||||
|
logger.warning("fetch %s failed: %r", name, result)
|
||||||
|
|
||||||
|
# KIS 분봉 + 호가 (kis_client 주어졌을 때만)
|
||||||
|
if kis_client is not None:
|
||||||
|
try:
|
||||||
|
await _run_kis_minute_cycle(kis_client, state)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("kis minute cycle failed")
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_kis_minute_cycle(kis_client: KISClient, state: PollState) -> None:
|
||||||
|
"""KIS 분봉 + 호가 fetch + state 갱신.
|
||||||
|
|
||||||
|
- 분봉: portfolio + screener Top-N union 종목 모두
|
||||||
|
- 호가 (REST): screener-only 종목 (portfolio 는 WebSocket 으로 들어옴)
|
||||||
|
"""
|
||||||
|
portfolio_tickers = _portfolio_tickers(state)
|
||||||
|
screener_tickers = _screener_tickers(state)
|
||||||
|
all_tickers = list(set(portfolio_tickers) | set(screener_tickers))
|
||||||
|
|
||||||
|
# 분봉 fetch (병렬)
|
||||||
|
minute_results = await asyncio.gather(*[
|
||||||
|
kis_client.get_minute_ohlcv(t) for t in all_tickers
|
||||||
|
], return_exceptions=True)
|
||||||
|
now_iso = datetime.now(KST).isoformat()
|
||||||
|
for ticker, result in zip(all_tickers, minute_results):
|
||||||
|
if isinstance(result, list):
|
||||||
|
buf = state.minute_bars.setdefault(ticker, deque(maxlen=60))
|
||||||
|
buf.extend(result)
|
||||||
|
state.last_updated[f"minute_bars/{ticker}"] = now_iso
|
||||||
|
else:
|
||||||
|
state.fetch_errors[f"minute_bars/{ticker}"] = (
|
||||||
|
state.fetch_errors.get(f"minute_bars/{ticker}", 0) + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# 호가 fetch (REST) — screener-only
|
||||||
|
screener_only = list(set(screener_tickers) - set(portfolio_tickers))
|
||||||
|
asking_results = await asyncio.gather(*[
|
||||||
|
kis_client.get_asking_price(t) for t in screener_only
|
||||||
|
], return_exceptions=True)
|
||||||
|
for ticker, result in zip(screener_only, asking_results):
|
||||||
|
if isinstance(result, dict):
|
||||||
|
state.asking_price[ticker] = result
|
||||||
|
state.last_updated[f"asking_price/{ticker}"] = now_iso
|
||||||
|
|
||||||
|
|
||||||
|
def make_asking_price_callback(state: PollState):
|
||||||
|
"""KIS WebSocket on_asking_price callback factory."""
|
||||||
|
def _cb(ticker: str, data: dict) -> None:
|
||||||
|
state.asking_price[ticker] = data
|
||||||
|
state.last_updated[f"asking_price/{ticker}"] = datetime.now(KST).isoformat()
|
||||||
|
return _cb
|
||||||
|
|
||||||
|
|
||||||
|
def _portfolio_tickers(state: PollState) -> list[str]:
|
||||||
|
if state.portfolio is None:
|
||||||
|
return []
|
||||||
|
return [h["ticker"] for h in state.portfolio.get("holdings", []) if "ticker" in h]
|
||||||
|
|
||||||
|
|
||||||
|
def _screener_tickers(state: PollState) -> list[str]:
|
||||||
|
if state.screener_preview is None:
|
||||||
|
return []
|
||||||
|
return [i["ticker"] for i in state.screener_preview.get("items", []) if "ticker" in i]
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_post_close_cycle(kis_client, chronos, state) -> None:
|
||||||
|
"""16:00 KST 종가 후 1회: daily fetch + chronos predict."""
|
||||||
|
tickers = list(set(_portfolio_tickers(state)) | set(_screener_tickers(state)))
|
||||||
|
if not tickers:
|
||||||
|
return
|
||||||
|
|
||||||
|
daily_results = await asyncio.gather(*[
|
||||||
|
kis_client.get_daily_ohlcv(t, days=60) for t in tickers
|
||||||
|
], return_exceptions=True)
|
||||||
|
daily_dict = {}
|
||||||
|
for ticker, result in zip(tickers, daily_results):
|
||||||
|
if isinstance(result, list) and len(result) >= 30:
|
||||||
|
daily_dict[ticker] = result
|
||||||
|
state.daily_ohlcv[ticker] = result
|
||||||
|
elif isinstance(result, Exception):
|
||||||
|
state.fetch_errors[f"daily_ohlcv/{ticker}"] = (
|
||||||
|
state.fetch_errors.get(f"daily_ohlcv/{ticker}", 0) + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
if daily_dict and chronos is not None:
|
||||||
|
try:
|
||||||
|
predictions = chronos.predict_batch(daily_dict)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("chronos predict_batch failed")
|
||||||
|
return
|
||||||
|
for ticker, pred in predictions.items():
|
||||||
|
state.chronos_predictions[ticker] = {
|
||||||
|
"median": pred.median,
|
||||||
|
"q10": pred.q10,
|
||||||
|
"q90": pred.q90,
|
||||||
|
"conf": pred.conf,
|
||||||
|
"as_of": pred.as_of,
|
||||||
|
}
|
||||||
|
state.last_updated[f"chronos/{ticker}"] = pred.as_of
|
||||||
|
|
||||||
|
|
||||||
|
def update_minute_momentum_for_all(state) -> None:
|
||||||
|
"""매 분봉 cycle 후 호출 — 모든 종목 모멘텀 갱신."""
|
||||||
|
from signal_v2.momentum_classifier import classify_minute_momentum
|
||||||
|
now_iso = datetime.now(KST).isoformat()
|
||||||
|
for ticker, bars in state.minute_bars.items():
|
||||||
|
state.minute_momentum[ticker] = classify_minute_momentum(bars)
|
||||||
|
state.last_updated[f"momentum/{ticker}"] = now_iso
|
||||||
3
signal_v2/pytest.ini
Normal file
3
signal_v2/pytest.ini
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[pytest]
|
||||||
|
asyncio_mode = auto
|
||||||
|
testpaths = tests
|
||||||
73
signal_v2/rate_limit.py
Normal file
73
signal_v2/rate_limit.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
"""SignalDedup — SQLite-backed 24h duplicate signal blocker."""
|
||||||
|
from __future__ import annotations
|
||||||
|
import sqlite3
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
|
KST = ZoneInfo("Asia/Seoul")
|
||||||
|
|
||||||
|
|
||||||
|
def _now_iso() -> str:
|
||||||
|
"""Test seam — overridable via monkeypatch."""
|
||||||
|
return datetime.now(KST).isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
_SCHEMA = """
|
||||||
|
CREATE TABLE IF NOT EXISTS signal_dedup (
|
||||||
|
ticker TEXT NOT NULL,
|
||||||
|
action TEXT NOT NULL,
|
||||||
|
last_sent TEXT NOT NULL,
|
||||||
|
confidence REAL NOT NULL,
|
||||||
|
PRIMARY KEY (ticker, action)
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_signal_dedup_last_sent
|
||||||
|
ON signal_dedup(last_sent);
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class SignalDedup:
|
||||||
|
"""24h dedup interface. WAL + busy_timeout=120000."""
|
||||||
|
|
||||||
|
def __init__(self, db_path: Path):
|
||||||
|
self._db_path = Path(db_path)
|
||||||
|
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._init_schema()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _conn(self):
|
||||||
|
conn = sqlite3.connect(self._db_path, timeout=120.0)
|
||||||
|
try:
|
||||||
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
conn.execute("PRAGMA busy_timeout=120000")
|
||||||
|
yield conn
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def _init_schema(self) -> None:
|
||||||
|
with self._conn() as conn:
|
||||||
|
conn.executescript(_SCHEMA)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def is_recent(self, ticker: str, action: str, within_hours: int = 24) -> bool:
|
||||||
|
threshold_dt = datetime.fromisoformat(_now_iso()) - timedelta(hours=within_hours)
|
||||||
|
threshold_iso = threshold_dt.isoformat()
|
||||||
|
with self._conn() as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT last_sent FROM signal_dedup WHERE ticker = ? AND action = ?",
|
||||||
|
(ticker, action),
|
||||||
|
).fetchone()
|
||||||
|
return row is not None and row[0] >= threshold_iso
|
||||||
|
|
||||||
|
def record(self, ticker: str, action: str, confidence: float) -> None:
|
||||||
|
with self._conn() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""INSERT INTO signal_dedup (ticker, action, last_sent, confidence)
|
||||||
|
VALUES (?, ?, ?, ?)
|
||||||
|
ON CONFLICT (ticker, action) DO UPDATE
|
||||||
|
SET last_sent = excluded.last_sent,
|
||||||
|
confidence = excluded.confidence""",
|
||||||
|
(ticker, action, _now_iso(), confidence),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
99
signal_v2/scheduler.py
Normal file
99
signal_v2/scheduler.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""Polling scheduler — 시간대별 분기 + 휴장일 처리."""
|
||||||
|
from __future__ import annotations
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta, time
|
||||||
|
from pathlib import Path
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
KST = ZoneInfo("Asia/Seoul")
|
||||||
|
_HOLIDAYS_PATH = Path(__file__).parent / "holidays.json"
|
||||||
|
_HOLIDAYS: set[str] = set(json.loads(_HOLIDAYS_PATH.read_text(encoding="utf-8")))
|
||||||
|
|
||||||
|
# Market windows (정규장)
|
||||||
|
_PRE_OPEN = time(7, 0)
|
||||||
|
_OPEN = time(9, 0)
|
||||||
|
_CLOSE = time(15, 30)
|
||||||
|
_POST_END = time(20, 0)
|
||||||
|
|
||||||
|
# NXT windows (시간외)
|
||||||
|
_NXT_PRE_END = time(23, 30)
|
||||||
|
_NXT_POST_OPEN = time(4, 30)
|
||||||
|
# 23:30 - 04:30 (dead zone) skip
|
||||||
|
|
||||||
|
|
||||||
|
def _is_market_day(now: datetime) -> bool:
|
||||||
|
"""평일 + 휴장일 아닌 날."""
|
||||||
|
if now.weekday() >= 5: # Sat/Sun
|
||||||
|
return False
|
||||||
|
return now.strftime("%Y-%m-%d") not in _HOLIDAYS
|
||||||
|
|
||||||
|
|
||||||
|
def _is_polling_window(now: datetime) -> bool:
|
||||||
|
"""폴링 윈도우: 07:00-23:30 + 04:30-07:00."""
|
||||||
|
t = now.time()
|
||||||
|
return (
|
||||||
|
(_PRE_OPEN <= t < _NXT_PRE_END)
|
||||||
|
or (_NXT_POST_OPEN <= t < _PRE_OPEN)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _next_interval(now: datetime) -> float:
|
||||||
|
"""다음 폴링까지 sleep 초수."""
|
||||||
|
if not _is_market_day(now):
|
||||||
|
return _seconds_until_next_market_open(now)
|
||||||
|
|
||||||
|
t = now.time()
|
||||||
|
if _PRE_OPEN <= t < _OPEN:
|
||||||
|
return 300.0 # 장전 5분
|
||||||
|
elif _OPEN <= t < _CLOSE:
|
||||||
|
return 60.0 # 장중 1분
|
||||||
|
elif _CLOSE <= t < _POST_END:
|
||||||
|
return 300.0 # 장후 5분
|
||||||
|
elif _POST_END <= t < _NXT_PRE_END:
|
||||||
|
return 300.0 # NXT 야간 5분
|
||||||
|
elif _NXT_POST_OPEN <= t < _PRE_OPEN:
|
||||||
|
return 300.0 # NXT 새벽 5분
|
||||||
|
else:
|
||||||
|
# Dead zone (23:30-04:30) — wait until next 04:30
|
||||||
|
return _seconds_until_nxt_or_market_open(now)
|
||||||
|
|
||||||
|
|
||||||
|
def _seconds_until_nxt_or_market_open(now: datetime) -> float:
|
||||||
|
"""다음 04:30 (NXT 새벽 start) 까지 초수. 휴장일은 다음 영업일 07:00."""
|
||||||
|
candidate = now.replace(hour=4, minute=30, second=0, microsecond=0)
|
||||||
|
if candidate <= now:
|
||||||
|
candidate += timedelta(days=1)
|
||||||
|
|
||||||
|
for _ in range(14):
|
||||||
|
if _is_market_day(candidate):
|
||||||
|
return (candidate - now).total_seconds()
|
||||||
|
candidate += timedelta(days=1)
|
||||||
|
|
||||||
|
logger.warning("could not find next market day within 14 days")
|
||||||
|
return 86400.0
|
||||||
|
|
||||||
|
|
||||||
|
def _is_post_close_trigger(now: datetime) -> bool:
|
||||||
|
"""16:00 KST ±1분 (post-close cycle 트리거). 평일/영업일만."""
|
||||||
|
if not _is_market_day(now):
|
||||||
|
return False
|
||||||
|
t = now.time()
|
||||||
|
return time(16, 0) <= t < time(16, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def _seconds_until_next_market_open(now: datetime) -> float:
|
||||||
|
"""다음 영업일의 07:00 KST 까지 초수 (휴장일/주말용)."""
|
||||||
|
candidate = now.replace(hour=7, minute=0, second=0, microsecond=0)
|
||||||
|
if candidate <= now:
|
||||||
|
candidate += timedelta(days=1)
|
||||||
|
|
||||||
|
for _ in range(14): # safety bound (max 2 weeks of holidays)
|
||||||
|
if _is_market_day(candidate):
|
||||||
|
return (candidate - now).total_seconds()
|
||||||
|
candidate += timedelta(days=1)
|
||||||
|
|
||||||
|
logger.warning("could not find next market day within 14 days")
|
||||||
|
return 86400.0
|
||||||
228
signal_v2/signal_generator.py
Normal file
228
signal_v2/signal_generator.py
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
"""Phase 4 — 매수/매도 신호 생성.
|
||||||
|
|
||||||
|
순수 함수 generate_signals(state, dedup, settings). state 를 mutate.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
KST = ZoneInfo("Asia/Seoul")
|
||||||
|
|
||||||
|
MOMENTUM_SCORES = {
|
||||||
|
"strong_up": 1.0,
|
||||||
|
"weak_up": 0.7,
|
||||||
|
"neutral": 0.5,
|
||||||
|
"weak_down": 0.3,
|
||||||
|
"strong_down": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def generate_signals(state, dedup, settings) -> None:
|
||||||
|
"""Phase 4 entry — state-mutating. Evaluation order: sell first (priority), then buy. A ticker receiving a sell signal in this cycle is excluded from buy evaluation to avoid silent overwrite."""
|
||||||
|
_evaluate_sell_signals(state, dedup, settings)
|
||||||
|
_evaluate_buy_signals(state, dedup, settings)
|
||||||
|
|
||||||
|
|
||||||
|
# ----- 매수 -----
|
||||||
|
|
||||||
|
def _evaluate_buy_signals(state, dedup, settings) -> None:
|
||||||
|
candidates = _buy_candidates(state)
|
||||||
|
for ticker, name, rank in candidates:
|
||||||
|
existing = state.signals.get(ticker)
|
||||||
|
if existing is not None and existing.get("action") == "sell":
|
||||||
|
logger.debug("buy %s skipped: same-cycle sell precedence", ticker)
|
||||||
|
continue
|
||||||
|
if not _check_buy_hard_gate(state, ticker, settings):
|
||||||
|
logger.debug("buy %s skipped: hard gate failed", ticker)
|
||||||
|
continue
|
||||||
|
confidence = _compute_buy_confidence(state, ticker, rank)
|
||||||
|
if confidence <= settings.confidence_threshold:
|
||||||
|
logger.debug("buy %s skipped: confidence %.3f <= %.3f",
|
||||||
|
ticker, confidence, settings.confidence_threshold)
|
||||||
|
continue
|
||||||
|
if dedup.is_recent(ticker, "buy", within_hours=24):
|
||||||
|
logger.debug("buy %s skipped: dedup 24h", ticker)
|
||||||
|
continue
|
||||||
|
state.signals[ticker] = _build_buy_signal(state, ticker, name, rank, confidence)
|
||||||
|
dedup.record(ticker, "buy", confidence=confidence)
|
||||||
|
logger.info("signal emit %s buy conf=%.3f rank=%s", ticker, confidence, rank)
|
||||||
|
|
||||||
|
|
||||||
|
def _buy_candidates(state) -> list[tuple[str, str, int | None]]:
|
||||||
|
"""screener Top-N (rank 1..N) + portfolio (rank=None)."""
|
||||||
|
candidates: list[tuple[str, str, int | None]] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
if state.screener_preview is not None:
|
||||||
|
for i, item in enumerate(state.screener_preview.get("items", [])):
|
||||||
|
ticker = item.get("ticker")
|
||||||
|
if not ticker or ticker in seen:
|
||||||
|
continue
|
||||||
|
seen.add(ticker)
|
||||||
|
name = item.get("name", ticker)
|
||||||
|
candidates.append((ticker, name, i + 1))
|
||||||
|
if state.portfolio is not None:
|
||||||
|
for h in state.portfolio.get("holdings", []):
|
||||||
|
ticker = h.get("ticker")
|
||||||
|
if not ticker or ticker in seen:
|
||||||
|
continue
|
||||||
|
seen.add(ticker)
|
||||||
|
candidates.append((ticker, h.get("name", ticker), None))
|
||||||
|
return candidates
|
||||||
|
|
||||||
|
|
||||||
|
def _check_buy_hard_gate(state, ticker: str, settings) -> bool:
|
||||||
|
pred = state.chronos_predictions.get(ticker)
|
||||||
|
if pred is None or pred.get("median", 0) <= 0:
|
||||||
|
return False
|
||||||
|
spread = pred.get("q90", 0) - pred.get("q10", 0)
|
||||||
|
if spread >= settings.chronos_spread_threshold:
|
||||||
|
return False
|
||||||
|
momentum = state.minute_momentum.get(ticker)
|
||||||
|
if momentum != settings.min_momentum_for_buy:
|
||||||
|
return False
|
||||||
|
ap = state.asking_price.get(ticker)
|
||||||
|
if ap is None or ap.get("bid_ratio", 0) < settings.asking_bid_ratio_threshold:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_buy_confidence(state, ticker: str, rank: int | None) -> float:
|
||||||
|
pred = state.chronos_predictions[ticker]
|
||||||
|
chronos_conf = pred["conf"]
|
||||||
|
minute_score = MOMENTUM_SCORES.get(state.minute_momentum.get(ticker, "neutral"), 0.5)
|
||||||
|
screener_norm = max(0.0, 1 - (rank - 1) / 20) if rank is not None else 0.0
|
||||||
|
return chronos_conf * 0.5 + minute_score * 0.3 + screener_norm * 0.2
|
||||||
|
|
||||||
|
|
||||||
|
def _build_buy_signal(state, ticker: str, name: str, rank: int | None, confidence: float) -> dict:
|
||||||
|
ap = state.asking_price[ticker]
|
||||||
|
return {
|
||||||
|
"ticker": ticker,
|
||||||
|
"name": name,
|
||||||
|
"action": "buy",
|
||||||
|
"confidence_webai": confidence,
|
||||||
|
"current_price": ap["current_price"],
|
||||||
|
"avg_price": None,
|
||||||
|
"pnl_pct": None,
|
||||||
|
"context": _build_context(state, ticker, rank),
|
||||||
|
"as_of": datetime.now(KST).isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ----- 매도 -----
|
||||||
|
|
||||||
|
def _evaluate_sell_signals(state, dedup, settings) -> None:
|
||||||
|
if state.portfolio is None:
|
||||||
|
return
|
||||||
|
for holding in state.portfolio.get("holdings", []):
|
||||||
|
ticker = holding.get("ticker")
|
||||||
|
if not ticker:
|
||||||
|
continue
|
||||||
|
sell = _try_stop_loss(state, holding, settings)
|
||||||
|
if sell is None:
|
||||||
|
sell = _try_anomaly(state, holding, settings)
|
||||||
|
if sell is None:
|
||||||
|
sell = _try_take_profit(state, holding, settings)
|
||||||
|
if sell is None:
|
||||||
|
continue
|
||||||
|
if dedup.is_recent(ticker, "sell", within_hours=24):
|
||||||
|
logger.debug("sell %s skipped: dedup 24h", ticker)
|
||||||
|
continue
|
||||||
|
state.signals[ticker] = sell
|
||||||
|
dedup.record(ticker, "sell", confidence=sell["confidence_webai"])
|
||||||
|
logger.info("signal emit %s sell conf=%.3f reason=%s",
|
||||||
|
ticker, sell["confidence_webai"],
|
||||||
|
sell.get("context", {}).get("sell_reason"))
|
||||||
|
|
||||||
|
|
||||||
|
def _try_stop_loss(state, holding: dict, settings) -> dict | None:
|
||||||
|
pnl = holding.get("pnl_pct")
|
||||||
|
if pnl is None or pnl >= settings.stop_loss_pct:
|
||||||
|
return None
|
||||||
|
return _build_sell_signal(state, holding, confidence=1.0, reason="stop_loss")
|
||||||
|
|
||||||
|
|
||||||
|
def _try_take_profit(state, holding: dict, settings) -> dict | None:
|
||||||
|
pnl = holding.get("pnl_pct")
|
||||||
|
if pnl is None or pnl <= settings.take_profit_pct:
|
||||||
|
return None
|
||||||
|
return _build_sell_signal(state, holding, confidence=0.6, reason="take_profit")
|
||||||
|
|
||||||
|
|
||||||
|
def _try_anomaly(state, holding: dict, settings) -> dict | None:
|
||||||
|
ticker = holding["ticker"]
|
||||||
|
pred = state.chronos_predictions.get(ticker)
|
||||||
|
if pred is None or pred["median"] >= -0.01:
|
||||||
|
return None
|
||||||
|
momentum = state.minute_momentum.get(ticker)
|
||||||
|
if momentum != "strong_down":
|
||||||
|
return None
|
||||||
|
ap = state.asking_price.get(ticker)
|
||||||
|
if ap is None:
|
||||||
|
return None
|
||||||
|
if ap["bid_ratio"] > (1 - settings.asking_bid_ratio_threshold):
|
||||||
|
return None
|
||||||
|
minute_score = 1.0 - MOMENTUM_SCORES.get(momentum, 0.5)
|
||||||
|
confidence = pred["conf"] * 0.5 + minute_score * 0.3 + 1.0 * 0.2
|
||||||
|
if confidence <= settings.confidence_threshold:
|
||||||
|
return None
|
||||||
|
return _build_sell_signal(state, holding, confidence=confidence, reason="anomaly")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_sell_signal(state, holding: dict, confidence: float, reason: str) -> dict:
|
||||||
|
ticker = holding["ticker"]
|
||||||
|
return {
|
||||||
|
"ticker": ticker,
|
||||||
|
"name": holding.get("name", ticker),
|
||||||
|
"action": "sell",
|
||||||
|
"confidence_webai": confidence,
|
||||||
|
"current_price": holding.get("current_price"),
|
||||||
|
"avg_price": holding.get("avg_price"),
|
||||||
|
"pnl_pct": holding.get("pnl_pct"),
|
||||||
|
"context": _build_context(state, ticker, rank=None, sell_reason=reason),
|
||||||
|
"as_of": datetime.now(KST).isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ----- Context -----
|
||||||
|
|
||||||
|
def _build_context(state, ticker: str, rank: int | None, sell_reason: str | None = None) -> dict:
|
||||||
|
pred = state.chronos_predictions.get(ticker) or {}
|
||||||
|
ap = state.asking_price.get(ticker) or {}
|
||||||
|
news_item = _find_news_sentiment(state, ticker)
|
||||||
|
screener_scores = _find_screener_scores(state, ticker)
|
||||||
|
context: dict = {
|
||||||
|
"chronos_pred_1d": pred.get("median"),
|
||||||
|
"chronos_pred_conf": pred.get("conf"),
|
||||||
|
"chronos_q10": pred.get("q10"),
|
||||||
|
"chronos_q90": pred.get("q90"),
|
||||||
|
"screener_rank": rank,
|
||||||
|
"screener_scores": screener_scores,
|
||||||
|
"minute_momentum": state.minute_momentum.get(ticker),
|
||||||
|
"asking_bid_ratio": ap.get("bid_ratio"),
|
||||||
|
"news_sentiment": news_item.get("score") if news_item else None,
|
||||||
|
"news_reason": news_item.get("reason") if news_item else None,
|
||||||
|
}
|
||||||
|
if sell_reason is not None:
|
||||||
|
context["sell_reason"] = sell_reason
|
||||||
|
return context
|
||||||
|
|
||||||
|
|
||||||
|
def _find_news_sentiment(state, ticker: str) -> dict | None:
|
||||||
|
if state.news_sentiment is None:
|
||||||
|
return None
|
||||||
|
for item in state.news_sentiment.get("items", []):
|
||||||
|
if item.get("ticker") == ticker:
|
||||||
|
return item
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _find_screener_scores(state, ticker: str) -> dict | None:
|
||||||
|
if state.screener_preview is None:
|
||||||
|
return None
|
||||||
|
for item in state.screener_preview.get("items", []):
|
||||||
|
if item.get("ticker") == ticker:
|
||||||
|
return item.get("scores")
|
||||||
|
return None
|
||||||
3
signal_v2/start.bat
Normal file
3
signal_v2/start.bat
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
@echo off
|
||||||
|
cd /d "%~dp0\.."
|
||||||
|
python -m uvicorn signal_v2.main:app --host 0.0.0.0 --port 8001
|
||||||
22
signal_v2/state.py
Normal file
22
signal_v2/state.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
"""PollState — process-wide singleton."""
|
||||||
|
from collections import deque
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PollState:
|
||||||
|
portfolio: dict | None = None
|
||||||
|
news_sentiment: dict | None = None
|
||||||
|
screener_preview: dict | None = None
|
||||||
|
minute_bars: dict[str, deque] = field(default_factory=dict)
|
||||||
|
asking_price: dict[str, dict] = field(default_factory=dict)
|
||||||
|
# Phase 3b additions
|
||||||
|
daily_ohlcv: dict[str, list[dict]] = field(default_factory=dict)
|
||||||
|
chronos_predictions: dict[str, dict] = field(default_factory=dict)
|
||||||
|
minute_momentum: dict[str, str] = field(default_factory=dict)
|
||||||
|
signals: dict[str, dict] = field(default_factory=dict)
|
||||||
|
last_updated: dict[str, str] = field(default_factory=dict)
|
||||||
|
fetch_errors: dict[str, int] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
state = PollState()
|
||||||
128
signal_v2/stock_client.py
Normal file
128
signal_v2/stock_client.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""Stock API HTTP client — async httpx + retry + memory cache."""
|
||||||
|
from __future__ import annotations
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Cache TTL by endpoint (seconds)
|
||||||
|
_TTL = {
|
||||||
|
"portfolio": 60.0,
|
||||||
|
"news-sentiment": 300.0,
|
||||||
|
"screener-preview": 60.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Retry policy
|
||||||
|
_MAX_ATTEMPTS = 3
|
||||||
|
_RETRY_STATUSES = {429, 500, 502, 503, 504}
|
||||||
|
|
||||||
|
|
||||||
|
class StockClient:
|
||||||
|
"""stock API wrapper. Async httpx + self-retry + memory cache."""
|
||||||
|
|
||||||
|
def __init__(self, base_url: str, api_key: str, timeout: float = 10.0):
|
||||||
|
self._base_url = base_url.rstrip("/")
|
||||||
|
self._api_key = api_key
|
||||||
|
self._client = httpx.AsyncClient(timeout=timeout)
|
||||||
|
# cache: key → (data, timestamp_monotonic)
|
||||||
|
self._cache: dict[str, tuple[Any, float]] = {}
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
await self._client.aclose()
|
||||||
|
|
||||||
|
def cache_size(self) -> int:
|
||||||
|
"""Number of cached endpoint responses (public surface for /health)."""
|
||||||
|
return len(self._cache)
|
||||||
|
|
||||||
|
async def get_portfolio(self) -> dict:
|
||||||
|
return await self._cached_request(
|
||||||
|
"portfolio", "GET", "/api/webai/portfolio"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_news_sentiment(self, date: str | None = None) -> dict:
|
||||||
|
path = "/api/webai/news-sentiment"
|
||||||
|
if date is not None:
|
||||||
|
path += f"?date={date}"
|
||||||
|
cache_key = f"news-sentiment:{date or 'latest'}"
|
||||||
|
return await self._cached_request(
|
||||||
|
cache_key, "GET", path, _ttl_key="news-sentiment"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run_screener_preview(
|
||||||
|
self, weights: dict | None = None, top_n: int = 20
|
||||||
|
) -> dict:
|
||||||
|
body = {"mode": "preview", "top_n": top_n}
|
||||||
|
if weights is not None:
|
||||||
|
body["weights"] = weights
|
||||||
|
return await self._cached_request(
|
||||||
|
"screener-preview",
|
||||||
|
"POST",
|
||||||
|
"/api/stock/screener/run",
|
||||||
|
json=body,
|
||||||
|
_ttl_key="screener-preview",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _cached_request(
|
||||||
|
self,
|
||||||
|
cache_key: str,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
*,
|
||||||
|
_ttl_key: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> dict:
|
||||||
|
ttl_key = _ttl_key or cache_key
|
||||||
|
ttl = _TTL.get(ttl_key, 60.0)
|
||||||
|
# Fresh cache hit?
|
||||||
|
if cache_key in self._cache:
|
||||||
|
data, ts = self._cache[cache_key]
|
||||||
|
if time.monotonic() - ts < ttl:
|
||||||
|
return data
|
||||||
|
|
||||||
|
# Fetch (with retry)
|
||||||
|
try:
|
||||||
|
data = await self._request_with_retry(method, path, **kwargs)
|
||||||
|
self._cache[cache_key] = (data, time.monotonic())
|
||||||
|
return data
|
||||||
|
except httpx.HTTPError:
|
||||||
|
# Stale fallback: serve old cached value if exists
|
||||||
|
if cache_key in self._cache:
|
||||||
|
stale_data, stale_ts = self._cache[cache_key]
|
||||||
|
age = time.monotonic() - stale_ts
|
||||||
|
logger.warning(
|
||||||
|
"serving stale cache for %s (age=%.1fs)", cache_key, age
|
||||||
|
)
|
||||||
|
return stale_data
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _request_with_retry(self, method: str, path: str, **kwargs) -> dict:
|
||||||
|
url = f"{self._base_url}{path}"
|
||||||
|
headers = self._auth_headers()
|
||||||
|
for attempt in range(_MAX_ATTEMPTS):
|
||||||
|
try:
|
||||||
|
response = await self._client.request(
|
||||||
|
method, url, headers=headers, **kwargs
|
||||||
|
)
|
||||||
|
if response.status_code in _RETRY_STATUSES:
|
||||||
|
if attempt < _MAX_ATTEMPTS - 1:
|
||||||
|
await asyncio.sleep(2**attempt)
|
||||||
|
continue
|
||||||
|
response.raise_for_status()
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
if attempt < _MAX_ATTEMPTS - 1:
|
||||||
|
await asyncio.sleep(2**attempt)
|
||||||
|
continue
|
||||||
|
raise
|
||||||
|
except httpx.HTTPStatusError:
|
||||||
|
raise
|
||||||
|
# Unreachable: every iteration either returns or raises
|
||||||
|
raise RuntimeError("_request_with_retry exhausted loop without raising")
|
||||||
|
|
||||||
|
def _auth_headers(self) -> dict[str, str]:
|
||||||
|
return {"X-WebAI-Key": self._api_key}
|
||||||
0
signal_v2/tests/__init__.py
Normal file
0
signal_v2/tests/__init__.py
Normal file
18
signal_v2/tests/conftest.py
Normal file
18
signal_v2/tests/conftest.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
"""Pytest fixtures for signal_v2 tests."""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import respx
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmp_dedup_db(tmp_path) -> Path:
|
||||||
|
"""SQLite 단위 테스트용 임시 DB path."""
|
||||||
|
return tmp_path / "test_signal_v2.db"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_stock_api():
|
||||||
|
"""respx 로 stock API mock. base_url 은 테스트마다 임의."""
|
||||||
|
with respx.mock(base_url="https://test.stock.local", assert_all_called=False) as mock:
|
||||||
|
yield mock
|
||||||
92
signal_v2/tests/test_chronos_predictor.py
Normal file
92
signal_v2/tests/test_chronos_predictor.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""Tests for ChronosPredictor (model mock)."""
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_pipeline():
|
||||||
|
"""Mock BaseChronosPipeline.from_pretrained returning a mock pipeline object."""
|
||||||
|
with patch("chronos.BaseChronosPipeline") as cls:
|
||||||
|
cls.__name__ = "BaseChronosPipeline"
|
||||||
|
instance = MagicMock()
|
||||||
|
# ChronosBolt API: predict_quantiles returns (quantiles_tensor, mean_tensor)
|
||||||
|
# Modern (predict_quantiles) branch will be used since hasattr(MagicMock, "predict_quantiles") is True.
|
||||||
|
cls.from_pretrained.return_value = instance
|
||||||
|
yield instance
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_torch_cpu():
|
||||||
|
with patch("torch.cuda.is_available", return_value=False):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def _daily_ohlcv(close_seq):
|
||||||
|
return [{"datetime": f"2026-05-{i+1:02d}", "open": c, "high": c, "low": c,
|
||||||
|
"close": c, "volume": 1000} for i, c in enumerate(close_seq)]
|
||||||
|
|
||||||
|
|
||||||
|
def _mk_quantiles_tensor(q10_price: float, q50_price: float, q90_price: float):
|
||||||
|
"""Helper: build predict_quantiles return tensor shape [1, 1, 3]."""
|
||||||
|
import torch
|
||||||
|
return torch.tensor([[[q10_price, q50_price, q90_price]]], dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def test_predict_batch_returns_prediction_dict(mock_pipeline, mock_torch_cpu):
|
||||||
|
"""mock predict_quantiles → dict[ticker, ChronosPrediction]. last_close=100, q50=102 → median≈+2%."""
|
||||||
|
quantiles = _mk_quantiles_tensor(101.5, 102.0, 102.5) # narrow around 102
|
||||||
|
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
|
||||||
|
|
||||||
|
from signal_v2.chronos_predictor import ChronosPredictor, ChronosPrediction
|
||||||
|
predictor = ChronosPredictor(model_name="mock-model")
|
||||||
|
daily = {"005930": _daily_ohlcv([100] * 60)}
|
||||||
|
result = predictor.predict_batch(daily)
|
||||||
|
assert "005930" in result
|
||||||
|
pred = result["005930"]
|
||||||
|
assert isinstance(pred, ChronosPrediction)
|
||||||
|
assert abs(pred.median - 0.02) < 0.001
|
||||||
|
|
||||||
|
|
||||||
|
def test_conf_high_when_distribution_narrow(mock_pipeline, mock_torch_cpu):
|
||||||
|
"""좁은 distribution (q90-q10 작음, median 0 아님) → conf ≈ 1."""
|
||||||
|
# last_close=100, q10=101.99, q50=102.00, q90=102.01
|
||||||
|
# returns: q10=0.0199, q50=0.02, q90=0.0201
|
||||||
|
# spread = (0.0201 - 0.0199) / max(0.02, 0.001) = 0.0002/0.02 = 0.01 → conf = 1 - 0.005 = 0.995
|
||||||
|
quantiles = _mk_quantiles_tensor(101.99, 102.0, 102.01)
|
||||||
|
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
|
||||||
|
|
||||||
|
from signal_v2.chronos_predictor import ChronosPredictor
|
||||||
|
predictor = ChronosPredictor(model_name="mock-model")
|
||||||
|
daily = {"005930": _daily_ohlcv([100] * 60)}
|
||||||
|
result = predictor.predict_batch(daily)
|
||||||
|
assert result["005930"].conf > 0.8
|
||||||
|
|
||||||
|
|
||||||
|
def test_conf_low_when_distribution_wide(mock_pipeline, mock_torch_cpu):
|
||||||
|
"""넓은 distribution → conf ≈ 0."""
|
||||||
|
# last_close=100, q10=70, q50=100, q90=130
|
||||||
|
# returns: q10=-0.3, q50=0.0, q90=0.3
|
||||||
|
# spread = (0.3 - (-0.3)) / max(0.0, 0.001) = 0.6 / 0.001 = 600 → conf = max(0, 1 - 300) = 0
|
||||||
|
quantiles = _mk_quantiles_tensor(70.0, 100.0, 130.0)
|
||||||
|
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
|
||||||
|
|
||||||
|
from signal_v2.chronos_predictor import ChronosPredictor
|
||||||
|
predictor = ChronosPredictor(model_name="mock-model")
|
||||||
|
daily = {"005930": _daily_ohlcv([100] * 60)}
|
||||||
|
result = predictor.predict_batch(daily)
|
||||||
|
assert result["005930"].conf < 0.3
|
||||||
|
|
||||||
|
|
||||||
|
def test_return_computed_from_price_relative_to_last_close(mock_pipeline, mock_torch_cpu):
|
||||||
|
"""price 예측 → last_close 대비 return 변환. last_close=100, q50=110 → return ≈ +10%."""
|
||||||
|
quantiles = _mk_quantiles_tensor(109.0, 110.0, 111.0)
|
||||||
|
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
|
||||||
|
|
||||||
|
from signal_v2.chronos_predictor import ChronosPredictor
|
||||||
|
predictor = ChronosPredictor(model_name="mock-model")
|
||||||
|
# last close = 100
|
||||||
|
daily = {"005930": _daily_ohlcv(list(range(41, 101)))} # last = 100
|
||||||
|
result = predictor.predict_batch(daily)
|
||||||
|
assert abs(result["005930"].median - 0.10) < 0.001
|
||||||
161
signal_v2/tests/test_kis_client.py
Normal file
161
signal_v2/tests/test_kis_client.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
"""Tests for KISClient (REST)."""
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
import respx
|
||||||
|
|
||||||
|
from signal_v2.kis_client import KISClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_v1_token(tmp_path):
|
||||||
|
"""V1 토큰 파일 fixture."""
|
||||||
|
token_file = tmp_path / "kis_token.json"
|
||||||
|
token_file.write_text(json.dumps({
|
||||||
|
"access_token": "test-kis-token-abc123",
|
||||||
|
"token_expired": "2099-12-31 23:59:59",
|
||||||
|
}))
|
||||||
|
return token_file
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def kis_client_factory(fake_v1_token):
|
||||||
|
def _make():
|
||||||
|
return KISClient(
|
||||||
|
app_key="test-app-key",
|
||||||
|
app_secret="test-app-secret",
|
||||||
|
account="50000000-01",
|
||||||
|
is_virtual=True,
|
||||||
|
v1_token_path=fake_v1_token,
|
||||||
|
)
|
||||||
|
return _make
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_get_minute_ohlcv_normal_returns_30_bars(kis_client_factory):
|
||||||
|
"""정상 200 → 30개 분봉 list 반환."""
|
||||||
|
sample_output2 = [
|
||||||
|
{
|
||||||
|
"stck_bsop_date": "20260518",
|
||||||
|
"stck_cntg_hour": f"09{m:02d}00",
|
||||||
|
"stck_oprc": "78000", "stck_hgpr": "78500",
|
||||||
|
"stck_lwpr": "77800", "stck_prpr": "78300",
|
||||||
|
"cntg_vol": "12345",
|
||||||
|
}
|
||||||
|
for m in range(30) # 9:00-9:29 = 30 bars
|
||||||
|
]
|
||||||
|
respx.get(
|
||||||
|
"https://openapivts.koreainvestment.com:29443/uapi/domestic-stock/v1/quotations/inquire-time-itemchartprice"
|
||||||
|
).mock(
|
||||||
|
return_value=httpx.Response(200, json={"output2": sample_output2})
|
||||||
|
)
|
||||||
|
|
||||||
|
client = kis_client_factory()
|
||||||
|
try:
|
||||||
|
bars = await client.get_minute_ohlcv("005930")
|
||||||
|
assert len(bars) == 30
|
||||||
|
assert bars[0]["close"] == 78300
|
||||||
|
assert "datetime" in bars[0]
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_get_minute_ohlcv_429_retry_then_success(kis_client_factory, monkeypatch):
|
||||||
|
"""429 → exponential backoff → 200."""
|
||||||
|
sleep_calls = []
|
||||||
|
async def fake_sleep(s): sleep_calls.append(s)
|
||||||
|
monkeypatch.setattr("asyncio.sleep", fake_sleep)
|
||||||
|
|
||||||
|
respx.get(
|
||||||
|
"https://openapivts.koreainvestment.com:29443/uapi/domestic-stock/v1/quotations/inquire-time-itemchartprice"
|
||||||
|
).mock(side_effect=[
|
||||||
|
httpx.Response(429, text="rate limit"),
|
||||||
|
httpx.Response(200, json={"output2": []}),
|
||||||
|
])
|
||||||
|
client = kis_client_factory()
|
||||||
|
try:
|
||||||
|
result = await client.get_minute_ohlcv("005930")
|
||||||
|
assert result == []
|
||||||
|
assert 1 in sleep_calls
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_get_minute_ohlcv_uses_v1_token(kis_client_factory, fake_v1_token):
|
||||||
|
"""KIS 호출 헤더에 V1 토큰 파일의 access_token 사용."""
|
||||||
|
route = respx.get(
|
||||||
|
"https://openapivts.koreainvestment.com:29443/uapi/domestic-stock/v1/quotations/inquire-time-itemchartprice"
|
||||||
|
).mock(return_value=httpx.Response(200, json={"output2": []}))
|
||||||
|
|
||||||
|
client = kis_client_factory()
|
||||||
|
try:
|
||||||
|
await client.get_minute_ohlcv("005930")
|
||||||
|
assert route.called
|
||||||
|
req = route.calls.last.request
|
||||||
|
# check authorization header contains the V1 token
|
||||||
|
auth = req.headers.get("authorization", "")
|
||||||
|
assert "test-kis-token-abc123" in auth
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_get_asking_price_computes_bid_ratio(kis_client_factory):
|
||||||
|
"""호가 응답 → bid_total/(bid+ask) bid_ratio 계산."""
|
||||||
|
respx.get(
|
||||||
|
"https://openapivts.koreainvestment.com:29443/uapi/domestic-stock/v1/quotations/inquire-asking-price-exp-ccn"
|
||||||
|
).mock(return_value=httpx.Response(200, json={
|
||||||
|
"output1": {
|
||||||
|
"total_bidp_rsqn": "600",
|
||||||
|
"total_askp_rsqn": "400",
|
||||||
|
"stck_prpr": "78500",
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
client = kis_client_factory()
|
||||||
|
try:
|
||||||
|
data = await client.get_asking_price("005930")
|
||||||
|
assert data["bid_total"] == 600
|
||||||
|
assert data["ask_total"] == 400
|
||||||
|
assert abs(data["bid_ratio"] - 0.6) < 1e-9
|
||||||
|
assert data["current_price"] == 78500
|
||||||
|
assert "as_of" in data
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_get_daily_ohlcv_returns_60_bars(kis_client_factory):
|
||||||
|
"""KIS daily endpoint returns 60 ascending bars after parsing."""
|
||||||
|
# Build 60 KIS-format daily bars (descending dates as KIS does)
|
||||||
|
sample_output2 = []
|
||||||
|
for i in range(60):
|
||||||
|
# Generate a fake date 60 days ago, descending
|
||||||
|
day = 60 - i
|
||||||
|
sample_output2.append({
|
||||||
|
"stck_bsop_date": f"2026{(((day-1)//30)+1):02d}{(((day-1)%30)+1):02d}",
|
||||||
|
"stck_oprc": "78000", "stck_hgpr": "78500",
|
||||||
|
"stck_lwpr": "77800", "stck_clpr": str(78000 + i),
|
||||||
|
"acml_vol": "12345",
|
||||||
|
})
|
||||||
|
|
||||||
|
respx.get(
|
||||||
|
"https://openapivts.koreainvestment.com:29443/uapi/domestic-stock/v1/quotations/inquire-daily-itemchartprice"
|
||||||
|
).mock(return_value=httpx.Response(200, json={"output2": sample_output2}))
|
||||||
|
|
||||||
|
client = kis_client_factory()
|
||||||
|
try:
|
||||||
|
bars = await client.get_daily_ohlcv("005930", days=60)
|
||||||
|
# KIS returns descending; client reverses to ascending
|
||||||
|
assert len(bars) == 60
|
||||||
|
# Ascending order: first item has smaller datetime than last
|
||||||
|
assert bars[0]["datetime"] < bars[-1]["datetime"]
|
||||||
|
assert isinstance(bars[0]["open"], int)
|
||||||
|
assert isinstance(bars[0]["close"], int)
|
||||||
|
assert "datetime" in bars[0]
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
94
signal_v2/tests/test_kis_websocket.py
Normal file
94
signal_v2/tests/test_kis_websocket.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
"""Tests for KISWebSocket."""
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
import respx
|
||||||
|
|
||||||
|
from signal_v2.kis_websocket import KISWebSocket
|
||||||
|
|
||||||
|
|
||||||
|
BASE_REST = "https://openapivts.koreainvestment.com:29443"
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_fetch_approval_key_via_oauth_endpoint():
|
||||||
|
"""POST /oauth2/Approval → approval_key 추출."""
|
||||||
|
respx.post(f"{BASE_REST}/oauth2/Approval").mock(
|
||||||
|
return_value=httpx.Response(200, json={"approval_key": "test-approval-key-xyz"})
|
||||||
|
)
|
||||||
|
ws = KISWebSocket(app_key="k", app_secret="s", is_virtual=True)
|
||||||
|
key = await ws._fetch_approval_key()
|
||||||
|
assert key == "test-approval-key-xyz"
|
||||||
|
assert ws._approval_key == "test-approval-key-xyz"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_subscribe_sends_h0stasp0_message():
|
||||||
|
"""subscribe() → WebSocket 으로 H0STASP0 구독 메시지 전송."""
|
||||||
|
sent_messages = []
|
||||||
|
mock_ws = AsyncMock()
|
||||||
|
mock_ws.send = AsyncMock(side_effect=lambda m: sent_messages.append(m))
|
||||||
|
|
||||||
|
ws = KISWebSocket(app_key="k", app_secret="s", is_virtual=True)
|
||||||
|
ws._approval_key = "test-key"
|
||||||
|
ws._ws = mock_ws
|
||||||
|
await ws.subscribe("005930")
|
||||||
|
assert ws._subscriptions == {"005930"}
|
||||||
|
assert len(sent_messages) == 1
|
||||||
|
msg = json.loads(sent_messages[0])
|
||||||
|
assert msg["header"]["tr_type"] == "1" # subscribe
|
||||||
|
assert msg["body"]["input"]["tr_id"] == "H0STASP0"
|
||||||
|
assert msg["body"]["input"]["tr_key"] == "005930"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_asking_price_extracts_bid_ask_totals():
|
||||||
|
"""KIS raw '0|H0STASP0|001|...' → (ticker, dict).
|
||||||
|
|
||||||
|
KIS 호가 메시지 형식 — KIS 공식 spec 의 정확한 필드 인덱스 운영 검증 필요.
|
||||||
|
본 테스트는 implementer 의 _parse_asking_price 구현 인덱스에 맞춰서 sample 작성.
|
||||||
|
"""
|
||||||
|
ws = KISWebSocket(app_key="k", app_secret="s", is_virtual=True)
|
||||||
|
# Build a sample raw message — implementer 가 _ASKING_TOTAL_BID/ASK 인덱스에
|
||||||
|
# 맞춰서 필드 배치하면 됨. 예: 마지막 2개 필드를 bid_total / ask_total 로.
|
||||||
|
fields = ["005930", "091500", "78500"] # ticker, time, current_price
|
||||||
|
fields.extend(["0"] * 40) # padding (KIS 의 실 필드 수 ~50개)
|
||||||
|
fields.append("400") # ask_total
|
||||||
|
fields.append("600") # bid_total
|
||||||
|
raw = f"0|H0STASP0|001|{'^'.join(fields)}"
|
||||||
|
|
||||||
|
result = ws._parse_asking_price(raw)
|
||||||
|
assert result is not None, "parse_asking_price returned None"
|
||||||
|
ticker, data = result
|
||||||
|
assert ticker == "005930"
|
||||||
|
assert "bid_total" in data
|
||||||
|
assert "ask_total" in data
|
||||||
|
assert "bid_ratio" in data
|
||||||
|
assert "current_price" in data
|
||||||
|
# bid_total=600, ask_total=400, bid_ratio=0.6
|
||||||
|
assert data["bid_total"] == 600
|
||||||
|
assert data["ask_total"] == 400
|
||||||
|
assert abs(data["bid_ratio"] - 0.6) < 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
async def test_reconnect_on_disconnect_with_backoff(monkeypatch):
|
||||||
|
"""연결 끊김 → exponential backoff retry. _connect_with_backoff() 검증."""
|
||||||
|
sleep_calls = []
|
||||||
|
async def fake_sleep(s): sleep_calls.append(s)
|
||||||
|
monkeypatch.setattr("asyncio.sleep", fake_sleep)
|
||||||
|
|
||||||
|
ws = KISWebSocket(app_key="k", app_secret="s", is_virtual=True)
|
||||||
|
# Mock _connect to fail twice then succeed
|
||||||
|
call_count = [0]
|
||||||
|
async def fake_connect():
|
||||||
|
call_count[0] += 1
|
||||||
|
if call_count[0] < 3:
|
||||||
|
raise ConnectionError("fake disconnect")
|
||||||
|
return AsyncMock()
|
||||||
|
monkeypatch.setattr(ws, "_connect", fake_connect)
|
||||||
|
|
||||||
|
result = await ws._connect_with_backoff()
|
||||||
|
assert call_count[0] == 3 # 2 fails + 1 success
|
||||||
|
# exponential 1s, 2s
|
||||||
|
assert sleep_calls[:2] == [1, 2]
|
||||||
62
signal_v2/tests/test_main.py
Normal file
62
signal_v2/tests/test_main.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""Tests for FastAPI main app."""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
|
def test_health_endpoint_returns_status_online(monkeypatch):
|
||||||
|
monkeypatch.setenv("STOCK_API_URL", "https://test.stock.local")
|
||||||
|
monkeypatch.setenv("WEBAI_API_KEY", "test-secret")
|
||||||
|
# Reload modules so they pick up the new env
|
||||||
|
import importlib
|
||||||
|
from signal_v2 import config as cfg
|
||||||
|
importlib.reload(cfg)
|
||||||
|
from signal_v2 import main as main_mod
|
||||||
|
importlib.reload(main_mod)
|
||||||
|
with TestClient(main_mod.app) as client:
|
||||||
|
r = client.get("/health")
|
||||||
|
assert r.status_code == 200
|
||||||
|
body = r.json()
|
||||||
|
assert body["status"] == "online"
|
||||||
|
assert body["stock_api_url"] == "https://test.stock.local"
|
||||||
|
|
||||||
|
|
||||||
|
def test_startup_warns_if_webai_api_key_missing(monkeypatch, caplog):
|
||||||
|
# Use setenv with empty string + no-op load_dotenv to defeat .env re-read on reload
|
||||||
|
monkeypatch.setattr("signal_v2.config.load_dotenv", lambda *a, **k: None)
|
||||||
|
monkeypatch.setenv("WEBAI_API_KEY", "")
|
||||||
|
monkeypatch.setenv("STOCK_API_URL", "https://test.stock.local")
|
||||||
|
import importlib
|
||||||
|
from signal_v2 import config as cfg
|
||||||
|
importlib.reload(cfg)
|
||||||
|
# After reload, load_dotenv reference is fresh — re-patch
|
||||||
|
monkeypatch.setattr("signal_v2.config.load_dotenv", lambda *a, **k: None)
|
||||||
|
from signal_v2 import main as main_mod
|
||||||
|
importlib.reload(main_mod)
|
||||||
|
with caplog.at_level(logging.WARNING, logger="signal_v2.main"):
|
||||||
|
with TestClient(main_mod.app) as client:
|
||||||
|
client.get("/health")
|
||||||
|
assert any("WEBAI_API_KEY" in rec.message for rec in caplog.records)
|
||||||
|
|
||||||
|
|
||||||
|
def test_startup_warns_if_kis_app_key_missing(monkeypatch, caplog):
|
||||||
|
"""KIS app_key 미설정 시 startup WARNING (KIS 호출 disabled) — V1 패턴."""
|
||||||
|
monkeypatch.setattr("signal_v2.config.load_dotenv", lambda *a, **k: None)
|
||||||
|
monkeypatch.setenv("STOCK_API_URL", "https://test.stock.local")
|
||||||
|
monkeypatch.setenv("WEBAI_API_KEY", "test-secret")
|
||||||
|
# V1 pattern: kis_env_type=virtual, both virtual keys empty
|
||||||
|
monkeypatch.setenv("KIS_ENV_TYPE", "virtual")
|
||||||
|
monkeypatch.setenv("KIS_VIRTUAL_APP_KEY", "")
|
||||||
|
monkeypatch.setenv("KIS_REAL_APP_KEY", "")
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
from signal_v2 import config as cfg
|
||||||
|
importlib.reload(cfg)
|
||||||
|
monkeypatch.setattr("signal_v2.config.load_dotenv", lambda *a, **k: None)
|
||||||
|
from signal_v2 import main as main_mod
|
||||||
|
importlib.reload(main_mod)
|
||||||
|
with caplog.at_level(logging.WARNING, logger="signal_v2.main"):
|
||||||
|
with TestClient(main_mod.app) as client:
|
||||||
|
client.get("/health")
|
||||||
|
assert any("KIS" in rec.message and "app_key" in rec.message.lower() for rec in caplog.records)
|
||||||
92
signal_v2/tests/test_momentum_classifier.py
Normal file
92
signal_v2/tests/test_momentum_classifier.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""Tests for minute momentum classifier."""
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
from signal_v2.momentum_classifier import (
|
||||||
|
aggregate_1min_to_5min, classify_minute_momentum,
|
||||||
|
STRONG_UP, WEAK_UP, NEUTRAL, WEAK_DOWN, STRONG_DOWN,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _bar(open_, high, low, close, volume):
|
||||||
|
return {
|
||||||
|
"datetime": "2026-05-18T09:00:00+09:00",
|
||||||
|
"open": open_, "high": high, "low": low, "close": close, "volume": volume,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_chunks(num_chunks_up: int, num_chunks_total: int, base_vol: int = 1000):
|
||||||
|
"""num_chunks_total 개의 5-bar 청크. num_chunks_up 청크는 양봉, 나머지는 음봉.
|
||||||
|
각 청크는 5개 1분봉. 거래량 = base_vol per bar.
|
||||||
|
"""
|
||||||
|
bars = []
|
||||||
|
for i in range(num_chunks_total):
|
||||||
|
is_up = i < num_chunks_up
|
||||||
|
o, c = (100, 110) if is_up else (110, 100)
|
||||||
|
for j in range(5):
|
||||||
|
bars.append(_bar(o, max(o, c) + 5, min(o, c) - 5, c, base_vol))
|
||||||
|
return bars
|
||||||
|
|
||||||
|
|
||||||
|
def test_strong_up_5_consecutive_green_with_high_volume():
|
||||||
|
"""직전 5개 5분봉 모두 양봉 + 거래량 1.5x → STRONG_UP."""
|
||||||
|
# 60분 (12 5분봉) 데이터: 7 normal + 5 high-vol up
|
||||||
|
older = _make_chunks(num_chunks_up=3, num_chunks_total=7, base_vol=1000)
|
||||||
|
recent = _make_chunks(num_chunks_up=5, num_chunks_total=5, base_vol=2500)
|
||||||
|
minute_bars = deque(older + recent, maxlen=60)
|
||||||
|
assert classify_minute_momentum(minute_bars) == STRONG_UP
|
||||||
|
|
||||||
|
|
||||||
|
def test_weak_up_3of5_green_normal_volume():
|
||||||
|
"""직전 5개 5분봉 중 3-4개 양봉 + 거래량 ≥ 1.0x → WEAK_UP."""
|
||||||
|
older = _make_chunks(num_chunks_up=3, num_chunks_total=7, base_vol=1000)
|
||||||
|
# 5 chunks: 3 up + 2 down, normal vol
|
||||||
|
recent_up = _make_chunks(num_chunks_up=3, num_chunks_total=3, base_vol=1000)
|
||||||
|
recent_down = _make_chunks(num_chunks_up=0, num_chunks_total=2, base_vol=1000)
|
||||||
|
minute_bars = deque(older + recent_up + recent_down, maxlen=60)
|
||||||
|
assert classify_minute_momentum(minute_bars) == WEAK_UP
|
||||||
|
|
||||||
|
|
||||||
|
def test_neutral_mixed():
|
||||||
|
"""up_count=2, vol normal → NEUTRAL (rule 미해당)."""
|
||||||
|
older = _make_chunks(num_chunks_up=3, num_chunks_total=7, base_vol=1000)
|
||||||
|
recent_up = _make_chunks(num_chunks_up=2, num_chunks_total=2, base_vol=1000)
|
||||||
|
recent_down = _make_chunks(num_chunks_up=0, num_chunks_total=3, base_vol=1000)
|
||||||
|
minute_bars = deque(older + recent_up + recent_down, maxlen=60)
|
||||||
|
# up_count=2, vol_mult=1.0 → 어느 분기 조건도 만족 안 함 → NEUTRAL
|
||||||
|
assert classify_minute_momentum(minute_bars) == NEUTRAL
|
||||||
|
|
||||||
|
|
||||||
|
def test_weak_down_low_green_low_volume():
|
||||||
|
"""up_count <= 2 + vol < 1.0 → WEAK_DOWN."""
|
||||||
|
older = _make_chunks(num_chunks_up=3, num_chunks_total=7, base_vol=1000)
|
||||||
|
recent_up = _make_chunks(num_chunks_up=1, num_chunks_total=1, base_vol=500)
|
||||||
|
recent_down = _make_chunks(num_chunks_up=0, num_chunks_total=4, base_vol=500)
|
||||||
|
minute_bars = deque(older + recent_up + recent_down, maxlen=60)
|
||||||
|
# recent 5 chunks avg vol = 500, long 12 avg ≈ (7*1000 + 5*500) / 12 ≈ 791 → vol_mult ≈ 0.63
|
||||||
|
assert classify_minute_momentum(minute_bars) == WEAK_DOWN
|
||||||
|
|
||||||
|
|
||||||
|
def test_strong_down_5_consecutive_red_high_volume():
|
||||||
|
"""직전 5개 5분봉 모두 음봉 + 거래량 1.5x → STRONG_DOWN."""
|
||||||
|
older = _make_chunks(num_chunks_up=3, num_chunks_total=7, base_vol=1000)
|
||||||
|
recent = _make_chunks(num_chunks_up=0, num_chunks_total=5, base_vol=2500)
|
||||||
|
minute_bars = deque(older + recent, maxlen=60)
|
||||||
|
assert classify_minute_momentum(minute_bars) == STRONG_DOWN
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_1min_to_5min_correctness():
|
||||||
|
"""5 1분봉 → 1개 5분봉 — open/close/high/low/volume 정확."""
|
||||||
|
bars = [
|
||||||
|
_bar(100, 105, 99, 102, 1000),
|
||||||
|
_bar(102, 108, 101, 107, 1500),
|
||||||
|
_bar(107, 110, 105, 106, 800),
|
||||||
|
_bar(106, 109, 104, 108, 1200),
|
||||||
|
_bar(108, 112, 107, 111, 900),
|
||||||
|
]
|
||||||
|
result = aggregate_1min_to_5min(bars)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["open"] == 100 # 첫 bar
|
||||||
|
assert result[0]["close"] == 111 # 마지막 bar
|
||||||
|
assert result[0]["high"] == 112 # max
|
||||||
|
assert result[0]["low"] == 99 # min
|
||||||
|
assert result[0]["volume"] == 5400 # sum
|
||||||
131
signal_v2/tests/test_pull_worker.py
Normal file
131
signal_v2/tests/test_pull_worker.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""Tests for pull_worker (Phase 3a additions)."""
|
||||||
|
from collections import deque
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from signal_v2.state import PollState
|
||||||
|
|
||||||
|
|
||||||
|
async def test_minute_polling_cycle_updates_state_minute_bars():
|
||||||
|
"""KIS REST mock 의 분봉 데이터가 state.minute_bars[ticker] deque 에 들어간다."""
|
||||||
|
from signal_v2.pull_worker import _run_kis_minute_cycle
|
||||||
|
|
||||||
|
state = PollState()
|
||||||
|
state.portfolio = {"holdings": [{"ticker": "005930"}, {"ticker": "000660"}]}
|
||||||
|
state.screener_preview = {
|
||||||
|
"items": [{"ticker": "005930"}, {"ticker": "035720"}]
|
||||||
|
}
|
||||||
|
|
||||||
|
kis_client_mock = MagicMock()
|
||||||
|
kis_client_mock.get_minute_ohlcv = AsyncMock(side_effect=[
|
||||||
|
[{"datetime": "2026-05-18T09:00:00+09:00", "open": 78000,
|
||||||
|
"high": 78500, "low": 77900, "close": 78300, "volume": 12345}],
|
||||||
|
[{"datetime": "2026-05-18T09:00:00+09:00", "open": 180000,
|
||||||
|
"high": 181000, "low": 179800, "close": 180500, "volume": 5000}],
|
||||||
|
[{"datetime": "2026-05-18T09:00:00+09:00", "open": 51000,
|
||||||
|
"high": 51200, "low": 50800, "close": 51100, "volume": 8000}],
|
||||||
|
])
|
||||||
|
kis_client_mock.get_asking_price = AsyncMock(return_value={
|
||||||
|
"bid_total": 600, "ask_total": 400, "bid_ratio": 0.6,
|
||||||
|
"current_price": 51100, "as_of": "2026-05-18T09:00:30+09:00",
|
||||||
|
})
|
||||||
|
|
||||||
|
await _run_kis_minute_cycle(kis_client_mock, state)
|
||||||
|
|
||||||
|
# 3 unique tickers (005930, 000660, 035720)
|
||||||
|
assert "005930" in state.minute_bars
|
||||||
|
assert "000660" in state.minute_bars
|
||||||
|
assert "035720" in state.minute_bars
|
||||||
|
assert len(state.minute_bars["005930"]) >= 1
|
||||||
|
# asking_price 만 screener-only ticker (035720) 에 들어가야 함
|
||||||
|
# (portfolio = 005930, 000660 는 WebSocket 으로 들어옴)
|
||||||
|
assert "035720" in state.asking_price
|
||||||
|
|
||||||
|
|
||||||
|
def test_websocket_message_updates_state_asking_price():
|
||||||
|
"""WebSocket callback factory → state.asking_price 갱신."""
|
||||||
|
from signal_v2.pull_worker import make_asking_price_callback
|
||||||
|
|
||||||
|
state = PollState()
|
||||||
|
cb = make_asking_price_callback(state)
|
||||||
|
cb("005930", {"bid_total": 1000, "ask_total": 800, "bid_ratio": 0.555,
|
||||||
|
"current_price": 78500, "as_of": "2026-05-18T10:00:00+09:00"})
|
||||||
|
assert state.asking_price["005930"]["bid_total"] == 1000
|
||||||
|
assert "asking_price/005930" in state.last_updated
|
||||||
|
|
||||||
|
|
||||||
|
async def test_post_close_cycle_updates_chronos_predictions():
|
||||||
|
"""mock kis + mock chronos → state.chronos_predictions + state.daily_ohlcv 갱신."""
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from signal_v2.pull_worker import _run_post_close_cycle
|
||||||
|
from signal_v2.chronos_predictor import ChronosPrediction
|
||||||
|
from signal_v2.state import PollState
|
||||||
|
|
||||||
|
state = PollState()
|
||||||
|
state.portfolio = {"holdings": [{"ticker": "005930"}]}
|
||||||
|
state.screener_preview = {"items": [{"ticker": "000660"}]}
|
||||||
|
|
||||||
|
kis_mock = MagicMock()
|
||||||
|
daily_005930 = [{"datetime": f"2026-05-{i+1:02d}", "open": 100, "high": 105,
|
||||||
|
"low": 95, "close": 100 + i, "volume": 1000} for i in range(60)]
|
||||||
|
daily_000660 = [{"datetime": f"2026-05-{i+1:02d}", "open": 200, "high": 210,
|
||||||
|
"low": 190, "close": 200 + i, "volume": 2000} for i in range(60)]
|
||||||
|
# _run_post_close_cycle iterates tickers and calls get_daily_ohlcv per ticker.
|
||||||
|
# Order depends on set() so use side_effect mapping if possible, otherwise list.
|
||||||
|
async def fake_daily(ticker, days=60):
|
||||||
|
if ticker == "005930":
|
||||||
|
return daily_005930
|
||||||
|
if ticker == "000660":
|
||||||
|
return daily_000660
|
||||||
|
return []
|
||||||
|
kis_mock.get_daily_ohlcv = AsyncMock(side_effect=fake_daily)
|
||||||
|
|
||||||
|
chronos_mock = MagicMock()
|
||||||
|
chronos_mock.predict_batch = MagicMock(return_value={
|
||||||
|
"005930": ChronosPrediction(0.02, -0.01, 0.04, 0.85, "2026-05-18T16:00:00+09:00"),
|
||||||
|
"000660": ChronosPrediction(0.03, -0.02, 0.06, 0.75, "2026-05-18T16:00:00+09:00"),
|
||||||
|
})
|
||||||
|
|
||||||
|
await _run_post_close_cycle(kis_mock, chronos_mock, state)
|
||||||
|
|
||||||
|
assert "005930" in state.chronos_predictions
|
||||||
|
assert "000660" in state.chronos_predictions
|
||||||
|
assert state.chronos_predictions["005930"]["median"] == 0.02
|
||||||
|
assert state.chronos_predictions["005930"]["conf"] == 0.85
|
||||||
|
assert "005930" in state.daily_ohlcv
|
||||||
|
assert "chronos/005930" in state.last_updated
|
||||||
|
|
||||||
|
|
||||||
|
def test_poll_loop_calls_generate_signals_after_cycle(monkeypatch):
|
||||||
|
"""Phase 4: generate_signals 가 cycle 후 state.signals 를 갱신한다."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from signal_v2.state import PollState
|
||||||
|
from signal_v2.signal_generator import generate_signals
|
||||||
|
|
||||||
|
state = PollState()
|
||||||
|
state.portfolio = {"holdings": [{
|
||||||
|
"ticker": "005930", "name": "삼성전자",
|
||||||
|
"avg_price": 75000, "current_price": 69000,
|
||||||
|
"pnl_pct": -0.08, "profit_rate": -8.0,
|
||||||
|
"quantity": 100, "broker": "키움",
|
||||||
|
}]}
|
||||||
|
state.screener_preview = {"items": []}
|
||||||
|
|
||||||
|
dedup = MagicMock()
|
||||||
|
dedup.is_recent.return_value = False
|
||||||
|
|
||||||
|
settings = MagicMock()
|
||||||
|
settings.stop_loss_pct = -0.07
|
||||||
|
settings.take_profit_pct = 0.15
|
||||||
|
settings.chronos_spread_threshold = 0.6
|
||||||
|
settings.asking_bid_ratio_threshold = 0.6
|
||||||
|
settings.confidence_threshold = 0.7
|
||||||
|
settings.min_momentum_for_buy = "strong_up"
|
||||||
|
|
||||||
|
generate_signals(state, dedup, settings)
|
||||||
|
|
||||||
|
assert "005930" in state.signals
|
||||||
|
assert state.signals["005930"]["action"] == "sell"
|
||||||
|
assert state.signals["005930"]["confidence_webai"] == 1.0
|
||||||
|
dedup.record.assert_called_with("005930", "sell", confidence=1.0)
|
||||||
34
signal_v2/tests/test_rate_limit.py
Normal file
34
signal_v2/tests/test_rate_limit.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""Tests for SignalDedup."""
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
|
from signal_v2.rate_limit import SignalDedup
|
||||||
|
|
||||||
|
KST = ZoneInfo("Asia/Seoul")
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_recent_returns_false_for_new_ticker_action(tmp_dedup_db):
|
||||||
|
dedup = SignalDedup(tmp_dedup_db)
|
||||||
|
assert dedup.is_recent("005930", "buy") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_recent_returns_true_within_24h(tmp_dedup_db):
|
||||||
|
dedup = SignalDedup(tmp_dedup_db)
|
||||||
|
dedup.record("005930", "buy", confidence=0.82)
|
||||||
|
assert dedup.is_recent("005930", "buy") is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_recent_returns_false_after_24h(tmp_dedup_db, monkeypatch):
|
||||||
|
dedup = SignalDedup(tmp_dedup_db)
|
||||||
|
# Record with a timestamp 25 hours ago
|
||||||
|
now = datetime.now(KST)
|
||||||
|
fake_now = now - timedelta(hours=25)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"signal_v2.rate_limit._now_iso", lambda: fake_now.isoformat()
|
||||||
|
)
|
||||||
|
dedup.record("005930", "buy", confidence=0.82)
|
||||||
|
# Reset to real now for is_recent check
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"signal_v2.rate_limit._now_iso", lambda: now.isoformat()
|
||||||
|
)
|
||||||
|
assert dedup.is_recent("005930", "buy", within_hours=24) is False
|
||||||
81
signal_v2/tests/test_scheduler.py
Normal file
81
signal_v2/tests/test_scheduler.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
"""Tests for scheduler interval logic."""
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from signal_v2.scheduler import _next_interval, _is_market_day, KST
|
||||||
|
|
||||||
|
|
||||||
|
def _kst(year, month, day, hour, minute=0):
|
||||||
|
return datetime(year, month, day, hour, minute, tzinfo=KST)
|
||||||
|
|
||||||
|
|
||||||
|
def test_next_interval_pre_market_5min():
|
||||||
|
now = _kst(2026, 5, 18, 8, 30) # Monday 08:30
|
||||||
|
assert _next_interval(now) == 300
|
||||||
|
|
||||||
|
|
||||||
|
def test_next_interval_market_open_1min():
|
||||||
|
now = _kst(2026, 5, 18, 10, 0) # Monday 10:00
|
||||||
|
assert _next_interval(now) == 60
|
||||||
|
|
||||||
|
|
||||||
|
def test_next_interval_post_market_5min():
|
||||||
|
now = _kst(2026, 5, 18, 17, 0) # Monday 17:00
|
||||||
|
assert _next_interval(now) == 300
|
||||||
|
|
||||||
|
|
||||||
|
def test_next_interval_overnight_skip_to_next_morning():
|
||||||
|
now = _kst(2026, 5, 18, 2, 30) # Monday 02:30 (dead zone, not NXT window)
|
||||||
|
interval = _next_interval(now)
|
||||||
|
# Dead zone 23:30-04:30 → next 04:30 is ~2h away
|
||||||
|
assert 2 * 3600 - 60 < interval < 2 * 3600 + 60
|
||||||
|
|
||||||
|
|
||||||
|
def test_next_interval_holiday_skip():
|
||||||
|
# 2026-05-05 어린이날 (Tuesday holiday)
|
||||||
|
now = _kst(2026, 5, 5, 10, 0)
|
||||||
|
assert _is_market_day(now) is False
|
||||||
|
interval = _next_interval(now)
|
||||||
|
# Next: 2026-05-06 (Wed) 07:00, ~21h away
|
||||||
|
assert 20 * 3600 < interval < 22 * 3600
|
||||||
|
|
||||||
|
|
||||||
|
def test_next_interval_at_market_open_boundary():
|
||||||
|
"""09:00:00 정확 second → 60초 (market 구간 진입)."""
|
||||||
|
now = _kst(2026, 5, 18, 9, 0) # Monday 09:00:00
|
||||||
|
assert _next_interval(now) == 60
|
||||||
|
|
||||||
|
|
||||||
|
def test_next_interval_at_market_close_boundary():
|
||||||
|
"""15:30:00 정확 second → 300초 (post-market 구간 진입)."""
|
||||||
|
now = _kst(2026, 5, 18, 15, 30) # Monday 15:30:00
|
||||||
|
assert _next_interval(now) == 300
|
||||||
|
|
||||||
|
|
||||||
|
def test_next_interval_at_polling_window_end_boundary():
|
||||||
|
"""23:30:00 정확 second → dead zone skip (다음 04:30 까지)."""
|
||||||
|
now = _kst(2026, 5, 18, 23, 30) # Monday 23:30:00 (NXT_PRE_END boundary)
|
||||||
|
interval = _next_interval(now)
|
||||||
|
# Dead zone 23:30-04:30 → next 04:30 is ~5h away
|
||||||
|
assert 5 * 3600 - 60 < interval < 5 * 3600 + 60
|
||||||
|
|
||||||
|
|
||||||
|
def test_next_interval_nxt_evening_5min():
|
||||||
|
"""22:00 평일 (NXT 야간) → 300 (5분)."""
|
||||||
|
now = _kst(2026, 5, 18, 22, 0)
|
||||||
|
assert _next_interval(now) == 300
|
||||||
|
|
||||||
|
|
||||||
|
def test_next_interval_nxt_dawn_5min():
|
||||||
|
"""05:30 평일 (NXT 새벽) → 300 (5분)."""
|
||||||
|
now = _kst(2026, 5, 18, 5, 30)
|
||||||
|
assert _next_interval(now) == 300
|
||||||
|
|
||||||
|
|
||||||
|
def test_next_interval_dead_zone_skip():
|
||||||
|
"""02:00 평일 (dead zone 23:30-04:30) → 다음 04:30 까지 (~9000s)."""
|
||||||
|
now = _kst(2026, 5, 18, 2, 0)
|
||||||
|
interval = _next_interval(now)
|
||||||
|
# 02:00 → 04:30 = 2.5h = 9000s
|
||||||
|
assert 9000 - 60 < interval < 9000 + 60
|
||||||
172
signal_v2/tests/test_signal_generator.py
Normal file
172
signal_v2/tests/test_signal_generator.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
"""Tests for signal_generator."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from signal_v2.signal_generator import generate_signals
|
||||||
|
from signal_v2.state import PollState
|
||||||
|
|
||||||
|
|
||||||
|
def _settings(**overrides):
|
||||||
|
"""Build a Settings-like object for tests (avoid env)."""
|
||||||
|
defaults = dict(
|
||||||
|
stop_loss_pct=-0.07,
|
||||||
|
take_profit_pct=0.15,
|
||||||
|
chronos_spread_threshold=0.6,
|
||||||
|
asking_bid_ratio_threshold=0.6,
|
||||||
|
confidence_threshold=0.7,
|
||||||
|
min_momentum_for_buy="strong_up",
|
||||||
|
)
|
||||||
|
defaults.update(overrides)
|
||||||
|
m = MagicMock()
|
||||||
|
for k, v in defaults.items():
|
||||||
|
setattr(m, k, v)
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def _make_state_with_buy_candidate(
|
||||||
|
ticker="005930", name="삼성전자",
|
||||||
|
chronos_median=0.02, chronos_q10=-0.01, chronos_q90=0.04, chronos_conf=0.85,
|
||||||
|
momentum="strong_up", bid_ratio=0.7, current_price=78500,
|
||||||
|
):
|
||||||
|
state = PollState()
|
||||||
|
state.screener_preview = {"items": [{"ticker": ticker, "name": name}]}
|
||||||
|
state.chronos_predictions[ticker] = {
|
||||||
|
"median": chronos_median, "q10": chronos_q10, "q90": chronos_q90,
|
||||||
|
"conf": chronos_conf, "as_of": "2026-05-17T16:00:00+09:00",
|
||||||
|
}
|
||||||
|
state.minute_momentum[ticker] = momentum
|
||||||
|
state.asking_price[ticker] = {
|
||||||
|
"bid_total": int(bid_ratio * 1000),
|
||||||
|
"ask_total": int((1 - bid_ratio) * 1000),
|
||||||
|
"bid_ratio": bid_ratio,
|
||||||
|
"current_price": current_price,
|
||||||
|
"as_of": "2026-05-17T16:00:01+09:00",
|
||||||
|
}
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def _make_state_with_holding(
|
||||||
|
ticker="005930", name="삼성전자",
|
||||||
|
pnl_pct=0.0, avg_price=75000, current_price=75000,
|
||||||
|
):
|
||||||
|
state = PollState()
|
||||||
|
state.portfolio = {"holdings": [{
|
||||||
|
"ticker": ticker, "name": name,
|
||||||
|
"avg_price": avg_price, "current_price": current_price,
|
||||||
|
"pnl_pct": pnl_pct, "profit_rate": pnl_pct * 100,
|
||||||
|
"quantity": 100, "broker": "키움",
|
||||||
|
}]}
|
||||||
|
state.screener_preview = {"items": []}
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dedup_mock():
|
||||||
|
d = MagicMock()
|
||||||
|
d.is_recent.return_value = False
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def test_buy_signal_when_all_conditions_pass_and_confidence_high(dedup_mock):
|
||||||
|
state = _make_state_with_buy_candidate()
|
||||||
|
generate_signals(state, dedup_mock, _settings())
|
||||||
|
assert "005930" in state.signals
|
||||||
|
sig = state.signals["005930"]
|
||||||
|
assert sig["action"] == "buy"
|
||||||
|
assert sig["confidence_webai"] > 0.7
|
||||||
|
dedup_mock.record.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_silent_when_chronos_median_negative(dedup_mock):
|
||||||
|
state = _make_state_with_buy_candidate(chronos_median=-0.01)
|
||||||
|
generate_signals(state, dedup_mock, _settings())
|
||||||
|
assert "005930" not in state.signals
|
||||||
|
|
||||||
|
|
||||||
|
def test_silent_when_distribution_spread_too_wide(dedup_mock):
|
||||||
|
# spread = q90 - q10 = 0.5 - (-0.5) = 1.0 > 0.6 → hard gate fails
|
||||||
|
state = _make_state_with_buy_candidate(
|
||||||
|
chronos_median=0.001, chronos_q10=-0.5, chronos_q90=0.5,
|
||||||
|
)
|
||||||
|
generate_signals(state, dedup_mock, _settings())
|
||||||
|
assert "005930" not in state.signals
|
||||||
|
|
||||||
|
|
||||||
|
def test_silent_when_momentum_not_strong_up(dedup_mock):
|
||||||
|
state = _make_state_with_buy_candidate(momentum="weak_up")
|
||||||
|
generate_signals(state, dedup_mock, _settings())
|
||||||
|
assert "005930" not in state.signals
|
||||||
|
|
||||||
|
|
||||||
|
def test_silent_when_bid_ratio_below_threshold(dedup_mock):
|
||||||
|
state = _make_state_with_buy_candidate(bid_ratio=0.5)
|
||||||
|
generate_signals(state, dedup_mock, _settings())
|
||||||
|
assert "005930" not in state.signals
|
||||||
|
|
||||||
|
|
||||||
|
def test_silent_when_confidence_below_threshold(dedup_mock):
|
||||||
|
# chronos_conf low + rank=20 → confidence < 0.7
|
||||||
|
state = _make_state_with_buy_candidate(chronos_conf=0.3)
|
||||||
|
# add 19 fake items to push 005930 rank to 20
|
||||||
|
state.screener_preview["items"] = (
|
||||||
|
[{"ticker": f"FAKE{i:03d}"} for i in range(19)]
|
||||||
|
+ [{"ticker": "005930", "name": "삼성전자"}]
|
||||||
|
)
|
||||||
|
generate_signals(state, dedup_mock, _settings())
|
||||||
|
# confidence_webai = 0.3*0.5 + 1.0*0.3 + 0.05*0.2 = 0.46 < 0.7
|
||||||
|
assert "005930" not in state.signals
|
||||||
|
|
||||||
|
|
||||||
|
def test_sell_signal_when_stop_loss_triggered(dedup_mock):
|
||||||
|
state = _make_state_with_holding(pnl_pct=-0.08, current_price=69000, avg_price=75000)
|
||||||
|
generate_signals(state, dedup_mock, _settings())
|
||||||
|
assert "005930" in state.signals
|
||||||
|
sig = state.signals["005930"]
|
||||||
|
assert sig["action"] == "sell"
|
||||||
|
assert sig["confidence_webai"] == 1.0
|
||||||
|
assert sig["pnl_pct"] == -0.08
|
||||||
|
|
||||||
|
|
||||||
|
def test_sell_signal_when_take_profit_triggered(dedup_mock):
|
||||||
|
state = _make_state_with_holding(pnl_pct=0.16, current_price=87000, avg_price=75000)
|
||||||
|
generate_signals(state, dedup_mock, _settings())
|
||||||
|
assert "005930" in state.signals
|
||||||
|
sig = state.signals["005930"]
|
||||||
|
assert sig["action"] == "sell"
|
||||||
|
assert sig["confidence_webai"] == 0.6
|
||||||
|
|
||||||
|
|
||||||
|
def test_silent_when_dedup_recently_sent(dedup_mock):
|
||||||
|
state = _make_state_with_buy_candidate()
|
||||||
|
dedup_mock.is_recent.return_value = True
|
||||||
|
generate_signals(state, dedup_mock, _settings())
|
||||||
|
assert "005930" not in state.signals
|
||||||
|
dedup_mock.record.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_sell_signal_triggers_on_anomaly_path(dedup_mock):
|
||||||
|
"""Anomaly sell: median < -1%, momentum strong_down, low bid_ratio, confidence > threshold."""
|
||||||
|
state = PollState()
|
||||||
|
state.portfolio = {"holdings": [{
|
||||||
|
"ticker": "005930", "name": "삼성전자",
|
||||||
|
"avg_price": 75000, "current_price": 70000,
|
||||||
|
"pnl_pct": -0.067, # within stop_loss tolerance (default -0.07): NOT triggering stop_loss
|
||||||
|
"quantity": 100, "broker": "키움",
|
||||||
|
}]}
|
||||||
|
state.screener_preview = {"items": []}
|
||||||
|
state.chronos_predictions["005930"] = {
|
||||||
|
"median": -0.025, "q10": -0.05, "q90": 0.005, "conf": 0.85,
|
||||||
|
}
|
||||||
|
state.minute_momentum["005930"] = "strong_down"
|
||||||
|
state.asking_price["005930"] = {"current_price": 70000, "bid_ratio": 0.30}
|
||||||
|
# bid_ratio 0.30 < (1 - 0.6) = 0.4 → anomaly bid_ratio gate passes
|
||||||
|
# confidence = 0.85*0.5 + 1.0*0.3 + 1.0*0.2 = 0.425 + 0.3 + 0.2 = 0.925 > 0.7
|
||||||
|
|
||||||
|
generate_signals(state, dedup_mock, _settings())
|
||||||
|
|
||||||
|
assert "005930" in state.signals
|
||||||
|
sig = state.signals["005930"]
|
||||||
|
assert sig["action"] == "sell"
|
||||||
|
assert sig["context"]["sell_reason"] == "anomaly"
|
||||||
|
assert sig["confidence_webai"] > 0.7
|
||||||
168
signal_v2/tests/test_stock_client.py
Normal file
168
signal_v2/tests/test_stock_client.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""Tests for stock_client.StockClient."""
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import pytest
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from signal_v2.stock_client import StockClient
|
||||||
|
|
||||||
|
|
||||||
|
BASE_URL = "https://test.stock.local"
|
||||||
|
API_KEY = "test-secret"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_portfolio_normal_returns_dict_with_pnl_pct(mock_stock_api):
|
||||||
|
"""정상 200 응답 + cache 저장."""
|
||||||
|
mock_stock_api.get("/api/webai/portfolio").mock(
|
||||||
|
return_value=httpx.Response(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"holdings": [{"ticker": "005930", "pnl_pct": 0.047}],
|
||||||
|
"cash": [],
|
||||||
|
"summary": {},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
client = StockClient(BASE_URL, API_KEY)
|
||||||
|
try:
|
||||||
|
result = await client.get_portfolio()
|
||||||
|
assert result["holdings"][0]["pnl_pct"] == 0.047
|
||||||
|
# Cache populated
|
||||||
|
assert len(client._cache) >= 1
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_portfolio_uses_cache_within_ttl(mock_stock_api):
|
||||||
|
"""60s TTL 내 두번째 호출 = mock 콜 1회."""
|
||||||
|
route = mock_stock_api.get("/api/webai/portfolio").mock(
|
||||||
|
return_value=httpx.Response(
|
||||||
|
200, json={"holdings": [], "cash": [], "summary": {}}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
client = StockClient(BASE_URL, API_KEY)
|
||||||
|
try:
|
||||||
|
await client.get_portfolio()
|
||||||
|
await client.get_portfolio() # second call within TTL
|
||||||
|
assert route.call_count == 1
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_portfolio_refetches_after_ttl_expiry(mock_stock_api, monkeypatch):
|
||||||
|
"""TTL 만료 후 재호출 = mock 콜 2회. time.monotonic 모킹."""
|
||||||
|
route = mock_stock_api.get("/api/webai/portfolio").mock(
|
||||||
|
return_value=httpx.Response(
|
||||||
|
200, json={"holdings": [], "cash": [], "summary": {}}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Fake clock: starts at 0, jumps to 61 between calls
|
||||||
|
fake_time = [0.0]
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"signal_v2.stock_client.time.monotonic", lambda: fake_time[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
client = StockClient(BASE_URL, API_KEY)
|
||||||
|
try:
|
||||||
|
await client.get_portfolio()
|
||||||
|
fake_time[0] = 61.0 # 60s TTL 만료
|
||||||
|
await client.get_portfolio()
|
||||||
|
assert route.call_count == 2
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_portfolio_retries_3_times_on_timeout(mock_stock_api, monkeypatch):
|
||||||
|
"""timeout 2번 + 200 1번 → 최종 성공. exponential sleep 호출 검증."""
|
||||||
|
sleep_calls = []
|
||||||
|
|
||||||
|
async def fake_sleep(s):
|
||||||
|
sleep_calls.append(s)
|
||||||
|
|
||||||
|
monkeypatch.setattr("asyncio.sleep", fake_sleep)
|
||||||
|
|
||||||
|
mock_stock_api.get("/api/webai/portfolio").mock(
|
||||||
|
side_effect=[
|
||||||
|
httpx.TimeoutException("timeout 1"),
|
||||||
|
httpx.TimeoutException("timeout 2"),
|
||||||
|
httpx.Response(
|
||||||
|
200, json={"holdings": [], "cash": [], "summary": {}}
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
client = StockClient(BASE_URL, API_KEY)
|
||||||
|
try:
|
||||||
|
result = await client.get_portfolio()
|
||||||
|
assert result["holdings"] == []
|
||||||
|
assert sleep_calls == [1, 2] # exponential 1s, 2s
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_portfolio_429_triggers_backoff(mock_stock_api, monkeypatch):
|
||||||
|
"""429 → 1s backoff → 200."""
|
||||||
|
sleep_calls = []
|
||||||
|
|
||||||
|
async def fake_sleep(s):
|
||||||
|
sleep_calls.append(s)
|
||||||
|
|
||||||
|
monkeypatch.setattr("asyncio.sleep", fake_sleep)
|
||||||
|
|
||||||
|
mock_stock_api.get("/api/webai/portfolio").mock(
|
||||||
|
side_effect=[
|
||||||
|
httpx.Response(429, text="rate limit"),
|
||||||
|
httpx.Response(
|
||||||
|
200, json={"holdings": [], "cash": [], "summary": {}}
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
client = StockClient(BASE_URL, API_KEY)
|
||||||
|
try:
|
||||||
|
result = await client.get_portfolio()
|
||||||
|
assert result["holdings"] == []
|
||||||
|
assert sleep_calls == [1]
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_portfolio_falls_back_to_stale_on_all_failures(
|
||||||
|
mock_stock_api, monkeypatch, caplog
|
||||||
|
):
|
||||||
|
"""cache 에 이전 성공 응답 + 모든 retry 5xx → stale 반환 + logger.warning."""
|
||||||
|
# No-op sleep for fast test
|
||||||
|
async def fake_sleep(s):
|
||||||
|
return None
|
||||||
|
monkeypatch.setattr("asyncio.sleep", fake_sleep)
|
||||||
|
|
||||||
|
# Patch time.monotonic BEFORE first call so cached timestamp uses fake clock
|
||||||
|
fake_time = [0.0]
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"signal_v2.stock_client.time.monotonic", lambda: fake_time[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
# First call succeeds
|
||||||
|
route1 = mock_stock_api.get("/api/webai/portfolio").mock(
|
||||||
|
return_value=httpx.Response(
|
||||||
|
200,
|
||||||
|
json={"holdings": [{"ticker": "005930"}], "cash": [], "summary": {}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
client = StockClient(BASE_URL, API_KEY)
|
||||||
|
try:
|
||||||
|
first = await client.get_portfolio()
|
||||||
|
assert first["holdings"][0]["ticker"] == "005930"
|
||||||
|
|
||||||
|
# Advance fake clock past TTL (60s) so cache is stale
|
||||||
|
fake_time[0] = 61.0
|
||||||
|
|
||||||
|
# Now mock to return 500s persistently
|
||||||
|
route1.mock(return_value=httpx.Response(500, text="server error"))
|
||||||
|
|
||||||
|
with caplog.at_level(logging.WARNING, logger="signal_v2.stock_client"):
|
||||||
|
result = await client.get_portfolio()
|
||||||
|
assert result["holdings"][0]["ticker"] == "005930" # stale data returned
|
||||||
|
assert any(
|
||||||
|
"stale" in rec.message.lower() for rec in caplog.records
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
Reference in New Issue
Block a user