Compare commits
75 Commits
ad2c65c2b2
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| cb70226f42 | |||
| de24bae984 | |||
| 0e6c893b4e | |||
| fb80973e38 | |||
| 31b0e7dbc4 | |||
| 6169f48eb8 | |||
| 27a6df6cff | |||
| 803fdb6278 | |||
| 77e21b54e6 | |||
| 4d0c89ce79 | |||
| 4b60ab34c3 | |||
| 53a0657027 | |||
| 91f01d126b | |||
| 0702cf052f | |||
| 8aa3f1c3b2 | |||
| 4db0551d33 | |||
| 4d837fdd31 | |||
| 2567a6f10b | |||
| 17ed1943f1 | |||
| 8d246b5b32 | |||
| b4bec9d51b | |||
| f32792e4a9 | |||
| f152545d3b | |||
| bf3d6ee694 | |||
| 44bc065796 | |||
| 9127616669 | |||
| 900f45c2ff | |||
| eb34cbc0f7 | |||
| 0de09613d2 | |||
| a5274a4fa7 | |||
| 4e72f8ca2e | |||
| 44c6811352 | |||
| 9eef2c5015 | |||
| b05e5714e3 | |||
| c8793cc3cf | |||
| 11e73f6960 | |||
| f1fc3e1102 | |||
| e0e56090ee | |||
| e0269bae39 | |||
| bee0add9dd | |||
| 1adf91a19b | |||
| 26ef660c75 | |||
| 139e4e3382 | |||
| bb03cc4525 | |||
| 71ef959310 | |||
| 2aa9f48ea3 | |||
| cc6310d72f | |||
| e574074ca8 | |||
| b9def06993 | |||
| 05ab2846bb | |||
| 760f914d3b | |||
| 8eefe9d79d | |||
| 91de16675b | |||
| 44888d6ede | |||
| 9e5fecb369 | |||
| 28f9c8c3a6 | |||
| c5a88fab66 | |||
| 7056cf2fa6 | |||
| 4ac7da8670 | |||
| b690900cfc | |||
| d85512d036 | |||
| 3ebe95ba29 | |||
| 163c9fb690 | |||
| 27bf360b01 | |||
| eafa73edb1 | |||
| 68eb7b073c | |||
| 8342d38935 | |||
| e47947fb69 | |||
| 94c684bab8 | |||
| 1a6d9fcb39 | |||
| 6cb5085118 | |||
| fdabc69004 | |||
| 90235497ae | |||
| 8469bf7ffa | |||
| 8a2fac03a6 |
25
.gitignore
vendored
25
.gitignore
vendored
@@ -47,9 +47,11 @@ daily_trade_history.json
|
||||
watchlist.json
|
||||
bot_ipc.json
|
||||
|
||||
# Test
|
||||
# Test (top-level only; ai_trade/tests tracked separately)
|
||||
tests/
|
||||
tests/*
|
||||
!ai_trade/tests/
|
||||
!ai_trade/tests/**
|
||||
|
||||
# System
|
||||
Thumbs.db
|
||||
@@ -59,3 +61,24 @@ Desktop.ini
|
||||
KIS_SETUP.md
|
||||
# Claude Code subagent state
|
||||
.claude/
|
||||
|
||||
# Signal V2 runtime data
|
||||
ai_trade/data/*.db
|
||||
ai_trade/data/*.db-*
|
||||
|
||||
# Plan-B-Insta services 예외 (코드는 추적, .env는 무시 유지)
|
||||
!services/
|
||||
!services/**/
|
||||
!services/**/*.py
|
||||
!services/**/Dockerfile
|
||||
!services/**/requirements.txt
|
||||
!services/**/.env.example
|
||||
!services/**/*.j2
|
||||
!services/**/*.html
|
||||
!services/**/*.css
|
||||
!services/**/.gitkeep
|
||||
!services/**/pytest.ini
|
||||
!services/docker-compose.yml
|
||||
# 단 실 .env는 무시 유지
|
||||
services/**/.env
|
||||
services/.env
|
||||
|
||||
277
CHECK_POINT.md
Normal file
277
CHECK_POINT.md
Normal file
@@ -0,0 +1,277 @@
|
||||
# web-ai CHECK_POINT
|
||||
|
||||
> Windows AI Server (192.168.45.59), AMD 9800X3D + RTX 5070 Ti (16GB VRAM).
|
||||
> V1(LSTM 레거시) + V2(Chronos-2 signal pipeline) 이중 구조.
|
||||
> 2026-05-18 작성.
|
||||
|
||||
## 🚀 2026-05-18 박재오 7 결정 — Windows 컴퓨팅 노드 신설 (1주 작업)
|
||||
|
||||
박재오 결정 7건 완료. 상세 가이드: `Obsidian Vault/raw/2026-05-18-Windows-NAS-아키텍처-7결정-통합.md`
|
||||
|
||||
### 결정 6 — 옵션 4 하이브리드 운영
|
||||
|
||||
```
|
||||
[Windows AI Server]
|
||||
|
||||
🔵 Native Python (NSSM 자동 시작, HIGH priority)
|
||||
├─ ai_trade (트레이딩 :8001) ⭐ 절대 우선
|
||||
├─ Ollama qwen3:14b (:11435)
|
||||
└─ MusicGen (:8765)
|
||||
|
||||
🟢 WSL2 + Docker Engine (Docker Desktop X, 라이선스·메모리 ↓)
|
||||
├─ insta-render (:18710) ⭐ NEW
|
||||
├─ music-render (:8771) ⭐ NEW
|
||||
├─ video-render (:18712) ⭐ NEW (외부 API 게이트웨이)
|
||||
└─ task-watcher (옵션 d 작업 감지)
|
||||
```
|
||||
|
||||
### Day 2 — WSL2 + Docker Engine 설치 ⭐ (2시간)
|
||||
|
||||
```powershell
|
||||
# 관리자 PowerShell
|
||||
wsl --install -d Ubuntu-22.04
|
||||
# 재부팅 후 Ubuntu 초기 설정
|
||||
|
||||
wsl -d Ubuntu-22.04
|
||||
# 안에서 Docker Engine 설치
|
||||
sudo apt update && sudo apt install -y ca-certificates curl gnupg
|
||||
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg
|
||||
echo "deb [arch=amd64 signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list
|
||||
sudo apt update && sudo apt install -y docker-ce docker-compose-plugin
|
||||
sudo usermod -aG docker $USER
|
||||
sudo systemctl enable docker
|
||||
sudo systemctl start docker
|
||||
|
||||
# Tailscale (NAS와 같은 LAN 확인)
|
||||
curl -fsSL https://tailscale.com/install.sh | sh
|
||||
sudo tailscale up
|
||||
```
|
||||
|
||||
- [ ] WSL2 + Ubuntu 22.04 설치
|
||||
- [ ] Docker Engine 설치 (Desktop X)
|
||||
- [ ] Tailscale 가입
|
||||
|
||||
### Day 3~4 — insta-render 컨테이너 신설 ⭐ (4시간)
|
||||
|
||||
**디렉토리**: `C:\Users\jaeoh\Desktop\workspace\web-ai-services\insta-render\`
|
||||
|
||||
**Dockerfile**: Playwright + Python 3.12 (raw 파일에 풀 코드)
|
||||
|
||||
**main.py**: Redis 큐 worker + Playwright Browser pool + Semaphore(1)
|
||||
|
||||
- [ ] 디렉토리 생성 + Dockerfile + main.py + requirements.txt
|
||||
- [ ] `docker compose up -d insta-render`
|
||||
- [ ] 테스트: NAS Redis에서 작업 push → Windows에서 처리 확인
|
||||
|
||||
### Day 6 — NSSM 자동 시작 ⭐ (1시간)
|
||||
|
||||
**NSSM 다운로드**: https://nssm.cc/download
|
||||
|
||||
```powershell
|
||||
# 관리자 PowerShell
|
||||
|
||||
# 트레이딩 자동 시작 (HIGH priority — 결정 b)
|
||||
nssm install ai_trade "C:\Python312\python.exe" "-m uvicorn main:app --host 0.0.0.0 --port 8001"
|
||||
nssm set ai_trade AppDirectory "C:\Users\jaeoh\Desktop\workspace\web-ai\ai_trade"
|
||||
nssm set ai_trade Priority HIGH_PRIORITY_CLASS
|
||||
nssm set ai_trade AppStartup AUTO
|
||||
|
||||
# WSL2 + Docker 자동 시작
|
||||
nssm install wsl_docker "wsl" "-d Ubuntu-22.04 -- sudo service docker start && cd /workspace/web-ai-services && docker compose up -d"
|
||||
nssm set wsl_docker AppStartup AUTO
|
||||
|
||||
# 시작
|
||||
nssm start ai_trade
|
||||
nssm start wsl_docker
|
||||
```
|
||||
|
||||
- [ ] NSSM 다운로드 + 압축 해제 + PATH 추가
|
||||
- [ ] ai_trade service 등록 (HIGH priority)
|
||||
- [ ] wsl_docker service 등록
|
||||
- [ ] 재부팅 후 자동 시작 확인
|
||||
|
||||
### Day 7 — 작업 감지 (옵션 d) — task-watcher ⭐ (2시간)
|
||||
|
||||
박재오 작업 중 → Redis `queue:paused` 플래그 → 워커 일시정지. 트레이딩은 영향 X.
|
||||
|
||||
선택지:
|
||||
- **A. 자동 (Python pynput + PowerShell API)** — 마우스·게임 process 감지, 자동 토글
|
||||
- **B. 수동 토글** — 박재오님이 작업 시작 시 `redis-cli SET queue:paused 1`, 종료 시 `DEL`
|
||||
- **C. NAS frontend에 토글 UI 1개** — 클릭 한 번
|
||||
|
||||
→ 시작은 **B 수동 토글** (구현 0, 즉시 가능), 나중에 A 자동화로 진화.
|
||||
|
||||
- [ ] B 수동 토글 명령어 확인
|
||||
- [ ] 또는 A Python pynput 자동 감지 구현 (선택, 2시간)
|
||||
|
||||
### Day 5 — music-render 컨테이너 (선택 — MusicGen 패턴 정착)
|
||||
|
||||
기존 NAS music-lab → Windows MusicGen 호출 패턴 이미 운영 중. 표준화만:
|
||||
- [ ] Redis 큐 사용으로 전환 (HTTP 직접 호출 X)
|
||||
- [ ] Browser pool 같은 패턴 적용 (Suno + MusicGen 동시 1개)
|
||||
|
||||
### Day 5 — video-render 컨테이너 (선택 — 영상 생성 결정 4)
|
||||
|
||||
외부 영상 API 6개 게이트웨이 (Runway·Sora·Veo·Pika·Kling·Luma):
|
||||
- [ ] 박재오 자금·품질 판단 후 1~2개 가입
|
||||
- [ ] `.env`에 API 키 추가
|
||||
- [ ] video-render Docker 컨테이너 신설
|
||||
|
||||
---
|
||||
|
||||
## 🔥 2026-05-18 추가 — NAS API 부하 진짜 원인 발견
|
||||
|
||||
박재오 발견: 5건 + 중기 2건 적용 후 **web-ai V1+V2 4 process 종료가 NAS CPU 가장 큰 즉시 감소**.
|
||||
→ 진짜 병목은 **web-ai → NAS stock(:18500) 인바운드 API 호출 빈도**.
|
||||
|
||||
상세: `Obsidian Vault/raw/2026-05-18-NAS-Window-AI-API-부하-해결방안.md`
|
||||
|
||||
### 🔴 추가 즉시 작업 (50분으로 70% 부담 감소)
|
||||
|
||||
#### A. 캐시 TTL 대폭 증가 ⭐⭐⭐ (10분)
|
||||
**파일**: `ai_trade/stock_client.py`
|
||||
```python
|
||||
# 변경 전 → 변경 후
|
||||
PORTFOLIO_TTL = 60 → 180 # 3분
|
||||
NEWS_TTL = 300 → 600 # 10분
|
||||
SCREENER_TTL = 60 → 300 # 5분
|
||||
```
|
||||
- [ ] 3 TTL 상수 증가
|
||||
- 효과: 분당 12 호출 → 3~4 호출 (70% 감소)
|
||||
|
||||
#### B. V1·V2 단일화 결정 ⭐⭐ (10분 결정)
|
||||
- 동시 운영 시 NAS API 부담 2배 + KIS rate limit 충돌
|
||||
- **권장**: V1 폐기 + V2 단독 (Phase 4 자산 활용)
|
||||
- 또는 V2 임시 종료 + V1 유지 → Phase 5 진입 시 V1 폐기
|
||||
- [ ] 박재오 결정
|
||||
- [ ] 선택 안 된 쪽 `legacy/` 또는 `.disabled`
|
||||
- [ ] start.bat 한 쪽만
|
||||
|
||||
#### C. (NAS 측) stock TTLCache — web-backend CHECK_POINT.md #13 참고
|
||||
- 박재오가 web-backend/stock에서 별도 적용
|
||||
|
||||
## 🟢 현재 상태
|
||||
|
||||
- **V2 Phase 4 완료** (5/17, 56/56 tests pass, main push)
|
||||
- Chronos-2 zero-shot 1일 수익률 + 분봉 모멘텀 5단계
|
||||
- 09:00~15:30 매 1분 / 16:00 일봉 추론
|
||||
- Sell-first 우선순위 (stop_loss · anomaly · take_profit) + 매수 hard gate
|
||||
- Confidence = chronos×0.5 + momentum×0.3 + screener×0.2
|
||||
|
||||
---
|
||||
|
||||
## 🔴 즉시 (이번 주)
|
||||
|
||||
### 1. V1 vs V2 KIS rate limit 충돌 해결
|
||||
- **현재**: V1 + V2 동시 실행 시 KIS EGW00201 (초당 2회 제한) 충돌
|
||||
- **임시 해결**: V2 종료 상태 (현재 V1만 운영 중)
|
||||
- **결정 필요**: V1 deprecation 시점 (Phase 6)
|
||||
- [ ] 박재오 결정 — V1 폐기 일자 (예: 5/20 / 5/31)
|
||||
- [ ] Phase 5 진입 전 V1 정리
|
||||
|
||||
### 2. `.venv` 한글 경로 문제 해결
|
||||
- 가상환경 한글 경로 깨짐 → 시스템 Python 사용 강제
|
||||
- 다른 개발 머신 협업 시 문제
|
||||
- [ ] `.venv`를 영문 경로로 이전 또는 시스템 Python으로 통일
|
||||
- [ ] start.bat에 경로 명시
|
||||
|
||||
### 3. `state.signals` consumer-drain protocol 정의
|
||||
- Phase 5 prereq
|
||||
- 신호가 누적되기만 하고 소비 로직 미정의
|
||||
- [ ] consumer (예: agent-office /signal endpoint)가 처리한 신호 marking
|
||||
- [ ] 24h 만료 dedup 외에 *처리됨* 상태 추가
|
||||
|
||||
---
|
||||
|
||||
## 🟡 중기 (1~2주, Phase 5)
|
||||
|
||||
### 4. agent-office `/signal` 엔드포인트 통합
|
||||
- web-ai V2가 매수/매도 신호 생성 → agent-office로 push
|
||||
- agent-office가 텔레그램 발송 + 사용자 결정 대기
|
||||
- [ ] V2에서 agent-office HTTP POST 호출 추가
|
||||
- [ ] payload: ticker, action, confidence, reasoning, chronos_quantile
|
||||
|
||||
### 5. Ollama Qwen3 14B 통합 (Windows 로컬)
|
||||
- 신호 *해석* 레이어 — 단순 규칙 결과 → LLM 자연어 설명
|
||||
- 9.3GB VRAM 사용 (Chronos 7GB와 동시 가능, 15.5GB)
|
||||
- [ ] Ollama Windows 설치 (이미 실행 중인지 확인)
|
||||
- [ ] `state.signals` 큐에서 신호 pop → Qwen3 prompt → 결과 add
|
||||
- [ ] 텔레그램 전송 시 LLM 해석 텍스트 포함
|
||||
|
||||
### 6. 이중 텔레그램 전송
|
||||
- 현재: V1만 텔레그램 발송 (Telegram Bot + KIS 자동주문)
|
||||
- Phase 5: V2도 별도 chat_id로 발송 (박재오 본인용 + 검증용)
|
||||
- [ ] V2 텔레그램 chat_id 환경변수 (`TELEGRAM_V2_CHAT_ID`)
|
||||
- [ ] V1·V2 메시지 톤 차별화 (V1 = 자동주문 / V2 = 신호 알림)
|
||||
|
||||
### 7. holidays.json 자동 동기화
|
||||
- 현재: NAS에서 수동 copy
|
||||
- 한국 휴장일 누락 시 V2 폴링 실수 (휴장일에도 KIS 호출)
|
||||
- [ ] NAS realestate-lab 또는 별도 컨테이너에서 휴장일 자동 발급
|
||||
- [ ] V2가 부팅 시·매일 00:00에 GET 갱신
|
||||
|
||||
---
|
||||
|
||||
## 🟢 장기 (1개월+, Phase 6+)
|
||||
|
||||
### 8. V1 완전 deprecation
|
||||
- LSTM 7-feature 모델 + main_server.py 폐기
|
||||
- 모든 자동매매 V2로 통일
|
||||
- [ ] V1 종료 일자 박재오 결정
|
||||
- [ ] V1 코드 `legacy/` 폴더로 이동 (read-only)
|
||||
|
||||
### 9. Chronos-2 모델 미세조정 검토
|
||||
- 현재 zero-shot. 한국 주식 데이터로 미세조정 시 정확도 ↑ 가능?
|
||||
- 박재오 자체 학습 데이터 (KIS 1년치) → finetune
|
||||
- [ ] 데이터 수집·전처리
|
||||
- [ ] LoRA 또는 full finetune 결정
|
||||
|
||||
### 10. KIS WebSocket 실 운영 검증
|
||||
- 현재 코드는 있으나 검증 부족
|
||||
- 실시간 호가가 1분 폴링보다 빠른 신호 확보 가능
|
||||
- [ ] 1주일 운영 후 latency·드롭 측정
|
||||
|
||||
---
|
||||
|
||||
## ✅ 최근 완료 (참고)
|
||||
|
||||
- 2026-05-17: emit/skip 로깅 추가 (`2aa9f48`)
|
||||
- 2026-05-17: signal_generator poll_loop 통합 — Phase 4 완료 (`cc6310d`)
|
||||
- 2026-05-16: 코드 리뷰 수정 — sell-first 순서, anomaly 테스트 (`e574074`)
|
||||
- 2026-05-16: signal_generator 초안 — 9개 unit test (`b9def06`)
|
||||
- 2026-05-15: Foundation — 6개 env 임계값 + state.signals 필드 (`05ab284`)
|
||||
- 2026-05-14: FP32 강제 — Chronos FP16 overflow 회피 (`760f914`)
|
||||
|
||||
---
|
||||
|
||||
## 🔧 운영 커맨드 (Windows PowerShell)
|
||||
|
||||
```powershell
|
||||
# V2 시작
|
||||
cd C:\Users\jaeoh\Desktop\workspace\web-ai\ai_trade
|
||||
.\start.bat
|
||||
|
||||
# 테스트 실행 (56 tests)
|
||||
pytest tests/ -v
|
||||
|
||||
# 로그 확인 (V2 실시간)
|
||||
Get-Content logs/ai_trade.log -Wait
|
||||
|
||||
# Chronos 메모리 사용 확인
|
||||
nvidia-smi
|
||||
|
||||
# KIS API 헬스 (REST)
|
||||
curl http://localhost:8001/health
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📚 참고
|
||||
|
||||
- 위키: [[자산-주식-자동매매]] (V3.1 → V2 Phase 4 정정 필요)
|
||||
- NAS stock 연동: `192.168.45.54:18500` (X-WebAI-Key 인증)
|
||||
- CLAUDE.md (본 디렉토리 루트): Phase 진행도 + 환경변수 명세
|
||||
|
||||
## 변경 이력
|
||||
|
||||
- 2026-05-18: 페이지 신설. 즉시 3건 (V1/V2 충돌, .venv 한글, consumer protocol) + 중기 4건 (Phase 5) + 장기 3건 (V1 deprecation·Chronos finetune·WebSocket).
|
||||
141
CLAUDE.md
141
CLAUDE.md
@@ -1,24 +1,139 @@
|
||||
# web-ai — Workspace 가이드
|
||||
|
||||
Windows AI 머신 (AMD 9800X3D + RTX 5070 Ti) 의 두 시그널 파이프라인 컨테이너.
|
||||
Windows AI 머신 (AMD 9800X3D + RTX 5070 Ti 16GB) 의 두 신호 파이프라인.
|
||||
**Confidence Signal Pipeline V2 의 Windows-side 구현체** (NAS stock 백엔드와 HTTP 연동).
|
||||
|
||||
상위 워크스페이스 컨텍스트는 `../CLAUDE.md` 참조.
|
||||
|
||||
---
|
||||
|
||||
## 디렉토리 구조
|
||||
|
||||
| 경로 | 역할 | 상태 |
|
||||
|------|------|------|
|
||||
| `signal_v1/` | V1 자체 자동매매 시스템 (main_server.py + Trading Bot + Telegram Bot + LSTM + Ollama + KIS 자동주문) | 운영 중. Confidence Signal Pipeline V2 Phase 6 에서 deprecation 예정 |
|
||||
| `signal_v2/` | V2 신호 파이프라인 (stock pull worker + Chronos-2 + signal API client) | Phase 2 에서 신설 |
|
||||
| `.env` | V1 + V2 환경변수 공유 | KIS_*, TELEGRAM_*, STOCK_API_URL, WEBAI_API_KEY 등 |
|
||||
| `start.bat` | V1 진입 (signal_v1 디렉토리 안 main_server.py 실행) | V2 별도 start 스크립트는 signal_v2/start.bat |
|
||||
| 경로 | 역할 | 포트 | 상태 |
|
||||
|------|------|------|------|
|
||||
| `signal_v1/` | ⚠️ **DEPRECATED 2026-05-19** — 레거시 LSTM 봇. 사용 안 함. `legacy/signal_v1/`로 이동 완료 (2026-05-19) | `:8000` | **OFF** |
|
||||
| `ai_trade/` | 자동매매 메인 (구 `signal_v2` 2026-05-19 rename) — Chronos-bolt + 분봉 모멘텀 + KIS WebSocket + 신호 생성 | `:8001` | **Phase 4 완료 (2026-05-17)**, Phase 5 대기 |
|
||||
| `legacy/start_v1.bat` | (deprecated) V1 진입점 — root `start.bat`에서 이동됨. 자동 실행 차단 | — | **OFF** |
|
||||
| `ai_trade/start.bat` | 자동매매 진입점 | — | `ai_trade/main.py` uvicorn 실행 |
|
||||
| `services/` | (예정) NAS↔Windows 분산 worker — insta-render·music-render·video-render·task-watcher | 18710~ | **Plan-B-Insta 작업 중** |
|
||||
| `.env` | 환경변수 (`KIS_REAL_*`, `TELEGRAM_*`, `STOCK_API_URL`, `WEBAI_API_KEY`, `LOG_LEVEL`) | — | |
|
||||
| `requirements.txt` | 공용 의존성 | — | torch, chronos-forecasting, fastapi, httpx, websockets 등 |
|
||||
|
||||
## 운영 가이드
|
||||
`.venv` 는 **구조적으로 깨짐**: `pyvenv.cfg` 가 한글 사용자 경로(`C:\Users\박재오\...`) 를 포함하여 콘솔 코드페이지가 roundtrip 못함. 테스트는 시스템 Python 으로 실행: `C:\Users\jaeoh\AppData\Local\Programs\Python\Python312\python.exe -m pytest ai_trade/tests -q`.
|
||||
|
||||
- V1 시작: `start.bat` 또는 `cd signal_v1 && python main_server.py`
|
||||
- V2 시작 (Phase 2 이후): `cd signal_v2 && python -m uvicorn main:app --port 8001`
|
||||
- 둘 다 동시 실행 가능 (포트 분리: V1=8000, V2=8001)
|
||||
---
|
||||
|
||||
## 서버 시작 방식
|
||||
|
||||
### V1 (⚠️ DEPRECATED — 운영 X)
|
||||
2026-05-19부터 자동 시작 차단. `legacy/start_v1.bat`에 보존 (참고용만).
|
||||
별도 backtest 등 1회성 시 필요 시 박재오 직접 `legacy/start_v1.bat` 실행.
|
||||
|
||||
### ai_trade 단독 (smoke/검증)
|
||||
```bat
|
||||
cd C:\Users\jaeoh\Desktop\workspace\web-ai\ai_trade
|
||||
.\start.bat
|
||||
```
|
||||
기대 로그: `Uvicorn running on http://0.0.0.0:8001`, `poll_loop started`, `[KIS] minute bars ... OK`, `[Chronos] predicted N tickers`, `signal emit XXXXXX buy conf=0.xxx`.
|
||||
|
||||
휴장일/장 외 시간엔 `poll_loop` 만 idle. `Application startup complete` 만 보이면 정상.
|
||||
|
||||
### V1 + V2 동시 실행 — **권장 안 함**
|
||||
**KIS app_key 초당 2회 한도 (EGW00201)** 충돌. V1 cycle + V2 분봉 cron 이 같은 KIS app_key 로 동시 호출하면 rate limit. 채택 해결책: V2 임시 종료 (Phase 3a 결정), Phase 6 V1 deprecation 시 자연 해소. 별도 app_key 발급은 옵션 B.
|
||||
|
||||
---
|
||||
|
||||
## Phase 진행 상태 (Confidence Signal Pipeline V2)
|
||||
|
||||
`web-ui/docs/superpowers/specs/2026-05-15-confidence-signal-pipeline-v2-architecture.md` 참조.
|
||||
| Phase | 내용 | 상태 |
|
||||
|-------|------|------|
|
||||
| 0 | Architecture & contract spec | ✅ Chronos-2 + Qwen3 14B 채택 |
|
||||
| 1 | stock 백엔드 WebAI API 보강 (NAS) | ✅ 102/102 tests, 운영 배포 |
|
||||
| 1.5 | V1 → `signal_v1/` rename | ✅ V1 정상 기동 |
|
||||
| 2 | ai_trade pull worker + signal API client + scheduler | ✅ 19/19 tests, `:8001` 기동 |
|
||||
| 3a | KIS REST 분봉 + WebSocket 호가 + NXT 스케줄 | ✅ 33/33 tests |
|
||||
| 3b | Chronos-bolt-base 추론 + 5분봉 모멘텀 분류기 | ✅ 45/45 tests, 실 KIS+Chronos chain 검증 |
|
||||
| 4 | Signal Generator (매수/매도 룰) + pull_worker 통합 + 로깅 | ✅ **2026-05-17 완료, 56/56 tests, push 완료** |
|
||||
| 5 | agent-office `/signal` + Ollama Qwen3 14B + 이중 텔레그램 | ⏳ 2주 예상 |
|
||||
| 6 | signal_v1 deprecation | ⏳ 1주 |
|
||||
| 7 | 운영 모니터링 + 4주 IC 검증 | ⏳ 1주 + 4주 |
|
||||
|
||||
자세한 V1 가이드는 `signal_v1/CLAUDE.md` 참조.
|
||||
상세 spec/plan: `../web-ui/docs/superpowers/specs/` 및 `../web-ui/docs/superpowers/plans/` (web-ui repo 안에 보관됨 — V2 자체 코드와 분리 보관).
|
||||
|
||||
---
|
||||
|
||||
## ai_trade 디렉토리 내부
|
||||
|
||||
| 파일 | 역할 |
|
||||
|------|------|
|
||||
| `main.py` | FastAPI app + lifespan (StockClient + KISClient + KISWebSocket + ChronosPredictor + SignalDedup 초기화). poll_loop task 생성 |
|
||||
| `config.py` | Settings dataclass — 환경변수 로드. Phase 4 추가 6 필드: `stop_loss_pct`, `take_profit_pct`, `chronos_spread_threshold`, `asking_bid_ratio_threshold`, `confidence_threshold`, `min_momentum_for_buy` |
|
||||
| `state.py` | PollState (process-wide singleton) — portfolio, screener_preview, news_sentiment, chronos_predictions, minute_bars, asking_price, **signals** (Phase 4) |
|
||||
| `stock_client.py` | NAS stock 백엔드 pull (X-WebAI-Key + 메모리 cache 60s/300s/60s + retry) |
|
||||
| `kis_client.py` | KIS REST 분봉/호가 — V1 토큰 read-only 공유 (mtime cache) + 초당 2회 throttle + 지수 backoff |
|
||||
| `kis_websocket.py` | KIS WebSocket H0STASP0 호가 + approval_key + 재연결 (1→2→4→max 30s) |
|
||||
| `chronos_predictor.py` | `amazon/chronos-bolt-base` zero-shot quantile (FP32 강제 — FP16 overflow 회피) |
|
||||
| `minute_momentum.py` | 5분봉 → strong_up/weak_up/neutral/weak_down/strong_down 5단계 분류 |
|
||||
| `signal_generator.py` | **Phase 4 — 매수/매도 룰 엔진**. `generate_signals(state, dedup, settings)` 진입. sell-first → buy 순서. 신호 emit/skip INFO/DEBUG 로그 |
|
||||
| `pull_worker.py` | asyncio cron — 장전 5분 / 장중 1분 / 장후 5분 / NXT / dead zone skip. cycle 끝에 `generate_signals` 호출 |
|
||||
| `scheduler.py` | polling window 판정 (KST 캘린더 + 휴장일) |
|
||||
| `rate_limit.py` | 초당 N회 token bucket |
|
||||
| `dedup.py` | SignalDedup SQLite WAL — `(ticker, action)` PK 24h |
|
||||
| `tests/` | 56 tests (pytest + respx HTTP mock + monkeypatch) |
|
||||
| `data/` | dedup.db (SQLite WAL) + `holidays.json` (NAS stock 에서 manual copy) |
|
||||
| `start.bat` | V2 진입 |
|
||||
|
||||
---
|
||||
|
||||
## 신호 룰 요약 (Phase 4)
|
||||
|
||||
### 매수 (screener Top-N + portfolio, sell 신호 받은 종목은 skip)
|
||||
모두 충족:
|
||||
1. `chronos.median > 0`
|
||||
2. **`chronos.q90 - chronos.q10 < 0.6`** (absolute spread — 2026-05-17 spec amend, 기존 relative formula 가 zero-shot median≈0 빈번에서 모든 신호 거부)
|
||||
3. `minute_momentum == strong_up` (env 로 조정 가능)
|
||||
4. `asking_price.bid_ratio >= 0.6`
|
||||
|
||||
종합 confidence = `chronos_conf * 0.5 + minute_score * 0.3 + screener_norm * 0.2`. `> 0.7` 시 emit.
|
||||
|
||||
### 매도 (portfolio only, 우선순위 stop_loss → anomaly → take_profit)
|
||||
- **stop_loss**: `pnl_pct < -7%` 즉시 (confidence=1.0)
|
||||
- **anomaly**: `chronos.median < -1%` + `strong_down` + `bid_ratio < 0.4` + 종합 conf > 0.7
|
||||
- **take_profit**: `pnl_pct > 15%` 검토 (confidence=0.6)
|
||||
|
||||
---
|
||||
|
||||
## 알려진 함정 / Phase 7 백로그
|
||||
|
||||
1. **KIS rate limit (EGW00201)** — V1+V2 동시 실행 시 충돌. Phase 6 자연 해소
|
||||
2. **`.venv` 한글 경로 깨짐** — 시스템 Python 사용
|
||||
3. **Chronos FP16 overflow** — 한국 주가 5만+ 시 inf. FP32 강제 (`chronos_predictor.py:39-41`)
|
||||
4. **`predict_quantiles` positional `inputs`** — ChronosBolt API 새 변경. `try/except TypeError` fallback 처리됨
|
||||
5. **`state.signals` consumer-drain protocol 미정의** — Phase 5 prereq. dict 무한 누적 위험 (실제로는 bounded by unique ticker count)
|
||||
6. **integration test 가 poll_loop 실제 호출 안 함** — `test_pull_worker.py:test_poll_loop_calls_generate_signals_after_cycle` 가 `generate_signals` 직접 호출. Phase 7 hardening 시 mock-iteration 으로 강화
|
||||
7. **KIS WebSocket URL `ws://ops.koreainvestment.com:21000/31000`** — 첫 운영 시 실제 KIS API docs 와 대조 필요
|
||||
8. **`_parse_asking_price` 필드 인덱스** — 마지막 2 필드 가정. 실 운영 raw 메시지 캡처 후 매핑 검증 필요
|
||||
9. **`holidays.json` 자동 동기화 부재** — NAS stock 의 `holidays.json` 을 수동 copy
|
||||
10. **schema rename** — Phase 0 §5.2 의 `lstm_pred_*`, `news_top[]` 는 `chronos_pred_*`, `news_reason(string)` 으로 변경됨. Phase 5 prompt 작성 시 반영
|
||||
11. **6개 env 필드가 `.env` 에 미기재** — 기본값으로 동작 가능하나 discoverability 위해 `.env.example` 또는 commented block 추가 권장
|
||||
|
||||
---
|
||||
|
||||
## 다음 단계 (Phase 5 진입 시 brainstorming 주제)
|
||||
|
||||
- `state.signals` consumer 패턴: pop vs leave + Phase 5 자체 dedup
|
||||
- agent-office 의 `/signal` endpoint 설계 — POST 페이로드 schema
|
||||
- Ollama Qwen3 14B Q4 로컬 호출 — 타임아웃, retry, VRAM 공존 (Chronos + Qwen3 동시 메모리 9.3GB / 15.5GB 가용)
|
||||
- 이중 텔레그램 (본인 풀 / 아내 lite) — context augmentation 단일 호출에서 양쪽 메시지 생성
|
||||
- LLM 비용: ₩0 목표 유지 (로컬)
|
||||
|
||||
---
|
||||
|
||||
## 양쪽 디렉토리 (web-ui ↔ web-ai) 작업 시 주의
|
||||
|
||||
- **코드**: ai_trade 는 web-ai/, spec/plan/메모리는 web-ui/
|
||||
- **커밋**: `web-ai` 와 `web-ui` 는 **별도 Gitea 저장소**. 각각 경로에서만 `git add/commit/push`
|
||||
- **메모리**: Claude Code 의 auto-memory 는 디렉토리별 격리. 핵심 reference 는 양쪽에 미러됨 (`./memory-mirror/` 또는 `~/.claude/projects/C--Users-jaeoh-Desktop-workspace-web-ai/memory/`)
|
||||
- **spec amendment 발생 시**: 코드는 `web-ai` 에 commit, spec 갱신은 `web-ui/docs/superpowers/specs/` 에 commit (Phase 4 spread formula 변경 사례 = web-ui commit `534ded5`)
|
||||
|
||||
자세한 V1 가이드는 `signal_v1/CLAUDE.md` 참조 (있다면).
|
||||
|
||||
0
ai_trade/__init__.py
Normal file
0
ai_trade/__init__.py
Normal file
132
ai_trade/chronos_predictor.py
Normal file
132
ai_trade/chronos_predictor.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Chronos-2 zero-shot forecaster wrapper."""
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
KST = ZoneInfo("Asia/Seoul")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChronosPrediction:
|
||||
median: float
|
||||
q10: float
|
||||
q90: float
|
||||
conf: float
|
||||
as_of: str
|
||||
|
||||
|
||||
class ChronosPredictor:
|
||||
"""HuggingFace Chronos-2 zero-shot forecaster."""
|
||||
|
||||
def __init__(self, model_name: str = "amazon/chronos-2", device: str | None = None):
|
||||
# BaseChronosPipeline auto-detects model variant (Chronos / ChronosBolt / Chronos-2)
|
||||
# and returns the appropriate sub-pipeline. ChronosPipeline only supports legacy T5.
|
||||
import torch
|
||||
try:
|
||||
from chronos import BaseChronosPipeline
|
||||
pipeline_cls = BaseChronosPipeline
|
||||
except ImportError:
|
||||
from chronos import ChronosPipeline
|
||||
pipeline_cls = ChronosPipeline
|
||||
|
||||
self._device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# Always use float32 — Korean stock prices (e.g. 280,000원) exceed FP16 max (~65,504)
|
||||
# causing inf in quantile output. FP32 is safe for typical price magnitudes.
|
||||
dtype = torch.float32
|
||||
logger.info("Loading Chronos pipeline: %s on %s (cls=%s)",
|
||||
model_name, self._device, pipeline_cls.__name__)
|
||||
# Try `dtype` (newer API) first, fall back to `torch_dtype` (older)
|
||||
try:
|
||||
self._pipeline = pipeline_cls.from_pretrained(
|
||||
model_name, device_map=self._device, dtype=dtype,
|
||||
)
|
||||
except TypeError:
|
||||
self._pipeline = pipeline_cls.from_pretrained(
|
||||
model_name, device_map=self._device, torch_dtype=dtype,
|
||||
)
|
||||
logger.info("Chronos pipeline loaded.")
|
||||
|
||||
def predict_batch(
|
||||
self,
|
||||
daily_ohlcv_dict: dict[str, list[dict]],
|
||||
prediction_length: int = 1,
|
||||
num_samples: int = 100,
|
||||
) -> dict[str, ChronosPrediction]:
|
||||
"""종목별 1-day return 분포 예측.
|
||||
|
||||
ChronosBolt / Chronos-2 등 신모델은 predict_quantiles 사용 (deterministic).
|
||||
Legacy ChronosPipeline (T5) 는 sample-based predict.
|
||||
"""
|
||||
import torch
|
||||
|
||||
tickers = list(daily_ohlcv_dict.keys())
|
||||
if not tickers:
|
||||
return {}
|
||||
|
||||
contexts = [
|
||||
torch.tensor([bar["close"] for bar in daily_ohlcv_dict[t]], dtype=torch.float32)
|
||||
for t in tickers
|
||||
]
|
||||
now_iso = datetime.now(KST).isoformat()
|
||||
results: dict[str, ChronosPrediction] = {}
|
||||
|
||||
# Modern API: predict_quantiles (ChronosBolt / Chronos-2)
|
||||
if hasattr(self._pipeline, "predict_quantiles"):
|
||||
quantile_levels = [0.1, 0.5, 0.9]
|
||||
# ChronosBolt API: positional `inputs` (first arg). Older variants use `context`.
|
||||
try:
|
||||
quantiles_tensor, _ = self._pipeline.predict_quantiles(
|
||||
contexts,
|
||||
prediction_length=prediction_length,
|
||||
quantile_levels=quantile_levels,
|
||||
)
|
||||
except TypeError:
|
||||
quantiles_tensor, _ = self._pipeline.predict_quantiles(
|
||||
context=contexts,
|
||||
prediction_length=prediction_length,
|
||||
quantile_levels=quantile_levels,
|
||||
)
|
||||
quantiles_np = (
|
||||
quantiles_tensor.cpu().numpy()
|
||||
if hasattr(quantiles_tensor, "cpu")
|
||||
else np.asarray(quantiles_tensor)
|
||||
)
|
||||
# shape: [num_series, prediction_length, 3]
|
||||
for i, ticker in enumerate(tickers):
|
||||
q10_price, q50_price, q90_price = quantiles_np[i, 0, :]
|
||||
last_close = daily_ohlcv_dict[ticker][-1]["close"]
|
||||
median = float((q50_price - last_close) / last_close)
|
||||
q10 = float((q10_price - last_close) / last_close)
|
||||
q90 = float((q90_price - last_close) / last_close)
|
||||
spread = (q90 - q10) / max(abs(median), 0.001)
|
||||
conf = float(max(0.0, min(1.0, 1.0 - spread / 2.0)))
|
||||
results[ticker] = ChronosPrediction(
|
||||
median=median, q10=q10, q90=q90, conf=conf, as_of=now_iso,
|
||||
)
|
||||
return results
|
||||
|
||||
# Legacy API: sample-based predict (ChronosPipeline T5)
|
||||
forecasts = self._pipeline.predict(
|
||||
context=contexts,
|
||||
prediction_length=prediction_length,
|
||||
num_samples=num_samples,
|
||||
)
|
||||
forecasts_np = forecasts.numpy() if hasattr(forecasts, "numpy") else np.asarray(forecasts)
|
||||
for i, ticker in enumerate(tickers):
|
||||
samples = forecasts_np[i, :, 0]
|
||||
last_close = daily_ohlcv_dict[ticker][-1]["close"]
|
||||
returns = (samples - last_close) / last_close
|
||||
median = float(np.quantile(returns, 0.5))
|
||||
q10 = float(np.quantile(returns, 0.1))
|
||||
q90 = float(np.quantile(returns, 0.9))
|
||||
spread = (q90 - q10) / max(abs(median), 0.001)
|
||||
conf = float(max(0.0, min(1.0, 1.0 - spread / 2.0)))
|
||||
results[ticker] = ChronosPrediction(
|
||||
median=median, q10=q10, q90=q90, conf=conf, as_of=now_iso,
|
||||
)
|
||||
return results
|
||||
75
ai_trade/config.py
Normal file
75
ai_trade/config.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Signal V2 환경변수 로딩."""
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(Path(__file__).parent.parent / ".env")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Settings:
|
||||
stock_api_url: str = field(
|
||||
default_factory=lambda: os.getenv("STOCK_API_URL", "").rstrip("/")
|
||||
)
|
||||
webai_api_key: str = field(
|
||||
default_factory=lambda: os.getenv("WEBAI_API_KEY", "").strip()
|
||||
)
|
||||
port: int = field(default_factory=lambda: int(os.getenv("SIGNAL_V2_PORT", "8001")))
|
||||
db_path: Path = field(
|
||||
default_factory=lambda: Path(__file__).parent / "data" / "ai_trade.db"
|
||||
)
|
||||
# KIS — V1 호환 패턴 (KIS_ENV_TYPE virtual/real)
|
||||
kis_env_type: str = field(default_factory=lambda: os.getenv("KIS_ENV_TYPE", "virtual").lower())
|
||||
kis_real_app_key: str = field(default_factory=lambda: os.getenv("KIS_REAL_APP_KEY", "").strip())
|
||||
kis_real_app_secret: str = field(default_factory=lambda: os.getenv("KIS_REAL_APP_SECRET", "").strip())
|
||||
kis_real_account: str = field(default_factory=lambda: os.getenv("KIS_REAL_ACCOUNT", "").strip())
|
||||
kis_virtual_app_key: str = field(default_factory=lambda: os.getenv("KIS_VIRTUAL_APP_KEY", "").strip())
|
||||
kis_virtual_app_secret: str = field(default_factory=lambda: os.getenv("KIS_VIRTUAL_APP_SECRET", "").strip())
|
||||
kis_virtual_account: str = field(default_factory=lambda: os.getenv("KIS_VIRTUAL_ACCOUNT", "").strip())
|
||||
v1_token_path: Path = field(
|
||||
default_factory=lambda: Path(
|
||||
os.getenv("V1_TOKEN_PATH",
|
||||
str(Path(__file__).parent.parent / "signal_v1" / "data" / "kis_token.json"))
|
||||
)
|
||||
)
|
||||
chronos_model: str = field(default_factory=lambda: os.getenv("CHRONOS_MODEL", "amazon/chronos-2"))
|
||||
stop_loss_pct: float = field(
|
||||
default_factory=lambda: float(os.getenv("STOP_LOSS_PCT", "-0.07"))
|
||||
)
|
||||
take_profit_pct: float = field(
|
||||
default_factory=lambda: float(os.getenv("TAKE_PROFIT_PCT", "0.15"))
|
||||
)
|
||||
chronos_spread_threshold: float = field(
|
||||
default_factory=lambda: float(os.getenv("CHRONOS_SPREAD_THRESHOLD", "0.6"))
|
||||
)
|
||||
asking_bid_ratio_threshold: float = field(
|
||||
default_factory=lambda: float(os.getenv("ASKING_BID_RATIO_THRESHOLD", "0.6"))
|
||||
)
|
||||
confidence_threshold: float = field(
|
||||
default_factory=lambda: float(os.getenv("CONFIDENCE_THRESHOLD", "0.7"))
|
||||
)
|
||||
min_momentum_for_buy: str = field(
|
||||
default_factory=lambda: os.getenv("MIN_MOMENTUM_FOR_BUY", "strong_up")
|
||||
)
|
||||
|
||||
@property
|
||||
def kis_is_virtual(self) -> bool:
|
||||
return self.kis_env_type != "real"
|
||||
|
||||
@property
|
||||
def kis_app_key(self) -> str:
|
||||
return self.kis_real_app_key if self.kis_env_type == "real" else self.kis_virtual_app_key
|
||||
|
||||
@property
|
||||
def kis_app_secret(self) -> str:
|
||||
return self.kis_real_app_secret if self.kis_env_type == "real" else self.kis_virtual_app_secret
|
||||
|
||||
@property
|
||||
def kis_account(self) -> str:
|
||||
return self.kis_real_account if self.kis_env_type == "real" else self.kis_virtual_account
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
0
ai_trade/data/.gitkeep
Normal file
0
ai_trade/data/.gitkeep
Normal file
18
ai_trade/holidays.json
Normal file
18
ai_trade/holidays.json
Normal file
@@ -0,0 +1,18 @@
|
||||
[
|
||||
"2026-01-01",
|
||||
"2026-01-28",
|
||||
"2026-01-29",
|
||||
"2026-01-30",
|
||||
"2026-03-01",
|
||||
"2026-05-05",
|
||||
"2026-05-25",
|
||||
"2026-06-06",
|
||||
"2026-08-15",
|
||||
"2026-09-24",
|
||||
"2026-09-25",
|
||||
"2026-09-26",
|
||||
"2026-10-03",
|
||||
"2026-10-09",
|
||||
"2026-12-25",
|
||||
"2026-12-31"
|
||||
]
|
||||
193
ai_trade/kis_client.py
Normal file
193
ai_trade/kis_client.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""KIS REST API client — 분봉 + 호가. V1 토큰 read-only 공유."""
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
KST = ZoneInfo("Asia/Seoul")
|
||||
|
||||
_MAX_ATTEMPTS = 3
|
||||
_THROTTLE_INTERVAL = 0.5 # 초당 2회 제한
|
||||
|
||||
|
||||
class KISClient:
|
||||
"""KIS REST (분봉 + 호가). V1 토큰 파일 read-only."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_key: str, app_secret: str, account: str, is_virtual: bool,
|
||||
v1_token_path: Path,
|
||||
timeout: float = 10.0,
|
||||
):
|
||||
self._app_key = app_key
|
||||
self._app_secret = app_secret
|
||||
self._account = account
|
||||
self._is_virtual = is_virtual
|
||||
self._v1_token_path = Path(v1_token_path)
|
||||
self._base_url = (
|
||||
"https://openapivts.koreainvestment.com:29443" if is_virtual
|
||||
else "https://openapi.koreainvestment.com:9443"
|
||||
)
|
||||
self._client = httpx.AsyncClient(timeout=timeout)
|
||||
self._token_cache: tuple[str, float] | None = None # (token, file_mtime)
|
||||
self._last_throttle_at = 0.0
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._client.aclose()
|
||||
|
||||
def _read_v1_token(self) -> str:
|
||||
if not self._v1_token_path.exists():
|
||||
raise RuntimeError(f"V1 token file missing: {self._v1_token_path}")
|
||||
mtime = self._v1_token_path.stat().st_mtime
|
||||
if self._token_cache and self._token_cache[1] == mtime:
|
||||
return self._token_cache[0]
|
||||
data = json.loads(self._v1_token_path.read_text(encoding="utf-8"))
|
||||
token = data.get("access_token", "")
|
||||
if not token:
|
||||
raise RuntimeError("V1 token file has no access_token")
|
||||
self._token_cache = (token, mtime)
|
||||
return token
|
||||
|
||||
async def _throttle(self) -> None:
|
||||
elapsed = time.monotonic() - self._last_throttle_at
|
||||
if elapsed < _THROTTLE_INTERVAL:
|
||||
await asyncio.sleep(_THROTTLE_INTERVAL - elapsed)
|
||||
self._last_throttle_at = time.monotonic()
|
||||
|
||||
def _common_headers(self, tr_id: str) -> dict[str, str]:
|
||||
token = self._read_v1_token()
|
||||
return {
|
||||
"authorization": f"Bearer {token}",
|
||||
"appkey": self._app_key,
|
||||
"appsecret": self._app_secret,
|
||||
"tr_id": tr_id,
|
||||
"custtype": "P",
|
||||
}
|
||||
|
||||
async def _request_with_retry(
|
||||
self, method: str, path: str, tr_id: str, **kwargs,
|
||||
) -> dict:
|
||||
url = f"{self._base_url}{path}"
|
||||
headers = self._common_headers(tr_id)
|
||||
for attempt in range(_MAX_ATTEMPTS):
|
||||
await self._throttle()
|
||||
try:
|
||||
response = await self._client.request(
|
||||
method, url, headers=headers, **kwargs
|
||||
)
|
||||
if response.status_code == 429:
|
||||
if attempt < _MAX_ATTEMPTS - 1:
|
||||
await asyncio.sleep(2**attempt)
|
||||
continue
|
||||
response.raise_for_status()
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.TimeoutException:
|
||||
if attempt < _MAX_ATTEMPTS - 1:
|
||||
await asyncio.sleep(2**attempt)
|
||||
continue
|
||||
raise
|
||||
raise RuntimeError("retry exhausted")
|
||||
|
||||
async def get_minute_ohlcv(self, ticker: str) -> list[dict]:
|
||||
"""현재 시점 직전 30개 1분봉 OHLCV (TR_ID FHKST03010200)."""
|
||||
path = "/uapi/domestic-stock/v1/quotations/inquire-time-itemchartprice"
|
||||
params = {
|
||||
"FID_ETC_CLS_CODE": "",
|
||||
"FID_COND_MRKT_DIV_CODE": "J",
|
||||
"FID_INPUT_ISCD": ticker,
|
||||
"FID_INPUT_HOUR_1": datetime.now(KST).strftime("%H%M%S"),
|
||||
"FID_PW_DATA_INCU_YN": "N",
|
||||
}
|
||||
raw = await self._request_with_retry(
|
||||
"GET", path, tr_id="FHKST03010200", params=params,
|
||||
)
|
||||
output2 = raw.get("output2", [])
|
||||
bars = []
|
||||
for row in output2:
|
||||
try:
|
||||
date = row["stck_bsop_date"]
|
||||
hhmmss = row["stck_cntg_hour"]
|
||||
dt = datetime.strptime(f"{date} {hhmmss}", "%Y%m%d %H%M%S").replace(tzinfo=KST)
|
||||
bars.append({
|
||||
"datetime": dt.isoformat(),
|
||||
"open": int(row["stck_oprc"]),
|
||||
"high": int(row["stck_hgpr"]),
|
||||
"low": int(row["stck_lwpr"]),
|
||||
"close": int(row["stck_prpr"]),
|
||||
"volume": int(row["cntg_vol"]),
|
||||
})
|
||||
except (KeyError, ValueError) as e:
|
||||
logger.warning("skip malformed bar for %s: %r", ticker, e)
|
||||
# KIS returns descending; reverse to ascending (most recent last)
|
||||
bars.reverse()
|
||||
return bars
|
||||
|
||||
async def get_asking_price(self, ticker: str) -> dict:
|
||||
"""현재 호가 + 매수/매도 잔량 (TR_ID FHKST01010200)."""
|
||||
path = "/uapi/domestic-stock/v1/quotations/inquire-asking-price-exp-ccn"
|
||||
params = {
|
||||
"FID_COND_MRKT_DIV_CODE": "J",
|
||||
"FID_INPUT_ISCD": ticker,
|
||||
}
|
||||
raw = await self._request_with_retry(
|
||||
"GET", path, tr_id="FHKST01010200", params=params,
|
||||
)
|
||||
output1 = raw.get("output1", {})
|
||||
bid_total = int(output1.get("total_bidp_rsqn", 0))
|
||||
ask_total = int(output1.get("total_askp_rsqn", 0))
|
||||
total = bid_total + ask_total
|
||||
bid_ratio = bid_total / total if total > 0 else 0.0
|
||||
current_price = int(output1.get("stck_prpr", 0))
|
||||
return {
|
||||
"bid_total": bid_total,
|
||||
"ask_total": ask_total,
|
||||
"bid_ratio": bid_ratio,
|
||||
"current_price": current_price,
|
||||
"as_of": datetime.now(KST).isoformat(),
|
||||
}
|
||||
|
||||
async def get_daily_ohlcv(self, ticker: str, days: int = 60) -> list[dict]:
|
||||
"""KRX 일봉 OHLCV (TR_ID FHKST03010100).
|
||||
|
||||
Returns: [{"datetime", "open", "high", "low", "close", "volume"}, ...]
|
||||
시간 오름차순.
|
||||
"""
|
||||
path = "/uapi/domestic-stock/v1/quotations/inquire-daily-itemchartprice"
|
||||
today = datetime.now(KST).strftime("%Y%m%d")
|
||||
start_date = (datetime.now(KST) - timedelta(days=days * 2)).strftime("%Y%m%d")
|
||||
params = {
|
||||
"FID_COND_MRKT_DIV_CODE": "J",
|
||||
"FID_INPUT_ISCD": ticker,
|
||||
"FID_INPUT_DATE_1": start_date,
|
||||
"FID_INPUT_DATE_2": today,
|
||||
"FID_PERIOD_DIV_CODE": "D",
|
||||
"FID_ORG_ADJ_PRC": "1",
|
||||
}
|
||||
raw = await self._request_with_retry(
|
||||
"GET", path, tr_id="FHKST03010100", params=params,
|
||||
)
|
||||
output2 = raw.get("output2", [])
|
||||
bars = []
|
||||
for row in output2:
|
||||
try:
|
||||
date = row["stck_bsop_date"]
|
||||
bars.append({
|
||||
"datetime": f"{date[:4]}-{date[4:6]}-{date[6:]}",
|
||||
"open": int(row["stck_oprc"]),
|
||||
"high": int(row["stck_hgpr"]),
|
||||
"low": int(row["stck_lwpr"]),
|
||||
"close": int(row["stck_clpr"]),
|
||||
"volume": int(row["acml_vol"]),
|
||||
})
|
||||
except (KeyError, ValueError):
|
||||
continue
|
||||
bars.reverse()
|
||||
return bars[-days:]
|
||||
186
ai_trade/kis_websocket.py
Normal file
186
ai_trade/kis_websocket.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""KIS WebSocket — approval_key + 실시간 호가 구독."""
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Callable
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
KST = ZoneInfo("Asia/Seoul")
|
||||
|
||||
# KIS 호가 메시지 필드 인덱스 (운영 환경 검증 필요)
|
||||
# H0STASP0 응답: ticker | time | current_price | ... | ask_total | bid_total
|
||||
# 본 spec/plan 의 가정: 마지막 2개 필드가 ask_total / bid_total
|
||||
_ASKING_TICKER_IDX = 0
|
||||
_ASKING_TIME_IDX = 1
|
||||
_ASKING_CURRENT_PRICE_IDX = 2
|
||||
_ASKING_TOTAL_ASK_IDX = -2
|
||||
_ASKING_TOTAL_BID_IDX = -1
|
||||
|
||||
|
||||
class KISWebSocket:
|
||||
"""KIS WebSocket client. approval_key 발급 + 호가 실시간."""
|
||||
|
||||
def __init__(self, app_key: str, app_secret: str, is_virtual: bool):
|
||||
self._app_key = app_key
|
||||
self._app_secret = app_secret
|
||||
self._is_virtual = is_virtual
|
||||
self._base_rest = (
|
||||
"https://openapivts.koreainvestment.com:29443" if is_virtual
|
||||
else "https://openapi.koreainvestment.com:9443"
|
||||
)
|
||||
self._ws_url = (
|
||||
"ws://ops.koreainvestment.com:31000" if is_virtual
|
||||
else "ws://ops.koreainvestment.com:21000"
|
||||
)
|
||||
self._approval_key: str | None = None
|
||||
self._ws = None
|
||||
self._subscriptions: set[str] = set()
|
||||
self._on_asking_price: Callable[[str, dict], None] | None = None
|
||||
self._recv_task: asyncio.Task | None = None
|
||||
self._shutdown = asyncio.Event()
|
||||
|
||||
async def _fetch_approval_key(self) -> str:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.post(
|
||||
f"{self._base_rest}/oauth2/Approval",
|
||||
json={
|
||||
"grant_type": "client_credentials",
|
||||
"appkey": self._app_key,
|
||||
"secretkey": self._app_secret,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
self._approval_key = data["approval_key"]
|
||||
return self._approval_key
|
||||
|
||||
async def _connect(self):
|
||||
return await websockets.connect(self._ws_url)
|
||||
|
||||
async def _connect_with_backoff(self):
|
||||
"""연결 시도 with exponential backoff (1s → 2s → 4s → max 30s)."""
|
||||
for attempt in range(10):
|
||||
try:
|
||||
ws = await self._connect()
|
||||
return ws
|
||||
except Exception as e:
|
||||
wait = min(2**attempt, 30)
|
||||
logger.warning(
|
||||
"KIS WebSocket connect failed (attempt %d): %r — retrying in %ds",
|
||||
attempt + 1, e, wait,
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
raise RuntimeError("KIS WebSocket connect exhausted retries")
|
||||
|
||||
async def start(
|
||||
self, tickers: list[str],
|
||||
on_asking_price: Callable[[str, dict], None],
|
||||
) -> None:
|
||||
if self._approval_key is None:
|
||||
await self._fetch_approval_key()
|
||||
self._on_asking_price = on_asking_price
|
||||
self._ws = await self._connect_with_backoff()
|
||||
for ticker in tickers:
|
||||
await self.subscribe(ticker)
|
||||
self._recv_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
async def subscribe(self, ticker: str) -> None:
|
||||
if self._ws is None or self._approval_key is None:
|
||||
raise RuntimeError("KIS WebSocket not started")
|
||||
msg = json.dumps({
|
||||
"header": {
|
||||
"approval_key": self._approval_key,
|
||||
"custtype": "P",
|
||||
"tr_type": "1",
|
||||
"content-type": "utf-8",
|
||||
},
|
||||
"body": {
|
||||
"input": {"tr_id": "H0STASP0", "tr_key": ticker},
|
||||
},
|
||||
})
|
||||
await self._ws.send(msg)
|
||||
self._subscriptions.add(ticker)
|
||||
|
||||
async def unsubscribe(self, ticker: str) -> None:
|
||||
if self._ws is None or self._approval_key is None:
|
||||
return
|
||||
msg = json.dumps({
|
||||
"header": {
|
||||
"approval_key": self._approval_key,
|
||||
"custtype": "P",
|
||||
"tr_type": "2",
|
||||
"content-type": "utf-8",
|
||||
},
|
||||
"body": {
|
||||
"input": {"tr_id": "H0STASP0", "tr_key": ticker},
|
||||
},
|
||||
})
|
||||
await self._ws.send(msg)
|
||||
self._subscriptions.discard(ticker)
|
||||
|
||||
async def close(self) -> None:
|
||||
self._shutdown.set()
|
||||
if self._recv_task is not None:
|
||||
self._recv_task.cancel()
|
||||
try:
|
||||
await self._recv_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._ws is not None:
|
||||
await self._ws.close()
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
while not self._shutdown.is_set():
|
||||
try:
|
||||
raw = await self._ws.recv()
|
||||
except websockets.ConnectionClosed:
|
||||
logger.warning("KIS WebSocket closed — reconnecting")
|
||||
self._ws = await self._connect_with_backoff()
|
||||
for ticker in list(self._subscriptions):
|
||||
await self.subscribe(ticker)
|
||||
continue
|
||||
if not isinstance(raw, str):
|
||||
continue
|
||||
parsed = self._parse_asking_price(raw)
|
||||
if parsed is not None and self._on_asking_price is not None:
|
||||
ticker, data = parsed
|
||||
try:
|
||||
self._on_asking_price(ticker, data)
|
||||
except Exception:
|
||||
logger.exception("on_asking_price callback failed")
|
||||
|
||||
def _parse_asking_price(self, raw: str) -> tuple[str, dict] | None:
|
||||
"""KIS H0STASP0 raw → (ticker, asking_price dict).
|
||||
|
||||
Raw format: '0|H0STASP0|<count>|<data>' where data = '^'-joined fields.
|
||||
Field indices (운영 검증 필요): 마지막 2개 가정 (ask, bid).
|
||||
"""
|
||||
try:
|
||||
parts = raw.split("|")
|
||||
if len(parts) < 4 or parts[1] != "H0STASP0":
|
||||
return None
|
||||
fields = parts[3].split("^")
|
||||
ticker = fields[_ASKING_TICKER_IDX]
|
||||
current_price_str = fields[_ASKING_CURRENT_PRICE_IDX]
|
||||
current_price = int(current_price_str) if current_price_str.lstrip("-").isdigit() else 0
|
||||
ask_str = fields[_ASKING_TOTAL_ASK_IDX]
|
||||
bid_str = fields[_ASKING_TOTAL_BID_IDX]
|
||||
ask_total = int(ask_str) if ask_str.lstrip("-").isdigit() else 0
|
||||
bid_total = int(bid_str) if bid_str.lstrip("-").isdigit() else 0
|
||||
total = bid_total + ask_total
|
||||
return ticker, {
|
||||
"bid_total": bid_total,
|
||||
"ask_total": ask_total,
|
||||
"bid_ratio": bid_total / total if total > 0 else 0.0,
|
||||
"current_price": current_price,
|
||||
"as_of": datetime.now(KST).isoformat(),
|
||||
}
|
||||
except (IndexError, ValueError) as e:
|
||||
logger.warning("parse_asking_price failed: %r", e)
|
||||
return None
|
||||
125
ai_trade/main.py
Normal file
125
ai_trade/main.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""FastAPI app — Signal V2 Pull Worker."""
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from ai_trade import state as state_mod
|
||||
from ai_trade.chronos_predictor import ChronosPredictor
|
||||
from ai_trade.config import get_settings
|
||||
from ai_trade.kis_client import KISClient
|
||||
from ai_trade.kis_websocket import KISWebSocket
|
||||
from ai_trade.pull_worker import poll_loop, make_asking_price_callback
|
||||
from ai_trade.rate_limit import SignalDedup
|
||||
from ai_trade.stock_client import StockClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppContext:
|
||||
client: StockClient | None = None
|
||||
dedup: SignalDedup | None = None
|
||||
shutdown: asyncio.Event | None = None
|
||||
poll_task: asyncio.Task | None = None
|
||||
kis_client: KISClient | None = None
|
||||
kis_ws: KISWebSocket | None = None
|
||||
chronos: ChronosPredictor | None = None
|
||||
|
||||
|
||||
_ctx = AppContext()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
settings = get_settings()
|
||||
if not settings.webai_api_key:
|
||||
logger.warning(
|
||||
"WEBAI_API_KEY not configured — stock API calls will fail with 401"
|
||||
)
|
||||
if not settings.kis_app_key:
|
||||
logger.warning(
|
||||
"KIS app_key not configured (KIS_ENV_TYPE=%s, KIS_%s_APP_KEY missing) — KIS REST/WebSocket disabled",
|
||||
settings.kis_env_type, settings.kis_env_type.upper()
|
||||
)
|
||||
|
||||
_ctx.client = StockClient(settings.stock_api_url, settings.webai_api_key)
|
||||
_ctx.dedup = SignalDedup(settings.db_path)
|
||||
_ctx.shutdown = asyncio.Event()
|
||||
|
||||
# KIS only if app_key configured
|
||||
if settings.kis_app_key:
|
||||
_ctx.kis_client = KISClient(
|
||||
app_key=settings.kis_app_key,
|
||||
app_secret=settings.kis_app_secret,
|
||||
account=settings.kis_account,
|
||||
is_virtual=settings.kis_is_virtual,
|
||||
v1_token_path=settings.v1_token_path,
|
||||
)
|
||||
_ctx.kis_ws = KISWebSocket(
|
||||
app_key=settings.kis_app_key,
|
||||
app_secret=settings.kis_app_secret,
|
||||
is_virtual=settings.kis_is_virtual,
|
||||
)
|
||||
# Subscribe portfolio holdings (if any)
|
||||
try:
|
||||
portfolio = await _ctx.client.get_portfolio()
|
||||
tickers = [h["ticker"] for h in portfolio.get("holdings", []) if "ticker" in h]
|
||||
cb = make_asking_price_callback(state_mod.state)
|
||||
await _ctx.kis_ws.start(tickers, cb)
|
||||
except Exception:
|
||||
logger.exception("KIS WebSocket startup failed — continuing without realtime asking_price")
|
||||
|
||||
# Load Chronos (heavy: ~1GB model download first time)
|
||||
try:
|
||||
_ctx.chronos = ChronosPredictor(model_name=settings.chronos_model)
|
||||
except Exception:
|
||||
logger.exception("ChronosPredictor load failed — continuing without chronos predictions")
|
||||
|
||||
_ctx.poll_task = asyncio.create_task(
|
||||
poll_loop(
|
||||
_ctx.client, state_mod.state, _ctx.shutdown,
|
||||
kis_client=_ctx.kis_client,
|
||||
chronos=_ctx.chronos,
|
||||
dedup=_ctx.dedup,
|
||||
settings=settings,
|
||||
)
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
if _ctx.shutdown is not None:
|
||||
_ctx.shutdown.set()
|
||||
if _ctx.poll_task is not None:
|
||||
try:
|
||||
await asyncio.wait_for(_ctx.poll_task, timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
_ctx.poll_task.cancel()
|
||||
try:
|
||||
await _ctx.poll_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if _ctx.kis_ws is not None:
|
||||
await _ctx.kis_ws.close()
|
||||
if _ctx.kis_client is not None:
|
||||
await _ctx.kis_client.close()
|
||||
if _ctx.client is not None:
|
||||
await _ctx.client.close()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Signal V2 Pull Worker", version="0.1.0", lifespan=lifespan
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
settings = get_settings()
|
||||
return {
|
||||
"status": "online",
|
||||
"stock_api_url": settings.stock_api_url,
|
||||
"last_poll": state_mod.state.last_updated,
|
||||
"cache_size": _ctx.client.cache_size() if _ctx.client is not None else 0,
|
||||
}
|
||||
69
ai_trade/momentum_classifier.py
Normal file
69
ai_trade/momentum_classifier.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""분봉 OHLCV → 5-level 모멘텀 분류."""
|
||||
from __future__ import annotations
|
||||
from collections import deque
|
||||
|
||||
# 분류 카테고리
|
||||
STRONG_UP = "strong_up"
|
||||
WEAK_UP = "weak_up"
|
||||
NEUTRAL = "neutral"
|
||||
WEAK_DOWN = "weak_down"
|
||||
STRONG_DOWN = "strong_down"
|
||||
|
||||
_BARS_PER_5MIN = 5
|
||||
_LOOKBACK_5MIN_BARS = 5
|
||||
_VOLUME_AVG_WINDOW = 12 # 60분 = 5분봉 12개
|
||||
|
||||
|
||||
def aggregate_1min_to_5min(minute_bars: list[dict]) -> list[dict]:
|
||||
"""1분봉 N개 → 5분봉 floor(N/5) 개. 시간 오름차순.
|
||||
|
||||
각 5분봉: open=첫 1분봉 open, high=max, low=min, close=마지막 close, volume=sum.
|
||||
"""
|
||||
bars_5min = []
|
||||
chunks = len(minute_bars) // _BARS_PER_5MIN
|
||||
for i in range(chunks):
|
||||
chunk = minute_bars[i * _BARS_PER_5MIN : (i + 1) * _BARS_PER_5MIN]
|
||||
bars_5min.append({
|
||||
"datetime": chunk[0]["datetime"],
|
||||
"open": chunk[0]["open"],
|
||||
"high": max(b["high"] for b in chunk),
|
||||
"low": min(b["low"] for b in chunk),
|
||||
"close": chunk[-1]["close"],
|
||||
"volume": sum(b["volume"] for b in chunk),
|
||||
})
|
||||
return bars_5min
|
||||
|
||||
|
||||
def classify_minute_momentum(minute_bars: deque) -> str:
|
||||
"""1분봉 deque → 5-level 모멘텀 분류.
|
||||
|
||||
Returns: STRONG_UP / WEAK_UP / NEUTRAL / WEAK_DOWN / STRONG_DOWN
|
||||
"""
|
||||
minute_list = list(minute_bars)
|
||||
if len(minute_list) < _BARS_PER_5MIN * _LOOKBACK_5MIN_BARS:
|
||||
return NEUTRAL # 데이터 부족
|
||||
|
||||
bars_5min = aggregate_1min_to_5min(minute_list)
|
||||
if len(bars_5min) < _LOOKBACK_5MIN_BARS:
|
||||
return NEUTRAL
|
||||
|
||||
recent = bars_5min[-_LOOKBACK_5MIN_BARS:]
|
||||
up_count = sum(1 for b in recent if b["close"] > b["open"])
|
||||
|
||||
# 거래량 multiplier: recent 5 avg vs 60분 avg
|
||||
recent_vol_avg = sum(b["volume"] for b in recent) / len(recent)
|
||||
long_window = bars_5min[-_VOLUME_AVG_WINDOW:]
|
||||
long_vol_avg = sum(b["volume"] for b in long_window) / len(long_window)
|
||||
vol_mult = recent_vol_avg / long_vol_avg if long_vol_avg > 0 else 1.0
|
||||
|
||||
# 5-level 분류
|
||||
if up_count == 5 and vol_mult >= 1.5:
|
||||
return STRONG_UP
|
||||
elif up_count >= 3 and vol_mult >= 1.0:
|
||||
return WEAK_UP
|
||||
elif up_count == 0 and vol_mult >= 1.5:
|
||||
return STRONG_DOWN
|
||||
elif up_count <= 2 and vol_mult < 1.0:
|
||||
return WEAK_DOWN
|
||||
else:
|
||||
return NEUTRAL
|
||||
193
ai_trade/pull_worker.py
Normal file
193
ai_trade/pull_worker.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""Polling loop — async cron + state update."""
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
|
||||
from ai_trade.kis_client import KISClient
|
||||
from ai_trade.scheduler import (
|
||||
KST, _is_market_day, _is_polling_window, _next_interval, _is_post_close_trigger,
|
||||
)
|
||||
from ai_trade.state import PollState
|
||||
from ai_trade.stock_client import StockClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def poll_loop(
|
||||
client: StockClient, state: PollState, shutdown: asyncio.Event,
|
||||
kis_client: KISClient | None = None,
|
||||
chronos=None,
|
||||
dedup=None,
|
||||
settings=None,
|
||||
) -> None:
|
||||
"""FastAPI lifespan 에서 asyncio.create_task 로 시작."""
|
||||
logger.info("poll_loop started")
|
||||
while not shutdown.is_set():
|
||||
now = datetime.now(KST)
|
||||
if _is_market_day(now) and _is_polling_window(now):
|
||||
try:
|
||||
await _run_polling_cycle(client, state, kis_client=kis_client)
|
||||
except Exception:
|
||||
logger.exception("poll cycle failed")
|
||||
# Minute momentum 갱신 (매 cycle)
|
||||
try:
|
||||
update_minute_momentum_for_all(state)
|
||||
except Exception:
|
||||
logger.exception("minute momentum update failed")
|
||||
# Post-close trigger (16:00 KST)
|
||||
if _is_post_close_trigger(now) and chronos is not None and kis_client is not None:
|
||||
try:
|
||||
await _run_post_close_cycle(kis_client, chronos, state)
|
||||
except Exception:
|
||||
logger.exception("post-close cycle failed")
|
||||
# Phase 4: generate signals
|
||||
if dedup is not None and settings is not None:
|
||||
try:
|
||||
from ai_trade.signal_generator import generate_signals
|
||||
generate_signals(state, dedup, settings)
|
||||
except Exception:
|
||||
logger.exception("generate_signals failed")
|
||||
interval = _next_interval(now)
|
||||
try:
|
||||
await asyncio.wait_for(shutdown.wait(), timeout=interval)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
logger.info("poll_loop ended")
|
||||
|
||||
|
||||
async def _run_polling_cycle(
|
||||
client: StockClient, state: PollState,
|
||||
kis_client: KISClient | None = None,
|
||||
) -> None:
|
||||
"""기존 3 endpoint (stock) + KIS 분봉 fetch."""
|
||||
portfolio, sentiment, screener = await asyncio.gather(
|
||||
client.get_portfolio(),
|
||||
client.get_news_sentiment(),
|
||||
client.run_screener_preview(),
|
||||
return_exceptions=True,
|
||||
)
|
||||
now_iso = datetime.now(KST).isoformat()
|
||||
|
||||
for name, result in (
|
||||
("portfolio", portfolio),
|
||||
("news_sentiment", sentiment),
|
||||
("screener_preview", screener),
|
||||
):
|
||||
if isinstance(result, dict):
|
||||
setattr(state, name, result)
|
||||
state.last_updated[name] = now_iso
|
||||
state.fetch_errors[name] = 0
|
||||
else:
|
||||
state.fetch_errors[name] = state.fetch_errors.get(name, 0) + 1
|
||||
logger.warning("fetch %s failed: %r", name, result)
|
||||
|
||||
# KIS 분봉 + 호가 (kis_client 주어졌을 때만)
|
||||
if kis_client is not None:
|
||||
try:
|
||||
await _run_kis_minute_cycle(kis_client, state)
|
||||
except Exception:
|
||||
logger.exception("kis minute cycle failed")
|
||||
|
||||
|
||||
async def _run_kis_minute_cycle(kis_client: KISClient, state: PollState) -> None:
|
||||
"""KIS 분봉 + 호가 fetch + state 갱신.
|
||||
|
||||
- 분봉: portfolio + screener Top-N union 종목 모두
|
||||
- 호가 (REST): screener-only 종목 (portfolio 는 WebSocket 으로 들어옴)
|
||||
"""
|
||||
portfolio_tickers = _portfolio_tickers(state)
|
||||
screener_tickers = _screener_tickers(state)
|
||||
all_tickers = list(set(portfolio_tickers) | set(screener_tickers))
|
||||
|
||||
# 분봉 fetch (병렬)
|
||||
minute_results = await asyncio.gather(*[
|
||||
kis_client.get_minute_ohlcv(t) for t in all_tickers
|
||||
], return_exceptions=True)
|
||||
now_iso = datetime.now(KST).isoformat()
|
||||
for ticker, result in zip(all_tickers, minute_results):
|
||||
if isinstance(result, list):
|
||||
buf = state.minute_bars.setdefault(ticker, deque(maxlen=60))
|
||||
buf.extend(result)
|
||||
state.last_updated[f"minute_bars/{ticker}"] = now_iso
|
||||
else:
|
||||
state.fetch_errors[f"minute_bars/{ticker}"] = (
|
||||
state.fetch_errors.get(f"minute_bars/{ticker}", 0) + 1
|
||||
)
|
||||
|
||||
# 호가 fetch (REST) — screener-only
|
||||
screener_only = list(set(screener_tickers) - set(portfolio_tickers))
|
||||
asking_results = await asyncio.gather(*[
|
||||
kis_client.get_asking_price(t) for t in screener_only
|
||||
], return_exceptions=True)
|
||||
for ticker, result in zip(screener_only, asking_results):
|
||||
if isinstance(result, dict):
|
||||
state.asking_price[ticker] = result
|
||||
state.last_updated[f"asking_price/{ticker}"] = now_iso
|
||||
|
||||
|
||||
def make_asking_price_callback(state: PollState):
|
||||
"""KIS WebSocket on_asking_price callback factory."""
|
||||
def _cb(ticker: str, data: dict) -> None:
|
||||
state.asking_price[ticker] = data
|
||||
state.last_updated[f"asking_price/{ticker}"] = datetime.now(KST).isoformat()
|
||||
return _cb
|
||||
|
||||
|
||||
def _portfolio_tickers(state: PollState) -> list[str]:
|
||||
if state.portfolio is None:
|
||||
return []
|
||||
return [h["ticker"] for h in state.portfolio.get("holdings", []) if "ticker" in h]
|
||||
|
||||
|
||||
def _screener_tickers(state: PollState) -> list[str]:
|
||||
if state.screener_preview is None:
|
||||
return []
|
||||
return [i["ticker"] for i in state.screener_preview.get("items", []) if "ticker" in i]
|
||||
|
||||
|
||||
async def _run_post_close_cycle(kis_client, chronos, state) -> None:
|
||||
"""16:00 KST 종가 후 1회: daily fetch + chronos predict."""
|
||||
tickers = list(set(_portfolio_tickers(state)) | set(_screener_tickers(state)))
|
||||
if not tickers:
|
||||
return
|
||||
|
||||
daily_results = await asyncio.gather(*[
|
||||
kis_client.get_daily_ohlcv(t, days=60) for t in tickers
|
||||
], return_exceptions=True)
|
||||
daily_dict = {}
|
||||
for ticker, result in zip(tickers, daily_results):
|
||||
if isinstance(result, list) and len(result) >= 30:
|
||||
daily_dict[ticker] = result
|
||||
state.daily_ohlcv[ticker] = result
|
||||
elif isinstance(result, Exception):
|
||||
state.fetch_errors[f"daily_ohlcv/{ticker}"] = (
|
||||
state.fetch_errors.get(f"daily_ohlcv/{ticker}", 0) + 1
|
||||
)
|
||||
|
||||
if daily_dict and chronos is not None:
|
||||
try:
|
||||
predictions = chronos.predict_batch(daily_dict)
|
||||
except Exception:
|
||||
logger.exception("chronos predict_batch failed")
|
||||
return
|
||||
for ticker, pred in predictions.items():
|
||||
state.chronos_predictions[ticker] = {
|
||||
"median": pred.median,
|
||||
"q10": pred.q10,
|
||||
"q90": pred.q90,
|
||||
"conf": pred.conf,
|
||||
"as_of": pred.as_of,
|
||||
}
|
||||
state.last_updated[f"chronos/{ticker}"] = pred.as_of
|
||||
|
||||
|
||||
def update_minute_momentum_for_all(state) -> None:
|
||||
"""매 분봉 cycle 후 호출 — 모든 종목 모멘텀 갱신."""
|
||||
from ai_trade.momentum_classifier import classify_minute_momentum
|
||||
now_iso = datetime.now(KST).isoformat()
|
||||
for ticker, bars in state.minute_bars.items():
|
||||
state.minute_momentum[ticker] = classify_minute_momentum(bars)
|
||||
state.last_updated[f"momentum/{ticker}"] = now_iso
|
||||
3
ai_trade/pytest.ini
Normal file
3
ai_trade/pytest.ini
Normal file
@@ -0,0 +1,3 @@
|
||||
[pytest]
|
||||
asyncio_mode = auto
|
||||
testpaths = tests
|
||||
73
ai_trade/rate_limit.py
Normal file
73
ai_trade/rate_limit.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""SignalDedup — SQLite-backed 24h duplicate signal blocker."""
|
||||
from __future__ import annotations
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
KST = ZoneInfo("Asia/Seoul")
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
"""Test seam — overridable via monkeypatch."""
|
||||
return datetime.now(KST).isoformat()
|
||||
|
||||
|
||||
_SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS signal_dedup (
|
||||
ticker TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
last_sent TEXT NOT NULL,
|
||||
confidence REAL NOT NULL,
|
||||
PRIMARY KEY (ticker, action)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_signal_dedup_last_sent
|
||||
ON signal_dedup(last_sent);
|
||||
"""
|
||||
|
||||
|
||||
class SignalDedup:
|
||||
"""24h dedup interface. WAL + busy_timeout=120000."""
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
self._db_path = Path(db_path)
|
||||
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._init_schema()
|
||||
|
||||
@contextmanager
|
||||
def _conn(self):
|
||||
conn = sqlite3.connect(self._db_path, timeout=120.0)
|
||||
try:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=120000")
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _init_schema(self) -> None:
|
||||
with self._conn() as conn:
|
||||
conn.executescript(_SCHEMA)
|
||||
conn.commit()
|
||||
|
||||
def is_recent(self, ticker: str, action: str, within_hours: int = 24) -> bool:
|
||||
threshold_dt = datetime.fromisoformat(_now_iso()) - timedelta(hours=within_hours)
|
||||
threshold_iso = threshold_dt.isoformat()
|
||||
with self._conn() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT last_sent FROM signal_dedup WHERE ticker = ? AND action = ?",
|
||||
(ticker, action),
|
||||
).fetchone()
|
||||
return row is not None and row[0] >= threshold_iso
|
||||
|
||||
def record(self, ticker: str, action: str, confidence: float) -> None:
|
||||
with self._conn() as conn:
|
||||
conn.execute(
|
||||
"""INSERT INTO signal_dedup (ticker, action, last_sent, confidence)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT (ticker, action) DO UPDATE
|
||||
SET last_sent = excluded.last_sent,
|
||||
confidence = excluded.confidence""",
|
||||
(ticker, action, _now_iso(), confidence),
|
||||
)
|
||||
conn.commit()
|
||||
99
ai_trade/scheduler.py
Normal file
99
ai_trade/scheduler.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Polling scheduler — 시간대별 분기 + 휴장일 처리."""
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta, time
|
||||
from pathlib import Path
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KST = ZoneInfo("Asia/Seoul")
|
||||
_HOLIDAYS_PATH = Path(__file__).parent / "holidays.json"
|
||||
_HOLIDAYS: set[str] = set(json.loads(_HOLIDAYS_PATH.read_text(encoding="utf-8")))
|
||||
|
||||
# Market windows (정규장)
|
||||
_PRE_OPEN = time(7, 0)
|
||||
_OPEN = time(9, 0)
|
||||
_CLOSE = time(15, 30)
|
||||
_POST_END = time(20, 0)
|
||||
|
||||
# NXT windows (시간외)
|
||||
_NXT_PRE_END = time(23, 30)
|
||||
_NXT_POST_OPEN = time(4, 30)
|
||||
# 23:30 - 04:30 (dead zone) skip
|
||||
|
||||
|
||||
def _is_market_day(now: datetime) -> bool:
|
||||
"""평일 + 휴장일 아닌 날."""
|
||||
if now.weekday() >= 5: # Sat/Sun
|
||||
return False
|
||||
return now.strftime("%Y-%m-%d") not in _HOLIDAYS
|
||||
|
||||
|
||||
def _is_polling_window(now: datetime) -> bool:
|
||||
"""폴링 윈도우: 07:00-23:30 + 04:30-07:00."""
|
||||
t = now.time()
|
||||
return (
|
||||
(_PRE_OPEN <= t < _NXT_PRE_END)
|
||||
or (_NXT_POST_OPEN <= t < _PRE_OPEN)
|
||||
)
|
||||
|
||||
|
||||
def _next_interval(now: datetime) -> float:
|
||||
"""다음 폴링까지 sleep 초수."""
|
||||
if not _is_market_day(now):
|
||||
return _seconds_until_next_market_open(now)
|
||||
|
||||
t = now.time()
|
||||
if _PRE_OPEN <= t < _OPEN:
|
||||
return 300.0 # 장전 5분
|
||||
elif _OPEN <= t < _CLOSE:
|
||||
return 60.0 # 장중 1분
|
||||
elif _CLOSE <= t < _POST_END:
|
||||
return 300.0 # 장후 5분
|
||||
elif _POST_END <= t < _NXT_PRE_END:
|
||||
return 300.0 # NXT 야간 5분
|
||||
elif _NXT_POST_OPEN <= t < _PRE_OPEN:
|
||||
return 300.0 # NXT 새벽 5분
|
||||
else:
|
||||
# Dead zone (23:30-04:30) — wait until next 04:30
|
||||
return _seconds_until_nxt_or_market_open(now)
|
||||
|
||||
|
||||
def _seconds_until_nxt_or_market_open(now: datetime) -> float:
|
||||
"""다음 04:30 (NXT 새벽 start) 까지 초수. 휴장일은 다음 영업일 07:00."""
|
||||
candidate = now.replace(hour=4, minute=30, second=0, microsecond=0)
|
||||
if candidate <= now:
|
||||
candidate += timedelta(days=1)
|
||||
|
||||
for _ in range(14):
|
||||
if _is_market_day(candidate):
|
||||
return (candidate - now).total_seconds()
|
||||
candidate += timedelta(days=1)
|
||||
|
||||
logger.warning("could not find next market day within 14 days")
|
||||
return 86400.0
|
||||
|
||||
|
||||
def _is_post_close_trigger(now: datetime) -> bool:
|
||||
"""16:00 KST ±1분 (post-close cycle 트리거). 평일/영업일만."""
|
||||
if not _is_market_day(now):
|
||||
return False
|
||||
t = now.time()
|
||||
return time(16, 0) <= t < time(16, 1)
|
||||
|
||||
|
||||
def _seconds_until_next_market_open(now: datetime) -> float:
|
||||
"""다음 영업일의 07:00 KST 까지 초수 (휴장일/주말용)."""
|
||||
candidate = now.replace(hour=7, minute=0, second=0, microsecond=0)
|
||||
if candidate <= now:
|
||||
candidate += timedelta(days=1)
|
||||
|
||||
for _ in range(14): # safety bound (max 2 weeks of holidays)
|
||||
if _is_market_day(candidate):
|
||||
return (candidate - now).total_seconds()
|
||||
candidate += timedelta(days=1)
|
||||
|
||||
logger.warning("could not find next market day within 14 days")
|
||||
return 86400.0
|
||||
228
ai_trade/signal_generator.py
Normal file
228
ai_trade/signal_generator.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""Phase 4 — 매수/매도 신호 생성.
|
||||
|
||||
순수 함수 generate_signals(state, dedup, settings). state 를 mutate.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
KST = ZoneInfo("Asia/Seoul")
|
||||
|
||||
MOMENTUM_SCORES = {
|
||||
"strong_up": 1.0,
|
||||
"weak_up": 0.7,
|
||||
"neutral": 0.5,
|
||||
"weak_down": 0.3,
|
||||
"strong_down": 0.0,
|
||||
}
|
||||
|
||||
|
||||
def generate_signals(state, dedup, settings) -> None:
|
||||
"""Phase 4 entry — state-mutating. Evaluation order: sell first (priority), then buy. A ticker receiving a sell signal in this cycle is excluded from buy evaluation to avoid silent overwrite."""
|
||||
_evaluate_sell_signals(state, dedup, settings)
|
||||
_evaluate_buy_signals(state, dedup, settings)
|
||||
|
||||
|
||||
# ----- 매수 -----
|
||||
|
||||
def _evaluate_buy_signals(state, dedup, settings) -> None:
|
||||
candidates = _buy_candidates(state)
|
||||
for ticker, name, rank in candidates:
|
||||
existing = state.signals.get(ticker)
|
||||
if existing is not None and existing.get("action") == "sell":
|
||||
logger.debug("buy %s skipped: same-cycle sell precedence", ticker)
|
||||
continue
|
||||
if not _check_buy_hard_gate(state, ticker, settings):
|
||||
logger.debug("buy %s skipped: hard gate failed", ticker)
|
||||
continue
|
||||
confidence = _compute_buy_confidence(state, ticker, rank)
|
||||
if confidence <= settings.confidence_threshold:
|
||||
logger.debug("buy %s skipped: confidence %.3f <= %.3f",
|
||||
ticker, confidence, settings.confidence_threshold)
|
||||
continue
|
||||
if dedup.is_recent(ticker, "buy", within_hours=24):
|
||||
logger.debug("buy %s skipped: dedup 24h", ticker)
|
||||
continue
|
||||
state.signals[ticker] = _build_buy_signal(state, ticker, name, rank, confidence)
|
||||
dedup.record(ticker, "buy", confidence=confidence)
|
||||
logger.info("signal emit %s buy conf=%.3f rank=%s", ticker, confidence, rank)
|
||||
|
||||
|
||||
def _buy_candidates(state) -> list[tuple[str, str, int | None]]:
|
||||
"""screener Top-N (rank 1..N) + portfolio (rank=None)."""
|
||||
candidates: list[tuple[str, str, int | None]] = []
|
||||
seen: set[str] = set()
|
||||
if state.screener_preview is not None:
|
||||
for i, item in enumerate(state.screener_preview.get("items", [])):
|
||||
ticker = item.get("ticker")
|
||||
if not ticker or ticker in seen:
|
||||
continue
|
||||
seen.add(ticker)
|
||||
name = item.get("name", ticker)
|
||||
candidates.append((ticker, name, i + 1))
|
||||
if state.portfolio is not None:
|
||||
for h in state.portfolio.get("holdings", []):
|
||||
ticker = h.get("ticker")
|
||||
if not ticker or ticker in seen:
|
||||
continue
|
||||
seen.add(ticker)
|
||||
candidates.append((ticker, h.get("name", ticker), None))
|
||||
return candidates
|
||||
|
||||
|
||||
def _check_buy_hard_gate(state, ticker: str, settings) -> bool:
|
||||
pred = state.chronos_predictions.get(ticker)
|
||||
if pred is None or pred.get("median", 0) <= 0:
|
||||
return False
|
||||
spread = pred.get("q90", 0) - pred.get("q10", 0)
|
||||
if spread >= settings.chronos_spread_threshold:
|
||||
return False
|
||||
momentum = state.minute_momentum.get(ticker)
|
||||
if momentum != settings.min_momentum_for_buy:
|
||||
return False
|
||||
ap = state.asking_price.get(ticker)
|
||||
if ap is None or ap.get("bid_ratio", 0) < settings.asking_bid_ratio_threshold:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _compute_buy_confidence(state, ticker: str, rank: int | None) -> float:
|
||||
pred = state.chronos_predictions[ticker]
|
||||
chronos_conf = pred["conf"]
|
||||
minute_score = MOMENTUM_SCORES.get(state.minute_momentum.get(ticker, "neutral"), 0.5)
|
||||
screener_norm = max(0.0, 1 - (rank - 1) / 20) if rank is not None else 0.0
|
||||
return chronos_conf * 0.5 + minute_score * 0.3 + screener_norm * 0.2
|
||||
|
||||
|
||||
def _build_buy_signal(state, ticker: str, name: str, rank: int | None, confidence: float) -> dict:
|
||||
ap = state.asking_price[ticker]
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"name": name,
|
||||
"action": "buy",
|
||||
"confidence_webai": confidence,
|
||||
"current_price": ap["current_price"],
|
||||
"avg_price": None,
|
||||
"pnl_pct": None,
|
||||
"context": _build_context(state, ticker, rank),
|
||||
"as_of": datetime.now(KST).isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# ----- 매도 -----
|
||||
|
||||
def _evaluate_sell_signals(state, dedup, settings) -> None:
|
||||
if state.portfolio is None:
|
||||
return
|
||||
for holding in state.portfolio.get("holdings", []):
|
||||
ticker = holding.get("ticker")
|
||||
if not ticker:
|
||||
continue
|
||||
sell = _try_stop_loss(state, holding, settings)
|
||||
if sell is None:
|
||||
sell = _try_anomaly(state, holding, settings)
|
||||
if sell is None:
|
||||
sell = _try_take_profit(state, holding, settings)
|
||||
if sell is None:
|
||||
continue
|
||||
if dedup.is_recent(ticker, "sell", within_hours=24):
|
||||
logger.debug("sell %s skipped: dedup 24h", ticker)
|
||||
continue
|
||||
state.signals[ticker] = sell
|
||||
dedup.record(ticker, "sell", confidence=sell["confidence_webai"])
|
||||
logger.info("signal emit %s sell conf=%.3f reason=%s",
|
||||
ticker, sell["confidence_webai"],
|
||||
sell.get("context", {}).get("sell_reason"))
|
||||
|
||||
|
||||
def _try_stop_loss(state, holding: dict, settings) -> dict | None:
|
||||
pnl = holding.get("pnl_pct")
|
||||
if pnl is None or pnl >= settings.stop_loss_pct:
|
||||
return None
|
||||
return _build_sell_signal(state, holding, confidence=1.0, reason="stop_loss")
|
||||
|
||||
|
||||
def _try_take_profit(state, holding: dict, settings) -> dict | None:
|
||||
pnl = holding.get("pnl_pct")
|
||||
if pnl is None or pnl <= settings.take_profit_pct:
|
||||
return None
|
||||
return _build_sell_signal(state, holding, confidence=0.6, reason="take_profit")
|
||||
|
||||
|
||||
def _try_anomaly(state, holding: dict, settings) -> dict | None:
|
||||
ticker = holding["ticker"]
|
||||
pred = state.chronos_predictions.get(ticker)
|
||||
if pred is None or pred["median"] >= -0.01:
|
||||
return None
|
||||
momentum = state.minute_momentum.get(ticker)
|
||||
if momentum != "strong_down":
|
||||
return None
|
||||
ap = state.asking_price.get(ticker)
|
||||
if ap is None:
|
||||
return None
|
||||
if ap["bid_ratio"] > (1 - settings.asking_bid_ratio_threshold):
|
||||
return None
|
||||
minute_score = 1.0 - MOMENTUM_SCORES.get(momentum, 0.5)
|
||||
confidence = pred["conf"] * 0.5 + minute_score * 0.3 + 1.0 * 0.2
|
||||
if confidence <= settings.confidence_threshold:
|
||||
return None
|
||||
return _build_sell_signal(state, holding, confidence=confidence, reason="anomaly")
|
||||
|
||||
|
||||
def _build_sell_signal(state, holding: dict, confidence: float, reason: str) -> dict:
|
||||
ticker = holding["ticker"]
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"name": holding.get("name", ticker),
|
||||
"action": "sell",
|
||||
"confidence_webai": confidence,
|
||||
"current_price": holding.get("current_price"),
|
||||
"avg_price": holding.get("avg_price"),
|
||||
"pnl_pct": holding.get("pnl_pct"),
|
||||
"context": _build_context(state, ticker, rank=None, sell_reason=reason),
|
||||
"as_of": datetime.now(KST).isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# ----- Context -----
|
||||
|
||||
def _build_context(state, ticker: str, rank: int | None, sell_reason: str | None = None) -> dict:
|
||||
pred = state.chronos_predictions.get(ticker) or {}
|
||||
ap = state.asking_price.get(ticker) or {}
|
||||
news_item = _find_news_sentiment(state, ticker)
|
||||
screener_scores = _find_screener_scores(state, ticker)
|
||||
context: dict = {
|
||||
"chronos_pred_1d": pred.get("median"),
|
||||
"chronos_pred_conf": pred.get("conf"),
|
||||
"chronos_q10": pred.get("q10"),
|
||||
"chronos_q90": pred.get("q90"),
|
||||
"screener_rank": rank,
|
||||
"screener_scores": screener_scores,
|
||||
"minute_momentum": state.minute_momentum.get(ticker),
|
||||
"asking_bid_ratio": ap.get("bid_ratio"),
|
||||
"news_sentiment": news_item.get("score") if news_item else None,
|
||||
"news_reason": news_item.get("reason") if news_item else None,
|
||||
}
|
||||
if sell_reason is not None:
|
||||
context["sell_reason"] = sell_reason
|
||||
return context
|
||||
|
||||
|
||||
def _find_news_sentiment(state, ticker: str) -> dict | None:
|
||||
if state.news_sentiment is None:
|
||||
return None
|
||||
for item in state.news_sentiment.get("items", []):
|
||||
if item.get("ticker") == ticker:
|
||||
return item
|
||||
return None
|
||||
|
||||
|
||||
def _find_screener_scores(state, ticker: str) -> dict | None:
|
||||
if state.screener_preview is None:
|
||||
return None
|
||||
for item in state.screener_preview.get("items", []):
|
||||
if item.get("ticker") == ticker:
|
||||
return item.get("scores")
|
||||
return None
|
||||
3
ai_trade/start.bat
Normal file
3
ai_trade/start.bat
Normal file
@@ -0,0 +1,3 @@
|
||||
@echo off
|
||||
cd /d "%~dp0\.."
|
||||
python -m uvicorn ai_trade.main:app --host 0.0.0.0 --port 8001
|
||||
22
ai_trade/state.py
Normal file
22
ai_trade/state.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""PollState — process-wide singleton."""
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class PollState:
|
||||
portfolio: dict | None = None
|
||||
news_sentiment: dict | None = None
|
||||
screener_preview: dict | None = None
|
||||
minute_bars: dict[str, deque] = field(default_factory=dict)
|
||||
asking_price: dict[str, dict] = field(default_factory=dict)
|
||||
# Phase 3b additions
|
||||
daily_ohlcv: dict[str, list[dict]] = field(default_factory=dict)
|
||||
chronos_predictions: dict[str, dict] = field(default_factory=dict)
|
||||
minute_momentum: dict[str, str] = field(default_factory=dict)
|
||||
signals: dict[str, dict] = field(default_factory=dict)
|
||||
last_updated: dict[str, str] = field(default_factory=dict)
|
||||
fetch_errors: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
state = PollState()
|
||||
129
ai_trade/stock_client.py
Normal file
129
ai_trade/stock_client.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Stock API HTTP client — async httpx + retry + memory cache."""
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache TTL by endpoint (seconds).
|
||||
# 2026-05-18 — NAS 인바운드 호출 부담 완화 (Plan-A SP-A1).
|
||||
_TTL = {
|
||||
"portfolio": 180.0, # 3분 (1분 폴링 시 3 폴링당 1회 실제 fetch)
|
||||
"news-sentiment": 600.0, # 10분 (뉴스 sentiment는 자주 안 바뀜)
|
||||
"screener-preview": 300.0, # 5분 (Top-20은 분 단위로 거의 안 바뀜)
|
||||
}
|
||||
|
||||
# Retry policy
|
||||
_MAX_ATTEMPTS = 3
|
||||
_RETRY_STATUSES = {429, 500, 502, 503, 504}
|
||||
|
||||
|
||||
class StockClient:
|
||||
"""stock API wrapper. Async httpx + self-retry + memory cache."""
|
||||
|
||||
def __init__(self, base_url: str, api_key: str, timeout: float = 10.0):
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._api_key = api_key
|
||||
self._client = httpx.AsyncClient(timeout=timeout)
|
||||
# cache: key → (data, timestamp_monotonic)
|
||||
self._cache: dict[str, tuple[Any, float]] = {}
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._client.aclose()
|
||||
|
||||
def cache_size(self) -> int:
|
||||
"""Number of cached endpoint responses (public surface for /health)."""
|
||||
return len(self._cache)
|
||||
|
||||
async def get_portfolio(self) -> dict:
|
||||
return await self._cached_request(
|
||||
"portfolio", "GET", "/api/webai/portfolio"
|
||||
)
|
||||
|
||||
async def get_news_sentiment(self, date: str | None = None) -> dict:
|
||||
path = "/api/webai/news-sentiment"
|
||||
if date is not None:
|
||||
path += f"?date={date}"
|
||||
cache_key = f"news-sentiment:{date or 'latest'}"
|
||||
return await self._cached_request(
|
||||
cache_key, "GET", path, _ttl_key="news-sentiment"
|
||||
)
|
||||
|
||||
async def run_screener_preview(
|
||||
self, weights: dict | None = None, top_n: int = 20
|
||||
) -> dict:
|
||||
body = {"mode": "preview", "top_n": top_n}
|
||||
if weights is not None:
|
||||
body["weights"] = weights
|
||||
return await self._cached_request(
|
||||
"screener-preview",
|
||||
"POST",
|
||||
"/api/stock/screener/run",
|
||||
json=body,
|
||||
_ttl_key="screener-preview",
|
||||
)
|
||||
|
||||
async def _cached_request(
|
||||
self,
|
||||
cache_key: str,
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
_ttl_key: str | None = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
ttl_key = _ttl_key or cache_key
|
||||
ttl = _TTL.get(ttl_key, 60.0)
|
||||
# Fresh cache hit?
|
||||
if cache_key in self._cache:
|
||||
data, ts = self._cache[cache_key]
|
||||
if time.monotonic() - ts < ttl:
|
||||
return data
|
||||
|
||||
# Fetch (with retry)
|
||||
try:
|
||||
data = await self._request_with_retry(method, path, **kwargs)
|
||||
self._cache[cache_key] = (data, time.monotonic())
|
||||
return data
|
||||
except httpx.HTTPError:
|
||||
# Stale fallback: serve old cached value if exists
|
||||
if cache_key in self._cache:
|
||||
stale_data, stale_ts = self._cache[cache_key]
|
||||
age = time.monotonic() - stale_ts
|
||||
logger.warning(
|
||||
"serving stale cache for %s (age=%.1fs)", cache_key, age
|
||||
)
|
||||
return stale_data
|
||||
raise
|
||||
|
||||
async def _request_with_retry(self, method: str, path: str, **kwargs) -> dict:
|
||||
url = f"{self._base_url}{path}"
|
||||
headers = self._auth_headers()
|
||||
for attempt in range(_MAX_ATTEMPTS):
|
||||
try:
|
||||
response = await self._client.request(
|
||||
method, url, headers=headers, **kwargs
|
||||
)
|
||||
if response.status_code in _RETRY_STATUSES:
|
||||
if attempt < _MAX_ATTEMPTS - 1:
|
||||
await asyncio.sleep(2**attempt)
|
||||
continue
|
||||
response.raise_for_status()
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.TimeoutException:
|
||||
if attempt < _MAX_ATTEMPTS - 1:
|
||||
await asyncio.sleep(2**attempt)
|
||||
continue
|
||||
raise
|
||||
except httpx.HTTPStatusError:
|
||||
raise
|
||||
# Unreachable: every iteration either returns or raises
|
||||
raise RuntimeError("_request_with_retry exhausted loop without raising")
|
||||
|
||||
def _auth_headers(self) -> dict[str, str]:
|
||||
return {"X-WebAI-Key": self._api_key}
|
||||
0
ai_trade/tests/__init__.py
Normal file
0
ai_trade/tests/__init__.py
Normal file
BIN
ai_trade/tests/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
ai_trade/tests/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
ai_trade/tests/__pycache__/conftest.cpython-312-pytest-9.0.2.pyc
Normal file
BIN
ai_trade/tests/__pycache__/conftest.cpython-312-pytest-9.0.2.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
18
ai_trade/tests/conftest.py
Normal file
18
ai_trade/tests/conftest.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Pytest fixtures for ai_trade tests."""
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_dedup_db(tmp_path) -> Path:
|
||||
"""SQLite 단위 테스트용 임시 DB path."""
|
||||
return tmp_path / "test_ai_trade.db"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stock_api():
|
||||
"""respx 로 stock API mock. base_url 은 테스트마다 임의."""
|
||||
with respx.mock(base_url="https://test.stock.local", assert_all_called=False) as mock:
|
||||
yield mock
|
||||
92
ai_trade/tests/test_chronos_predictor.py
Normal file
92
ai_trade/tests/test_chronos_predictor.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Tests for ChronosPredictor (model mock)."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pipeline():
|
||||
"""Mock BaseChronosPipeline.from_pretrained returning a mock pipeline object."""
|
||||
with patch("chronos.BaseChronosPipeline") as cls:
|
||||
cls.__name__ = "BaseChronosPipeline"
|
||||
instance = MagicMock()
|
||||
# ChronosBolt API: predict_quantiles returns (quantiles_tensor, mean_tensor)
|
||||
# Modern (predict_quantiles) branch will be used since hasattr(MagicMock, "predict_quantiles") is True.
|
||||
cls.from_pretrained.return_value = instance
|
||||
yield instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_torch_cpu():
|
||||
with patch("torch.cuda.is_available", return_value=False):
|
||||
yield
|
||||
|
||||
|
||||
def _daily_ohlcv(close_seq):
|
||||
return [{"datetime": f"2026-05-{i+1:02d}", "open": c, "high": c, "low": c,
|
||||
"close": c, "volume": 1000} for i, c in enumerate(close_seq)]
|
||||
|
||||
|
||||
def _mk_quantiles_tensor(q10_price: float, q50_price: float, q90_price: float):
|
||||
"""Helper: build predict_quantiles return tensor shape [1, 1, 3]."""
|
||||
import torch
|
||||
return torch.tensor([[[q10_price, q50_price, q90_price]]], dtype=torch.float32)
|
||||
|
||||
|
||||
def test_predict_batch_returns_prediction_dict(mock_pipeline, mock_torch_cpu):
|
||||
"""mock predict_quantiles → dict[ticker, ChronosPrediction]. last_close=100, q50=102 → median≈+2%."""
|
||||
quantiles = _mk_quantiles_tensor(101.5, 102.0, 102.5) # narrow around 102
|
||||
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
|
||||
|
||||
from ai_trade.chronos_predictor import ChronosPredictor, ChronosPrediction
|
||||
predictor = ChronosPredictor(model_name="mock-model")
|
||||
daily = {"005930": _daily_ohlcv([100] * 60)}
|
||||
result = predictor.predict_batch(daily)
|
||||
assert "005930" in result
|
||||
pred = result["005930"]
|
||||
assert isinstance(pred, ChronosPrediction)
|
||||
assert abs(pred.median - 0.02) < 0.001
|
||||
|
||||
|
||||
def test_conf_high_when_distribution_narrow(mock_pipeline, mock_torch_cpu):
|
||||
"""좁은 distribution (q90-q10 작음, median 0 아님) → conf ≈ 1."""
|
||||
# last_close=100, q10=101.99, q50=102.00, q90=102.01
|
||||
# returns: q10=0.0199, q50=0.02, q90=0.0201
|
||||
# spread = (0.0201 - 0.0199) / max(0.02, 0.001) = 0.0002/0.02 = 0.01 → conf = 1 - 0.005 = 0.995
|
||||
quantiles = _mk_quantiles_tensor(101.99, 102.0, 102.01)
|
||||
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
|
||||
|
||||
from ai_trade.chronos_predictor import ChronosPredictor
|
||||
predictor = ChronosPredictor(model_name="mock-model")
|
||||
daily = {"005930": _daily_ohlcv([100] * 60)}
|
||||
result = predictor.predict_batch(daily)
|
||||
assert result["005930"].conf > 0.8
|
||||
|
||||
|
||||
def test_conf_low_when_distribution_wide(mock_pipeline, mock_torch_cpu):
|
||||
"""넓은 distribution → conf ≈ 0."""
|
||||
# last_close=100, q10=70, q50=100, q90=130
|
||||
# returns: q10=-0.3, q50=0.0, q90=0.3
|
||||
# spread = (0.3 - (-0.3)) / max(0.0, 0.001) = 0.6 / 0.001 = 600 → conf = max(0, 1 - 300) = 0
|
||||
quantiles = _mk_quantiles_tensor(70.0, 100.0, 130.0)
|
||||
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
|
||||
|
||||
from ai_trade.chronos_predictor import ChronosPredictor
|
||||
predictor = ChronosPredictor(model_name="mock-model")
|
||||
daily = {"005930": _daily_ohlcv([100] * 60)}
|
||||
result = predictor.predict_batch(daily)
|
||||
assert result["005930"].conf < 0.3
|
||||
|
||||
|
||||
def test_return_computed_from_price_relative_to_last_close(mock_pipeline, mock_torch_cpu):
|
||||
"""price 예측 → last_close 대비 return 변환. last_close=100, q50=110 → return ≈ +10%."""
|
||||
quantiles = _mk_quantiles_tensor(109.0, 110.0, 111.0)
|
||||
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
|
||||
|
||||
from ai_trade.chronos_predictor import ChronosPredictor
|
||||
predictor = ChronosPredictor(model_name="mock-model")
|
||||
# last close = 100
|
||||
daily = {"005930": _daily_ohlcv(list(range(41, 101)))} # last = 100
|
||||
result = predictor.predict_batch(daily)
|
||||
assert abs(result["005930"].median - 0.10) < 0.001
|
||||
161
ai_trade/tests/test_kis_client.py
Normal file
161
ai_trade/tests/test_kis_client.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Tests for KISClient (REST)."""
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from ai_trade.kis_client import KISClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_v1_token(tmp_path):
|
||||
"""V1 토큰 파일 fixture."""
|
||||
token_file = tmp_path / "kis_token.json"
|
||||
token_file.write_text(json.dumps({
|
||||
"access_token": "test-kis-token-abc123",
|
||||
"token_expired": "2099-12-31 23:59:59",
|
||||
}))
|
||||
return token_file
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kis_client_factory(fake_v1_token):
|
||||
def _make():
|
||||
return KISClient(
|
||||
app_key="test-app-key",
|
||||
app_secret="test-app-secret",
|
||||
account="50000000-01",
|
||||
is_virtual=True,
|
||||
v1_token_path=fake_v1_token,
|
||||
)
|
||||
return _make
|
||||
|
||||
|
||||
@respx.mock
|
||||
async def test_get_minute_ohlcv_normal_returns_30_bars(kis_client_factory):
|
||||
"""정상 200 → 30개 분봉 list 반환."""
|
||||
sample_output2 = [
|
||||
{
|
||||
"stck_bsop_date": "20260518",
|
||||
"stck_cntg_hour": f"09{m:02d}00",
|
||||
"stck_oprc": "78000", "stck_hgpr": "78500",
|
||||
"stck_lwpr": "77800", "stck_prpr": "78300",
|
||||
"cntg_vol": "12345",
|
||||
}
|
||||
for m in range(30) # 9:00-9:29 = 30 bars
|
||||
]
|
||||
respx.get(
|
||||
"https://openapivts.koreainvestment.com:29443/uapi/domestic-stock/v1/quotations/inquire-time-itemchartprice"
|
||||
).mock(
|
||||
return_value=httpx.Response(200, json={"output2": sample_output2})
|
||||
)
|
||||
|
||||
client = kis_client_factory()
|
||||
try:
|
||||
bars = await client.get_minute_ohlcv("005930")
|
||||
assert len(bars) == 30
|
||||
assert bars[0]["close"] == 78300
|
||||
assert "datetime" in bars[0]
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
@respx.mock
|
||||
async def test_get_minute_ohlcv_429_retry_then_success(kis_client_factory, monkeypatch):
|
||||
"""429 → exponential backoff → 200."""
|
||||
sleep_calls = []
|
||||
async def fake_sleep(s): sleep_calls.append(s)
|
||||
monkeypatch.setattr("asyncio.sleep", fake_sleep)
|
||||
|
||||
respx.get(
|
||||
"https://openapivts.koreainvestment.com:29443/uapi/domestic-stock/v1/quotations/inquire-time-itemchartprice"
|
||||
).mock(side_effect=[
|
||||
httpx.Response(429, text="rate limit"),
|
||||
httpx.Response(200, json={"output2": []}),
|
||||
])
|
||||
client = kis_client_factory()
|
||||
try:
|
||||
result = await client.get_minute_ohlcv("005930")
|
||||
assert result == []
|
||||
assert 1 in sleep_calls
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
@respx.mock
|
||||
async def test_get_minute_ohlcv_uses_v1_token(kis_client_factory, fake_v1_token):
|
||||
"""KIS 호출 헤더에 V1 토큰 파일의 access_token 사용."""
|
||||
route = respx.get(
|
||||
"https://openapivts.koreainvestment.com:29443/uapi/domestic-stock/v1/quotations/inquire-time-itemchartprice"
|
||||
).mock(return_value=httpx.Response(200, json={"output2": []}))
|
||||
|
||||
client = kis_client_factory()
|
||||
try:
|
||||
await client.get_minute_ohlcv("005930")
|
||||
assert route.called
|
||||
req = route.calls.last.request
|
||||
# check authorization header contains the V1 token
|
||||
auth = req.headers.get("authorization", "")
|
||||
assert "test-kis-token-abc123" in auth
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
@respx.mock
|
||||
async def test_get_asking_price_computes_bid_ratio(kis_client_factory):
|
||||
"""호가 응답 → bid_total/(bid+ask) bid_ratio 계산."""
|
||||
respx.get(
|
||||
"https://openapivts.koreainvestment.com:29443/uapi/domestic-stock/v1/quotations/inquire-asking-price-exp-ccn"
|
||||
).mock(return_value=httpx.Response(200, json={
|
||||
"output1": {
|
||||
"total_bidp_rsqn": "600",
|
||||
"total_askp_rsqn": "400",
|
||||
"stck_prpr": "78500",
|
||||
}
|
||||
}))
|
||||
|
||||
client = kis_client_factory()
|
||||
try:
|
||||
data = await client.get_asking_price("005930")
|
||||
assert data["bid_total"] == 600
|
||||
assert data["ask_total"] == 400
|
||||
assert abs(data["bid_ratio"] - 0.6) < 1e-9
|
||||
assert data["current_price"] == 78500
|
||||
assert "as_of" in data
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
@respx.mock
|
||||
async def test_get_daily_ohlcv_returns_60_bars(kis_client_factory):
|
||||
"""KIS daily endpoint returns 60 ascending bars after parsing."""
|
||||
# Build 60 KIS-format daily bars (descending dates as KIS does)
|
||||
sample_output2 = []
|
||||
for i in range(60):
|
||||
# Generate a fake date 60 days ago, descending
|
||||
day = 60 - i
|
||||
sample_output2.append({
|
||||
"stck_bsop_date": f"2026{(((day-1)//30)+1):02d}{(((day-1)%30)+1):02d}",
|
||||
"stck_oprc": "78000", "stck_hgpr": "78500",
|
||||
"stck_lwpr": "77800", "stck_clpr": str(78000 + i),
|
||||
"acml_vol": "12345",
|
||||
})
|
||||
|
||||
respx.get(
|
||||
"https://openapivts.koreainvestment.com:29443/uapi/domestic-stock/v1/quotations/inquire-daily-itemchartprice"
|
||||
).mock(return_value=httpx.Response(200, json={"output2": sample_output2}))
|
||||
|
||||
client = kis_client_factory()
|
||||
try:
|
||||
bars = await client.get_daily_ohlcv("005930", days=60)
|
||||
# KIS returns descending; client reverses to ascending
|
||||
assert len(bars) == 60
|
||||
# Ascending order: first item has smaller datetime than last
|
||||
assert bars[0]["datetime"] < bars[-1]["datetime"]
|
||||
assert isinstance(bars[0]["open"], int)
|
||||
assert isinstance(bars[0]["close"], int)
|
||||
assert "datetime" in bars[0]
|
||||
finally:
|
||||
await client.close()
|
||||
94
ai_trade/tests/test_kis_websocket.py
Normal file
94
ai_trade/tests/test_kis_websocket.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Tests for KISWebSocket."""
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from ai_trade.kis_websocket import KISWebSocket
|
||||
|
||||
|
||||
BASE_REST = "https://openapivts.koreainvestment.com:29443"
|
||||
|
||||
|
||||
@respx.mock
|
||||
async def test_fetch_approval_key_via_oauth_endpoint():
|
||||
"""POST /oauth2/Approval → approval_key 추출."""
|
||||
respx.post(f"{BASE_REST}/oauth2/Approval").mock(
|
||||
return_value=httpx.Response(200, json={"approval_key": "test-approval-key-xyz"})
|
||||
)
|
||||
ws = KISWebSocket(app_key="k", app_secret="s", is_virtual=True)
|
||||
key = await ws._fetch_approval_key()
|
||||
assert key == "test-approval-key-xyz"
|
||||
assert ws._approval_key == "test-approval-key-xyz"
|
||||
|
||||
|
||||
async def test_subscribe_sends_h0stasp0_message():
|
||||
"""subscribe() → WebSocket 으로 H0STASP0 구독 메시지 전송."""
|
||||
sent_messages = []
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.send = AsyncMock(side_effect=lambda m: sent_messages.append(m))
|
||||
|
||||
ws = KISWebSocket(app_key="k", app_secret="s", is_virtual=True)
|
||||
ws._approval_key = "test-key"
|
||||
ws._ws = mock_ws
|
||||
await ws.subscribe("005930")
|
||||
assert ws._subscriptions == {"005930"}
|
||||
assert len(sent_messages) == 1
|
||||
msg = json.loads(sent_messages[0])
|
||||
assert msg["header"]["tr_type"] == "1" # subscribe
|
||||
assert msg["body"]["input"]["tr_id"] == "H0STASP0"
|
||||
assert msg["body"]["input"]["tr_key"] == "005930"
|
||||
|
||||
|
||||
def test_parse_asking_price_extracts_bid_ask_totals():
|
||||
"""KIS raw '0|H0STASP0|001|...' → (ticker, dict).
|
||||
|
||||
KIS 호가 메시지 형식 — KIS 공식 spec 의 정확한 필드 인덱스 운영 검증 필요.
|
||||
본 테스트는 implementer 의 _parse_asking_price 구현 인덱스에 맞춰서 sample 작성.
|
||||
"""
|
||||
ws = KISWebSocket(app_key="k", app_secret="s", is_virtual=True)
|
||||
# Build a sample raw message — implementer 가 _ASKING_TOTAL_BID/ASK 인덱스에
|
||||
# 맞춰서 필드 배치하면 됨. 예: 마지막 2개 필드를 bid_total / ask_total 로.
|
||||
fields = ["005930", "091500", "78500"] # ticker, time, current_price
|
||||
fields.extend(["0"] * 40) # padding (KIS 의 실 필드 수 ~50개)
|
||||
fields.append("400") # ask_total
|
||||
fields.append("600") # bid_total
|
||||
raw = f"0|H0STASP0|001|{'^'.join(fields)}"
|
||||
|
||||
result = ws._parse_asking_price(raw)
|
||||
assert result is not None, "parse_asking_price returned None"
|
||||
ticker, data = result
|
||||
assert ticker == "005930"
|
||||
assert "bid_total" in data
|
||||
assert "ask_total" in data
|
||||
assert "bid_ratio" in data
|
||||
assert "current_price" in data
|
||||
# bid_total=600, ask_total=400, bid_ratio=0.6
|
||||
assert data["bid_total"] == 600
|
||||
assert data["ask_total"] == 400
|
||||
assert abs(data["bid_ratio"] - 0.6) < 1e-9
|
||||
|
||||
|
||||
async def test_reconnect_on_disconnect_with_backoff(monkeypatch):
|
||||
"""연결 끊김 → exponential backoff retry. _connect_with_backoff() 검증."""
|
||||
sleep_calls = []
|
||||
async def fake_sleep(s): sleep_calls.append(s)
|
||||
monkeypatch.setattr("asyncio.sleep", fake_sleep)
|
||||
|
||||
ws = KISWebSocket(app_key="k", app_secret="s", is_virtual=True)
|
||||
# Mock _connect to fail twice then succeed
|
||||
call_count = [0]
|
||||
async def fake_connect():
|
||||
call_count[0] += 1
|
||||
if call_count[0] < 3:
|
||||
raise ConnectionError("fake disconnect")
|
||||
return AsyncMock()
|
||||
monkeypatch.setattr(ws, "_connect", fake_connect)
|
||||
|
||||
result = await ws._connect_with_backoff()
|
||||
assert call_count[0] == 3 # 2 fails + 1 success
|
||||
# exponential 1s, 2s
|
||||
assert sleep_calls[:2] == [1, 2]
|
||||
62
ai_trade/tests/test_main.py
Normal file
62
ai_trade/tests/test_main.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Tests for FastAPI main app."""
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
def test_health_endpoint_returns_status_online(monkeypatch):
|
||||
monkeypatch.setenv("STOCK_API_URL", "https://test.stock.local")
|
||||
monkeypatch.setenv("WEBAI_API_KEY", "test-secret")
|
||||
# Reload modules so they pick up the new env
|
||||
import importlib
|
||||
from ai_trade import config as cfg
|
||||
importlib.reload(cfg)
|
||||
from ai_trade import main as main_mod
|
||||
importlib.reload(main_mod)
|
||||
with TestClient(main_mod.app) as client:
|
||||
r = client.get("/health")
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
assert body["status"] == "online"
|
||||
assert body["stock_api_url"] == "https://test.stock.local"
|
||||
|
||||
|
||||
def test_startup_warns_if_webai_api_key_missing(monkeypatch, caplog):
|
||||
# Use setenv with empty string + no-op load_dotenv to defeat .env re-read on reload
|
||||
monkeypatch.setattr("ai_trade.config.load_dotenv", lambda *a, **k: None)
|
||||
monkeypatch.setenv("WEBAI_API_KEY", "")
|
||||
monkeypatch.setenv("STOCK_API_URL", "https://test.stock.local")
|
||||
import importlib
|
||||
from ai_trade import config as cfg
|
||||
importlib.reload(cfg)
|
||||
# After reload, load_dotenv reference is fresh — re-patch
|
||||
monkeypatch.setattr("ai_trade.config.load_dotenv", lambda *a, **k: None)
|
||||
from ai_trade import main as main_mod
|
||||
importlib.reload(main_mod)
|
||||
with caplog.at_level(logging.WARNING, logger="ai_trade.main"):
|
||||
with TestClient(main_mod.app) as client:
|
||||
client.get("/health")
|
||||
assert any("WEBAI_API_KEY" in rec.message for rec in caplog.records)
|
||||
|
||||
|
||||
def test_startup_warns_if_kis_app_key_missing(monkeypatch, caplog):
|
||||
"""KIS app_key 미설정 시 startup WARNING (KIS 호출 disabled) — V1 패턴."""
|
||||
monkeypatch.setattr("ai_trade.config.load_dotenv", lambda *a, **k: None)
|
||||
monkeypatch.setenv("STOCK_API_URL", "https://test.stock.local")
|
||||
monkeypatch.setenv("WEBAI_API_KEY", "test-secret")
|
||||
# V1 pattern: kis_env_type=virtual, both virtual keys empty
|
||||
monkeypatch.setenv("KIS_ENV_TYPE", "virtual")
|
||||
monkeypatch.setenv("KIS_VIRTUAL_APP_KEY", "")
|
||||
monkeypatch.setenv("KIS_REAL_APP_KEY", "")
|
||||
|
||||
import importlib
|
||||
from ai_trade import config as cfg
|
||||
importlib.reload(cfg)
|
||||
monkeypatch.setattr("ai_trade.config.load_dotenv", lambda *a, **k: None)
|
||||
from ai_trade import main as main_mod
|
||||
importlib.reload(main_mod)
|
||||
with caplog.at_level(logging.WARNING, logger="ai_trade.main"):
|
||||
with TestClient(main_mod.app) as client:
|
||||
client.get("/health")
|
||||
assert any("KIS" in rec.message and "app_key" in rec.message.lower() for rec in caplog.records)
|
||||
92
ai_trade/tests/test_momentum_classifier.py
Normal file
92
ai_trade/tests/test_momentum_classifier.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Tests for minute momentum classifier."""
|
||||
from collections import deque
|
||||
|
||||
from ai_trade.momentum_classifier import (
|
||||
aggregate_1min_to_5min, classify_minute_momentum,
|
||||
STRONG_UP, WEAK_UP, NEUTRAL, WEAK_DOWN, STRONG_DOWN,
|
||||
)
|
||||
|
||||
|
||||
def _bar(open_, high, low, close, volume):
|
||||
return {
|
||||
"datetime": "2026-05-18T09:00:00+09:00",
|
||||
"open": open_, "high": high, "low": low, "close": close, "volume": volume,
|
||||
}
|
||||
|
||||
|
||||
def _make_chunks(num_chunks_up: int, num_chunks_total: int, base_vol: int = 1000):
|
||||
"""num_chunks_total 개의 5-bar 청크. num_chunks_up 청크는 양봉, 나머지는 음봉.
|
||||
각 청크는 5개 1분봉. 거래량 = base_vol per bar.
|
||||
"""
|
||||
bars = []
|
||||
for i in range(num_chunks_total):
|
||||
is_up = i < num_chunks_up
|
||||
o, c = (100, 110) if is_up else (110, 100)
|
||||
for j in range(5):
|
||||
bars.append(_bar(o, max(o, c) + 5, min(o, c) - 5, c, base_vol))
|
||||
return bars
|
||||
|
||||
|
||||
def test_strong_up_5_consecutive_green_with_high_volume():
|
||||
"""직전 5개 5분봉 모두 양봉 + 거래량 1.5x → STRONG_UP."""
|
||||
# 60분 (12 5분봉) 데이터: 7 normal + 5 high-vol up
|
||||
older = _make_chunks(num_chunks_up=3, num_chunks_total=7, base_vol=1000)
|
||||
recent = _make_chunks(num_chunks_up=5, num_chunks_total=5, base_vol=2500)
|
||||
minute_bars = deque(older + recent, maxlen=60)
|
||||
assert classify_minute_momentum(minute_bars) == STRONG_UP
|
||||
|
||||
|
||||
def test_weak_up_3of5_green_normal_volume():
|
||||
"""직전 5개 5분봉 중 3-4개 양봉 + 거래량 ≥ 1.0x → WEAK_UP."""
|
||||
older = _make_chunks(num_chunks_up=3, num_chunks_total=7, base_vol=1000)
|
||||
# 5 chunks: 3 up + 2 down, normal vol
|
||||
recent_up = _make_chunks(num_chunks_up=3, num_chunks_total=3, base_vol=1000)
|
||||
recent_down = _make_chunks(num_chunks_up=0, num_chunks_total=2, base_vol=1000)
|
||||
minute_bars = deque(older + recent_up + recent_down, maxlen=60)
|
||||
assert classify_minute_momentum(minute_bars) == WEAK_UP
|
||||
|
||||
|
||||
def test_neutral_mixed():
|
||||
"""up_count=2, vol normal → NEUTRAL (rule 미해당)."""
|
||||
older = _make_chunks(num_chunks_up=3, num_chunks_total=7, base_vol=1000)
|
||||
recent_up = _make_chunks(num_chunks_up=2, num_chunks_total=2, base_vol=1000)
|
||||
recent_down = _make_chunks(num_chunks_up=0, num_chunks_total=3, base_vol=1000)
|
||||
minute_bars = deque(older + recent_up + recent_down, maxlen=60)
|
||||
# up_count=2, vol_mult=1.0 → 어느 분기 조건도 만족 안 함 → NEUTRAL
|
||||
assert classify_minute_momentum(minute_bars) == NEUTRAL
|
||||
|
||||
|
||||
def test_weak_down_low_green_low_volume():
|
||||
"""up_count <= 2 + vol < 1.0 → WEAK_DOWN."""
|
||||
older = _make_chunks(num_chunks_up=3, num_chunks_total=7, base_vol=1000)
|
||||
recent_up = _make_chunks(num_chunks_up=1, num_chunks_total=1, base_vol=500)
|
||||
recent_down = _make_chunks(num_chunks_up=0, num_chunks_total=4, base_vol=500)
|
||||
minute_bars = deque(older + recent_up + recent_down, maxlen=60)
|
||||
# recent 5 chunks avg vol = 500, long 12 avg ≈ (7*1000 + 5*500) / 12 ≈ 791 → vol_mult ≈ 0.63
|
||||
assert classify_minute_momentum(minute_bars) == WEAK_DOWN
|
||||
|
||||
|
||||
def test_strong_down_5_consecutive_red_high_volume():
|
||||
"""직전 5개 5분봉 모두 음봉 + 거래량 1.5x → STRONG_DOWN."""
|
||||
older = _make_chunks(num_chunks_up=3, num_chunks_total=7, base_vol=1000)
|
||||
recent = _make_chunks(num_chunks_up=0, num_chunks_total=5, base_vol=2500)
|
||||
minute_bars = deque(older + recent, maxlen=60)
|
||||
assert classify_minute_momentum(minute_bars) == STRONG_DOWN
|
||||
|
||||
|
||||
def test_aggregate_1min_to_5min_correctness():
|
||||
"""5 1분봉 → 1개 5분봉 — open/close/high/low/volume 정확."""
|
||||
bars = [
|
||||
_bar(100, 105, 99, 102, 1000),
|
||||
_bar(102, 108, 101, 107, 1500),
|
||||
_bar(107, 110, 105, 106, 800),
|
||||
_bar(106, 109, 104, 108, 1200),
|
||||
_bar(108, 112, 107, 111, 900),
|
||||
]
|
||||
result = aggregate_1min_to_5min(bars)
|
||||
assert len(result) == 1
|
||||
assert result[0]["open"] == 100 # 첫 bar
|
||||
assert result[0]["close"] == 111 # 마지막 bar
|
||||
assert result[0]["high"] == 112 # max
|
||||
assert result[0]["low"] == 99 # min
|
||||
assert result[0]["volume"] == 5400 # sum
|
||||
131
ai_trade/tests/test_pull_worker.py
Normal file
131
ai_trade/tests/test_pull_worker.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Tests for pull_worker (Phase 3a additions)."""
|
||||
from collections import deque
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from ai_trade.state import PollState
|
||||
|
||||
|
||||
async def test_minute_polling_cycle_updates_state_minute_bars():
|
||||
"""KIS REST mock 의 분봉 데이터가 state.minute_bars[ticker] deque 에 들어간다."""
|
||||
from ai_trade.pull_worker import _run_kis_minute_cycle
|
||||
|
||||
state = PollState()
|
||||
state.portfolio = {"holdings": [{"ticker": "005930"}, {"ticker": "000660"}]}
|
||||
state.screener_preview = {
|
||||
"items": [{"ticker": "005930"}, {"ticker": "035720"}]
|
||||
}
|
||||
|
||||
kis_client_mock = MagicMock()
|
||||
kis_client_mock.get_minute_ohlcv = AsyncMock(side_effect=[
|
||||
[{"datetime": "2026-05-18T09:00:00+09:00", "open": 78000,
|
||||
"high": 78500, "low": 77900, "close": 78300, "volume": 12345}],
|
||||
[{"datetime": "2026-05-18T09:00:00+09:00", "open": 180000,
|
||||
"high": 181000, "low": 179800, "close": 180500, "volume": 5000}],
|
||||
[{"datetime": "2026-05-18T09:00:00+09:00", "open": 51000,
|
||||
"high": 51200, "low": 50800, "close": 51100, "volume": 8000}],
|
||||
])
|
||||
kis_client_mock.get_asking_price = AsyncMock(return_value={
|
||||
"bid_total": 600, "ask_total": 400, "bid_ratio": 0.6,
|
||||
"current_price": 51100, "as_of": "2026-05-18T09:00:30+09:00",
|
||||
})
|
||||
|
||||
await _run_kis_minute_cycle(kis_client_mock, state)
|
||||
|
||||
# 3 unique tickers (005930, 000660, 035720)
|
||||
assert "005930" in state.minute_bars
|
||||
assert "000660" in state.minute_bars
|
||||
assert "035720" in state.minute_bars
|
||||
assert len(state.minute_bars["005930"]) >= 1
|
||||
# asking_price 만 screener-only ticker (035720) 에 들어가야 함
|
||||
# (portfolio = 005930, 000660 는 WebSocket 으로 들어옴)
|
||||
assert "035720" in state.asking_price
|
||||
|
||||
|
||||
def test_websocket_message_updates_state_asking_price():
|
||||
"""WebSocket callback factory → state.asking_price 갱신."""
|
||||
from ai_trade.pull_worker import make_asking_price_callback
|
||||
|
||||
state = PollState()
|
||||
cb = make_asking_price_callback(state)
|
||||
cb("005930", {"bid_total": 1000, "ask_total": 800, "bid_ratio": 0.555,
|
||||
"current_price": 78500, "as_of": "2026-05-18T10:00:00+09:00"})
|
||||
assert state.asking_price["005930"]["bid_total"] == 1000
|
||||
assert "asking_price/005930" in state.last_updated
|
||||
|
||||
|
||||
async def test_post_close_cycle_updates_chronos_predictions():
|
||||
"""mock kis + mock chronos → state.chronos_predictions + state.daily_ohlcv 갱신."""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from ai_trade.pull_worker import _run_post_close_cycle
|
||||
from ai_trade.chronos_predictor import ChronosPrediction
|
||||
from ai_trade.state import PollState
|
||||
|
||||
state = PollState()
|
||||
state.portfolio = {"holdings": [{"ticker": "005930"}]}
|
||||
state.screener_preview = {"items": [{"ticker": "000660"}]}
|
||||
|
||||
kis_mock = MagicMock()
|
||||
daily_005930 = [{"datetime": f"2026-05-{i+1:02d}", "open": 100, "high": 105,
|
||||
"low": 95, "close": 100 + i, "volume": 1000} for i in range(60)]
|
||||
daily_000660 = [{"datetime": f"2026-05-{i+1:02d}", "open": 200, "high": 210,
|
||||
"low": 190, "close": 200 + i, "volume": 2000} for i in range(60)]
|
||||
# _run_post_close_cycle iterates tickers and calls get_daily_ohlcv per ticker.
|
||||
# Order depends on set() so use side_effect mapping if possible, otherwise list.
|
||||
async def fake_daily(ticker, days=60):
|
||||
if ticker == "005930":
|
||||
return daily_005930
|
||||
if ticker == "000660":
|
||||
return daily_000660
|
||||
return []
|
||||
kis_mock.get_daily_ohlcv = AsyncMock(side_effect=fake_daily)
|
||||
|
||||
chronos_mock = MagicMock()
|
||||
chronos_mock.predict_batch = MagicMock(return_value={
|
||||
"005930": ChronosPrediction(0.02, -0.01, 0.04, 0.85, "2026-05-18T16:00:00+09:00"),
|
||||
"000660": ChronosPrediction(0.03, -0.02, 0.06, 0.75, "2026-05-18T16:00:00+09:00"),
|
||||
})
|
||||
|
||||
await _run_post_close_cycle(kis_mock, chronos_mock, state)
|
||||
|
||||
assert "005930" in state.chronos_predictions
|
||||
assert "000660" in state.chronos_predictions
|
||||
assert state.chronos_predictions["005930"]["median"] == 0.02
|
||||
assert state.chronos_predictions["005930"]["conf"] == 0.85
|
||||
assert "005930" in state.daily_ohlcv
|
||||
assert "chronos/005930" in state.last_updated
|
||||
|
||||
|
||||
def test_poll_loop_calls_generate_signals_after_cycle(monkeypatch):
|
||||
"""Phase 4: generate_signals 가 cycle 후 state.signals 를 갱신한다."""
|
||||
from unittest.mock import MagicMock
|
||||
from ai_trade.state import PollState
|
||||
from ai_trade.signal_generator import generate_signals
|
||||
|
||||
state = PollState()
|
||||
state.portfolio = {"holdings": [{
|
||||
"ticker": "005930", "name": "삼성전자",
|
||||
"avg_price": 75000, "current_price": 69000,
|
||||
"pnl_pct": -0.08, "profit_rate": -8.0,
|
||||
"quantity": 100, "broker": "키움",
|
||||
}]}
|
||||
state.screener_preview = {"items": []}
|
||||
|
||||
dedup = MagicMock()
|
||||
dedup.is_recent.return_value = False
|
||||
|
||||
settings = MagicMock()
|
||||
settings.stop_loss_pct = -0.07
|
||||
settings.take_profit_pct = 0.15
|
||||
settings.chronos_spread_threshold = 0.6
|
||||
settings.asking_bid_ratio_threshold = 0.6
|
||||
settings.confidence_threshold = 0.7
|
||||
settings.min_momentum_for_buy = "strong_up"
|
||||
|
||||
generate_signals(state, dedup, settings)
|
||||
|
||||
assert "005930" in state.signals
|
||||
assert state.signals["005930"]["action"] == "sell"
|
||||
assert state.signals["005930"]["confidence_webai"] == 1.0
|
||||
dedup.record.assert_called_with("005930", "sell", confidence=1.0)
|
||||
34
ai_trade/tests/test_rate_limit.py
Normal file
34
ai_trade/tests/test_rate_limit.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Tests for SignalDedup."""
|
||||
from datetime import datetime, timedelta
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from ai_trade.rate_limit import SignalDedup
|
||||
|
||||
KST = ZoneInfo("Asia/Seoul")
|
||||
|
||||
|
||||
def test_is_recent_returns_false_for_new_ticker_action(tmp_dedup_db):
|
||||
dedup = SignalDedup(tmp_dedup_db)
|
||||
assert dedup.is_recent("005930", "buy") is False
|
||||
|
||||
|
||||
def test_is_recent_returns_true_within_24h(tmp_dedup_db):
|
||||
dedup = SignalDedup(tmp_dedup_db)
|
||||
dedup.record("005930", "buy", confidence=0.82)
|
||||
assert dedup.is_recent("005930", "buy") is True
|
||||
|
||||
|
||||
def test_is_recent_returns_false_after_24h(tmp_dedup_db, monkeypatch):
|
||||
dedup = SignalDedup(tmp_dedup_db)
|
||||
# Record with a timestamp 25 hours ago
|
||||
now = datetime.now(KST)
|
||||
fake_now = now - timedelta(hours=25)
|
||||
monkeypatch.setattr(
|
||||
"ai_trade.rate_limit._now_iso", lambda: fake_now.isoformat()
|
||||
)
|
||||
dedup.record("005930", "buy", confidence=0.82)
|
||||
# Reset to real now for is_recent check
|
||||
monkeypatch.setattr(
|
||||
"ai_trade.rate_limit._now_iso", lambda: now.isoformat()
|
||||
)
|
||||
assert dedup.is_recent("005930", "buy", within_hours=24) is False
|
||||
81
ai_trade/tests/test_scheduler.py
Normal file
81
ai_trade/tests/test_scheduler.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Tests for scheduler interval logic."""
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from ai_trade.scheduler import _next_interval, _is_market_day, KST
|
||||
|
||||
|
||||
def _kst(year, month, day, hour, minute=0):
|
||||
return datetime(year, month, day, hour, minute, tzinfo=KST)
|
||||
|
||||
|
||||
def test_next_interval_pre_market_5min():
|
||||
now = _kst(2026, 5, 18, 8, 30) # Monday 08:30
|
||||
assert _next_interval(now) == 300
|
||||
|
||||
|
||||
def test_next_interval_market_open_1min():
|
||||
now = _kst(2026, 5, 18, 10, 0) # Monday 10:00
|
||||
assert _next_interval(now) == 60
|
||||
|
||||
|
||||
def test_next_interval_post_market_5min():
|
||||
now = _kst(2026, 5, 18, 17, 0) # Monday 17:00
|
||||
assert _next_interval(now) == 300
|
||||
|
||||
|
||||
def test_next_interval_overnight_skip_to_next_morning():
|
||||
now = _kst(2026, 5, 18, 2, 30) # Monday 02:30 (dead zone, not NXT window)
|
||||
interval = _next_interval(now)
|
||||
# Dead zone 23:30-04:30 → next 04:30 is ~2h away
|
||||
assert 2 * 3600 - 60 < interval < 2 * 3600 + 60
|
||||
|
||||
|
||||
def test_next_interval_holiday_skip():
|
||||
# 2026-05-05 어린이날 (Tuesday holiday)
|
||||
now = _kst(2026, 5, 5, 10, 0)
|
||||
assert _is_market_day(now) is False
|
||||
interval = _next_interval(now)
|
||||
# Next: 2026-05-06 (Wed) 07:00, ~21h away
|
||||
assert 20 * 3600 < interval < 22 * 3600
|
||||
|
||||
|
||||
def test_next_interval_at_market_open_boundary():
|
||||
"""09:00:00 정확 second → 60초 (market 구간 진입)."""
|
||||
now = _kst(2026, 5, 18, 9, 0) # Monday 09:00:00
|
||||
assert _next_interval(now) == 60
|
||||
|
||||
|
||||
def test_next_interval_at_market_close_boundary():
|
||||
"""15:30:00 정확 second → 300초 (post-market 구간 진입)."""
|
||||
now = _kst(2026, 5, 18, 15, 30) # Monday 15:30:00
|
||||
assert _next_interval(now) == 300
|
||||
|
||||
|
||||
def test_next_interval_at_polling_window_end_boundary():
|
||||
"""23:30:00 정확 second → dead zone skip (다음 04:30 까지)."""
|
||||
now = _kst(2026, 5, 18, 23, 30) # Monday 23:30:00 (NXT_PRE_END boundary)
|
||||
interval = _next_interval(now)
|
||||
# Dead zone 23:30-04:30 → next 04:30 is ~5h away
|
||||
assert 5 * 3600 - 60 < interval < 5 * 3600 + 60
|
||||
|
||||
|
||||
def test_next_interval_nxt_evening_5min():
|
||||
"""22:00 평일 (NXT 야간) → 300 (5분)."""
|
||||
now = _kst(2026, 5, 18, 22, 0)
|
||||
assert _next_interval(now) == 300
|
||||
|
||||
|
||||
def test_next_interval_nxt_dawn_5min():
|
||||
"""05:30 평일 (NXT 새벽) → 300 (5분)."""
|
||||
now = _kst(2026, 5, 18, 5, 30)
|
||||
assert _next_interval(now) == 300
|
||||
|
||||
|
||||
def test_next_interval_dead_zone_skip():
|
||||
"""02:00 평일 (dead zone 23:30-04:30) → 다음 04:30 까지 (~9000s)."""
|
||||
now = _kst(2026, 5, 18, 2, 0)
|
||||
interval = _next_interval(now)
|
||||
# 02:00 → 04:30 = 2.5h = 9000s
|
||||
assert 9000 - 60 < interval < 9000 + 60
|
||||
172
ai_trade/tests/test_signal_generator.py
Normal file
172
ai_trade/tests/test_signal_generator.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""Tests for signal_generator."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from ai_trade.signal_generator import generate_signals
|
||||
from ai_trade.state import PollState
|
||||
|
||||
|
||||
def _settings(**overrides):
|
||||
"""Build a Settings-like object for tests (avoid env)."""
|
||||
defaults = dict(
|
||||
stop_loss_pct=-0.07,
|
||||
take_profit_pct=0.15,
|
||||
chronos_spread_threshold=0.6,
|
||||
asking_bid_ratio_threshold=0.6,
|
||||
confidence_threshold=0.7,
|
||||
min_momentum_for_buy="strong_up",
|
||||
)
|
||||
defaults.update(overrides)
|
||||
m = MagicMock()
|
||||
for k, v in defaults.items():
|
||||
setattr(m, k, v)
|
||||
return m
|
||||
|
||||
|
||||
def _make_state_with_buy_candidate(
|
||||
ticker="005930", name="삼성전자",
|
||||
chronos_median=0.02, chronos_q10=-0.01, chronos_q90=0.04, chronos_conf=0.85,
|
||||
momentum="strong_up", bid_ratio=0.7, current_price=78500,
|
||||
):
|
||||
state = PollState()
|
||||
state.screener_preview = {"items": [{"ticker": ticker, "name": name}]}
|
||||
state.chronos_predictions[ticker] = {
|
||||
"median": chronos_median, "q10": chronos_q10, "q90": chronos_q90,
|
||||
"conf": chronos_conf, "as_of": "2026-05-17T16:00:00+09:00",
|
||||
}
|
||||
state.minute_momentum[ticker] = momentum
|
||||
state.asking_price[ticker] = {
|
||||
"bid_total": int(bid_ratio * 1000),
|
||||
"ask_total": int((1 - bid_ratio) * 1000),
|
||||
"bid_ratio": bid_ratio,
|
||||
"current_price": current_price,
|
||||
"as_of": "2026-05-17T16:00:01+09:00",
|
||||
}
|
||||
return state
|
||||
|
||||
|
||||
def _make_state_with_holding(
|
||||
ticker="005930", name="삼성전자",
|
||||
pnl_pct=0.0, avg_price=75000, current_price=75000,
|
||||
):
|
||||
state = PollState()
|
||||
state.portfolio = {"holdings": [{
|
||||
"ticker": ticker, "name": name,
|
||||
"avg_price": avg_price, "current_price": current_price,
|
||||
"pnl_pct": pnl_pct, "profit_rate": pnl_pct * 100,
|
||||
"quantity": 100, "broker": "키움",
|
||||
}]}
|
||||
state.screener_preview = {"items": []}
|
||||
return state
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dedup_mock():
|
||||
d = MagicMock()
|
||||
d.is_recent.return_value = False
|
||||
return d
|
||||
|
||||
|
||||
def test_buy_signal_when_all_conditions_pass_and_confidence_high(dedup_mock):
|
||||
state = _make_state_with_buy_candidate()
|
||||
generate_signals(state, dedup_mock, _settings())
|
||||
assert "005930" in state.signals
|
||||
sig = state.signals["005930"]
|
||||
assert sig["action"] == "buy"
|
||||
assert sig["confidence_webai"] > 0.7
|
||||
dedup_mock.record.assert_called()
|
||||
|
||||
|
||||
def test_silent_when_chronos_median_negative(dedup_mock):
|
||||
state = _make_state_with_buy_candidate(chronos_median=-0.01)
|
||||
generate_signals(state, dedup_mock, _settings())
|
||||
assert "005930" not in state.signals
|
||||
|
||||
|
||||
def test_silent_when_distribution_spread_too_wide(dedup_mock):
|
||||
# spread = q90 - q10 = 0.5 - (-0.5) = 1.0 > 0.6 → hard gate fails
|
||||
state = _make_state_with_buy_candidate(
|
||||
chronos_median=0.001, chronos_q10=-0.5, chronos_q90=0.5,
|
||||
)
|
||||
generate_signals(state, dedup_mock, _settings())
|
||||
assert "005930" not in state.signals
|
||||
|
||||
|
||||
def test_silent_when_momentum_not_strong_up(dedup_mock):
|
||||
state = _make_state_with_buy_candidate(momentum="weak_up")
|
||||
generate_signals(state, dedup_mock, _settings())
|
||||
assert "005930" not in state.signals
|
||||
|
||||
|
||||
def test_silent_when_bid_ratio_below_threshold(dedup_mock):
|
||||
state = _make_state_with_buy_candidate(bid_ratio=0.5)
|
||||
generate_signals(state, dedup_mock, _settings())
|
||||
assert "005930" not in state.signals
|
||||
|
||||
|
||||
def test_silent_when_confidence_below_threshold(dedup_mock):
|
||||
# chronos_conf low + rank=20 → confidence < 0.7
|
||||
state = _make_state_with_buy_candidate(chronos_conf=0.3)
|
||||
# add 19 fake items to push 005930 rank to 20
|
||||
state.screener_preview["items"] = (
|
||||
[{"ticker": f"FAKE{i:03d}"} for i in range(19)]
|
||||
+ [{"ticker": "005930", "name": "삼성전자"}]
|
||||
)
|
||||
generate_signals(state, dedup_mock, _settings())
|
||||
# confidence_webai = 0.3*0.5 + 1.0*0.3 + 0.05*0.2 = 0.46 < 0.7
|
||||
assert "005930" not in state.signals
|
||||
|
||||
|
||||
def test_sell_signal_when_stop_loss_triggered(dedup_mock):
|
||||
state = _make_state_with_holding(pnl_pct=-0.08, current_price=69000, avg_price=75000)
|
||||
generate_signals(state, dedup_mock, _settings())
|
||||
assert "005930" in state.signals
|
||||
sig = state.signals["005930"]
|
||||
assert sig["action"] == "sell"
|
||||
assert sig["confidence_webai"] == 1.0
|
||||
assert sig["pnl_pct"] == -0.08
|
||||
|
||||
|
||||
def test_sell_signal_when_take_profit_triggered(dedup_mock):
|
||||
state = _make_state_with_holding(pnl_pct=0.16, current_price=87000, avg_price=75000)
|
||||
generate_signals(state, dedup_mock, _settings())
|
||||
assert "005930" in state.signals
|
||||
sig = state.signals["005930"]
|
||||
assert sig["action"] == "sell"
|
||||
assert sig["confidence_webai"] == 0.6
|
||||
|
||||
|
||||
def test_silent_when_dedup_recently_sent(dedup_mock):
|
||||
state = _make_state_with_buy_candidate()
|
||||
dedup_mock.is_recent.return_value = True
|
||||
generate_signals(state, dedup_mock, _settings())
|
||||
assert "005930" not in state.signals
|
||||
dedup_mock.record.assert_not_called()
|
||||
|
||||
|
||||
def test_sell_signal_triggers_on_anomaly_path(dedup_mock):
|
||||
"""Anomaly sell: median < -1%, momentum strong_down, low bid_ratio, confidence > threshold."""
|
||||
state = PollState()
|
||||
state.portfolio = {"holdings": [{
|
||||
"ticker": "005930", "name": "삼성전자",
|
||||
"avg_price": 75000, "current_price": 70000,
|
||||
"pnl_pct": -0.067, # within stop_loss tolerance (default -0.07): NOT triggering stop_loss
|
||||
"quantity": 100, "broker": "키움",
|
||||
}]}
|
||||
state.screener_preview = {"items": []}
|
||||
state.chronos_predictions["005930"] = {
|
||||
"median": -0.025, "q10": -0.05, "q90": 0.005, "conf": 0.85,
|
||||
}
|
||||
state.minute_momentum["005930"] = "strong_down"
|
||||
state.asking_price["005930"] = {"current_price": 70000, "bid_ratio": 0.30}
|
||||
# bid_ratio 0.30 < (1 - 0.6) = 0.4 → anomaly bid_ratio gate passes
|
||||
# confidence = 0.85*0.5 + 1.0*0.3 + 1.0*0.2 = 0.425 + 0.3 + 0.2 = 0.925 > 0.7
|
||||
|
||||
generate_signals(state, dedup_mock, _settings())
|
||||
|
||||
assert "005930" in state.signals
|
||||
sig = state.signals["005930"]
|
||||
assert sig["action"] == "sell"
|
||||
assert sig["context"]["sell_reason"] == "anomaly"
|
||||
assert sig["confidence_webai"] > 0.7
|
||||
168
ai_trade/tests/test_stock_client.py
Normal file
168
ai_trade/tests/test_stock_client.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Tests for stock_client.StockClient."""
|
||||
import asyncio
|
||||
import logging
|
||||
import pytest
|
||||
import httpx
|
||||
|
||||
from ai_trade.stock_client import StockClient
|
||||
|
||||
|
||||
BASE_URL = "https://test.stock.local"
|
||||
API_KEY = "test-secret"
|
||||
|
||||
|
||||
async def test_get_portfolio_normal_returns_dict_with_pnl_pct(mock_stock_api):
|
||||
"""정상 200 응답 + cache 저장."""
|
||||
mock_stock_api.get("/api/webai/portfolio").mock(
|
||||
return_value=httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"holdings": [{"ticker": "005930", "pnl_pct": 0.047}],
|
||||
"cash": [],
|
||||
"summary": {},
|
||||
},
|
||||
)
|
||||
)
|
||||
client = StockClient(BASE_URL, API_KEY)
|
||||
try:
|
||||
result = await client.get_portfolio()
|
||||
assert result["holdings"][0]["pnl_pct"] == 0.047
|
||||
# Cache populated
|
||||
assert len(client._cache) >= 1
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
async def test_get_portfolio_uses_cache_within_ttl(mock_stock_api):
|
||||
"""180s TTL 내 두번째 호출 = mock 콜 1회."""
|
||||
route = mock_stock_api.get("/api/webai/portfolio").mock(
|
||||
return_value=httpx.Response(
|
||||
200, json={"holdings": [], "cash": [], "summary": {}}
|
||||
)
|
||||
)
|
||||
client = StockClient(BASE_URL, API_KEY)
|
||||
try:
|
||||
await client.get_portfolio()
|
||||
await client.get_portfolio() # second call within TTL
|
||||
assert route.call_count == 1
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
async def test_get_portfolio_refetches_after_ttl_expiry(mock_stock_api, monkeypatch):
|
||||
"""TTL 만료 후 재호출 = mock 콜 2회. time.monotonic 모킹."""
|
||||
route = mock_stock_api.get("/api/webai/portfolio").mock(
|
||||
return_value=httpx.Response(
|
||||
200, json={"holdings": [], "cash": [], "summary": {}}
|
||||
)
|
||||
)
|
||||
# Fake clock: starts at 0, jumps past portfolio TTL (180s) between calls
|
||||
fake_time = [0.0]
|
||||
monkeypatch.setattr(
|
||||
"ai_trade.stock_client.time.monotonic", lambda: fake_time[0]
|
||||
)
|
||||
|
||||
client = StockClient(BASE_URL, API_KEY)
|
||||
try:
|
||||
await client.get_portfolio()
|
||||
fake_time[0] = 181.0 # 180s TTL 만료
|
||||
await client.get_portfolio()
|
||||
assert route.call_count == 2
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
async def test_get_portfolio_retries_3_times_on_timeout(mock_stock_api, monkeypatch):
|
||||
"""timeout 2번 + 200 1번 → 최종 성공. exponential sleep 호출 검증."""
|
||||
sleep_calls = []
|
||||
|
||||
async def fake_sleep(s):
|
||||
sleep_calls.append(s)
|
||||
|
||||
monkeypatch.setattr("asyncio.sleep", fake_sleep)
|
||||
|
||||
mock_stock_api.get("/api/webai/portfolio").mock(
|
||||
side_effect=[
|
||||
httpx.TimeoutException("timeout 1"),
|
||||
httpx.TimeoutException("timeout 2"),
|
||||
httpx.Response(
|
||||
200, json={"holdings": [], "cash": [], "summary": {}}
|
||||
),
|
||||
]
|
||||
)
|
||||
client = StockClient(BASE_URL, API_KEY)
|
||||
try:
|
||||
result = await client.get_portfolio()
|
||||
assert result["holdings"] == []
|
||||
assert sleep_calls == [1, 2] # exponential 1s, 2s
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
async def test_get_portfolio_429_triggers_backoff(mock_stock_api, monkeypatch):
|
||||
"""429 → 1s backoff → 200."""
|
||||
sleep_calls = []
|
||||
|
||||
async def fake_sleep(s):
|
||||
sleep_calls.append(s)
|
||||
|
||||
monkeypatch.setattr("asyncio.sleep", fake_sleep)
|
||||
|
||||
mock_stock_api.get("/api/webai/portfolio").mock(
|
||||
side_effect=[
|
||||
httpx.Response(429, text="rate limit"),
|
||||
httpx.Response(
|
||||
200, json={"holdings": [], "cash": [], "summary": {}}
|
||||
),
|
||||
]
|
||||
)
|
||||
client = StockClient(BASE_URL, API_KEY)
|
||||
try:
|
||||
result = await client.get_portfolio()
|
||||
assert result["holdings"] == []
|
||||
assert sleep_calls == [1]
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
async def test_get_portfolio_falls_back_to_stale_on_all_failures(
|
||||
mock_stock_api, monkeypatch, caplog
|
||||
):
|
||||
"""cache 에 이전 성공 응답 + 모든 retry 5xx → stale 반환 + logger.warning."""
|
||||
# No-op sleep for fast test
|
||||
async def fake_sleep(s):
|
||||
return None
|
||||
monkeypatch.setattr("asyncio.sleep", fake_sleep)
|
||||
|
||||
# Patch time.monotonic BEFORE first call so cached timestamp uses fake clock
|
||||
fake_time = [0.0]
|
||||
monkeypatch.setattr(
|
||||
"ai_trade.stock_client.time.monotonic", lambda: fake_time[0]
|
||||
)
|
||||
|
||||
# First call succeeds
|
||||
route1 = mock_stock_api.get("/api/webai/portfolio").mock(
|
||||
return_value=httpx.Response(
|
||||
200,
|
||||
json={"holdings": [{"ticker": "005930"}], "cash": [], "summary": {}},
|
||||
)
|
||||
)
|
||||
client = StockClient(BASE_URL, API_KEY)
|
||||
try:
|
||||
first = await client.get_portfolio()
|
||||
assert first["holdings"][0]["ticker"] == "005930"
|
||||
|
||||
# Advance fake clock past TTL (180s) so cache is stale
|
||||
fake_time[0] = 181.0
|
||||
|
||||
# Now mock to return 500s persistently
|
||||
route1.mock(return_value=httpx.Response(500, text="server error"))
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="ai_trade.stock_client"):
|
||||
result = await client.get_portfolio()
|
||||
assert result["holdings"][0]["ticker"] == "005930" # stale data returned
|
||||
assert any(
|
||||
"stale" in rec.message.lower() for rec in caplog.records
|
||||
)
|
||||
finally:
|
||||
await client.close()
|
||||
18
ai_trade/tests/test_stock_client_ttl.py
Normal file
18
ai_trade/tests/test_stock_client_ttl.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# tests/test_stock_client_ttl.py
|
||||
"""SP-A1 회귀 — _TTL이 NAS 부담 완화를 위한 값으로 설정되어 있어야 함."""
|
||||
from ai_trade.stock_client import _TTL
|
||||
|
||||
|
||||
def test_portfolio_ttl_is_180s():
|
||||
"""portfolio TTL은 180초 이상 (3분 폴링에서 1회 fetch가 3 폴링 커버)."""
|
||||
assert _TTL["portfolio"] >= 180.0
|
||||
|
||||
|
||||
def test_news_sentiment_ttl_is_600s():
|
||||
"""news-sentiment TTL은 600초 이상 (10분, 뉴스 sentiment는 자주 안 바뀜)."""
|
||||
assert _TTL["news-sentiment"] >= 600.0
|
||||
|
||||
|
||||
def test_screener_preview_ttl_is_300s():
|
||||
"""screener-preview TTL은 300초 이상 (5분, Top-20은 분 단위로 거의 안 바뀜)."""
|
||||
assert _TTL["screener-preview"] >= 300.0
|
||||
26
legacy/signal_v1/DEPRECATED.md
Normal file
26
legacy/signal_v1/DEPRECATED.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# signal_v1 — DEPRECATED
|
||||
|
||||
> **2026-05-19부터 사용 안 함.** 신규 작업 금지. 모든 트레이딩은 `web-ai/ai_trade/` (구 `signal_v2`) 에서 진행.
|
||||
|
||||
## 폐기 사유
|
||||
|
||||
- V2 (`ai_trade`) Phase 4 완료 — Chronos-2 zero-shot 1일 수익률 + 분봉 모멘텀 + 5-state classifier + sell-first 우선순위 + 매수 hard gate가 V1 (LSTM 7-features + Gemini Flash) 보다 정확도·확장성·해석가능성에서 우위
|
||||
- V1 + V2 동시 운영 시 KIS API rate limit 충돌
|
||||
- NAS 인바운드 polling 부담 (web-ai → NAS API) 의 50% 차지
|
||||
|
||||
## 향후 처리
|
||||
|
||||
- 디렉토리를 `legacy/signal_v1/`로 이동 예정 (현재 file lock 풀린 후 처리)
|
||||
- `start.bat` 진입점은 이미 `legacy/start_v1.bat`으로 이동 → 자동 시작 차단됨
|
||||
- DSM Scheduler 등 외부 trigger에 V1 startup 등록되어 있다면 해제 필요 (박재오 확인)
|
||||
|
||||
## 활용 (필요 시)
|
||||
|
||||
- 코드 참고용 (LSTM 모델 구조, Telegram bot 인터페이스, KIS 자동주문 패턴)
|
||||
- 별도 backtest 실행은 가능 (`backtest_runner.py`) — 단 운영 자동 실행 X
|
||||
|
||||
## 관련 문서
|
||||
|
||||
- 신 운영 가이드: `../ai_trade/`
|
||||
- web-ai 통합 가이드: `../CLAUDE.md`
|
||||
- V1 vs V2 진단: `../CHECK_POINT.md`
|
||||
13
requirements.txt
Normal file
13
requirements.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
# Signal V2 dependencies (added 2026-05-16, Phase 2)
|
||||
httpx>=0.27
|
||||
fastapi>=0.110
|
||||
uvicorn>=0.27
|
||||
python-dotenv>=1.0
|
||||
pytest>=8.0
|
||||
pytest-asyncio>=0.23
|
||||
respx>=0.21
|
||||
websockets>=12
|
||||
# Phase 3b dependencies (Chronos-2 + ML)
|
||||
transformers>=4.40
|
||||
chronos-forecasting>=1.4
|
||||
# torch: typically already installed via V1 venv; if not, install with CUDA support manually
|
||||
124
services/docker-compose.yml
Normal file
124
services/docker-compose.yml
Normal file
@@ -0,0 +1,124 @@
|
||||
name: web-ai-services
|
||||
|
||||
services:
|
||||
insta-render:
|
||||
build:
|
||||
context: ./insta-render
|
||||
container_name: insta-render
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "18710:8000"
|
||||
environment:
|
||||
- TZ=Asia/Seoul
|
||||
- REDIS_URL=${REDIS_URL:-redis://192.168.45.54:6379}
|
||||
- NAS_BASE_URL=${NAS_BASE_URL:-http://192.168.45.54:18700}
|
||||
- INTERNAL_API_KEY=${INTERNAL_API_KEY:-}
|
||||
- INSTA_MEDIA_ROOT=${INSTA_MEDIA_ROOT:-/mnt/nas/webpage/data/insta}
|
||||
- INSTA_MEDIA_URL_PREFIX=${INSTA_MEDIA_URL_PREFIX:-/media/insta}
|
||||
- CARD_TEMPLATE_DIR=/app/templates
|
||||
volumes:
|
||||
- /mnt/nas/webpage/data/insta:/mnt/nas/webpage/data/insta
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
|
||||
interval: 60s
|
||||
timeout: 5s
|
||||
retries: 3
|
||||
|
||||
music-render:
|
||||
build:
|
||||
context: ./music-render
|
||||
container_name: music-render
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "18711:8000"
|
||||
environment:
|
||||
- TZ=Asia/Seoul
|
||||
- REDIS_URL=${REDIS_URL:-redis://192.168.45.54:6379}
|
||||
- NAS_BASE_URL=${NAS_BASE_URL:-http://192.168.45.54:18600}
|
||||
- INTERNAL_API_KEY=${INTERNAL_API_KEY:-}
|
||||
- SUNO_API_KEY=${SUNO_API_KEY:-}
|
||||
- MUSIC_AI_SERVER_URL=${MUSIC_AI_SERVER_URL:-http://host.docker.internal:8765}
|
||||
- MUSIC_MEDIA_ROOT=${MUSIC_MEDIA_ROOT:-/mnt/nas/webpage/data/music}
|
||||
- MUSIC_MEDIA_URL_PREFIX=${MUSIC_MEDIA_URL_PREFIX:-/media/music}
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- /mnt/nas/webpage/data/music:/mnt/nas/webpage/data/music
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
|
||||
interval: 60s
|
||||
timeout: 5s
|
||||
retries: 3
|
||||
|
||||
video-render:
|
||||
build:
|
||||
context: ./video-render
|
||||
container_name: video-render
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "18712:8000"
|
||||
environment:
|
||||
- TZ=Asia/Seoul
|
||||
- REDIS_URL=${REDIS_URL:-redis://192.168.45.54:6379}
|
||||
- NAS_BASE_URL=${NAS_BASE_URL:-http://192.168.45.54:18801}
|
||||
- INTERNAL_API_KEY=${INTERNAL_API_KEY:-}
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
|
||||
- GEMINI_API_KEY=${GEMINI_API_KEY:-}
|
||||
- KLING_ACCESS_KEY=${KLING_ACCESS_KEY:-}
|
||||
- KLING_SECRET_KEY=${KLING_SECRET_KEY:-}
|
||||
- SEEDANCE_API_KEY=${SEEDANCE_API_KEY:-}
|
||||
- VIDEO_MEDIA_ROOT=${VIDEO_MEDIA_ROOT:-/mnt/nas/webpage/data/video}
|
||||
- VIDEO_MEDIA_URL_PREFIX=${VIDEO_MEDIA_URL_PREFIX:-/media/video}
|
||||
volumes:
|
||||
- /mnt/nas/webpage/data/video:/mnt/nas/webpage/data/video
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
|
||||
interval: 60s
|
||||
timeout: 5s
|
||||
retries: 3
|
||||
|
||||
task-watcher:
|
||||
build:
|
||||
context: ./task-watcher
|
||||
container_name: task-watcher
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "18713:8000"
|
||||
environment:
|
||||
- TZ=Asia/Seoul
|
||||
- REDIS_URL=${REDIS_URL:-redis://192.168.45.54:6379}
|
||||
- STOCK_BASE_URL=${STOCK_BASE_URL:-http://192.168.45.54:18500}
|
||||
- TRADING_START=${TRADING_START:-07:00}
|
||||
- TRADING_END=${TRADING_END:-16:30}
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
|
||||
interval: 60s
|
||||
timeout: 5s
|
||||
retries: 3
|
||||
|
||||
image-render:
|
||||
build:
|
||||
context: ./image-render
|
||||
container_name: image-render
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "18714:8000"
|
||||
environment:
|
||||
- TZ=Asia/Seoul
|
||||
- REDIS_URL=${REDIS_URL:-redis://192.168.45.54:6379}
|
||||
- NAS_BASE_URL=${NAS_BASE_URL:-http://192.168.45.54:18802}
|
||||
- INTERNAL_API_KEY=${INTERNAL_API_KEY:-}
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY:-}
|
||||
- GEMINI_API_KEY=${GEMINI_API_KEY:-}
|
||||
- COMFYUI_URL=${COMFYUI_URL:-http://host.docker.internal:8188}
|
||||
- FLUX_BLOCK_TRADING_HOURS=${FLUX_BLOCK_TRADING_HOURS:-1}
|
||||
- IMAGE_MEDIA_ROOT=${IMAGE_MEDIA_ROOT:-/mnt/nas/webpage/data/image}
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- /mnt/nas/webpage/data/image:/mnt/nas/webpage/data/image
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
|
||||
interval: 60s
|
||||
timeout: 5s
|
||||
retries: 3
|
||||
16
services/image-render/Dockerfile
Normal file
16
services/image-render/Dockerfile
Normal file
@@ -0,0 +1,16 @@
|
||||
FROM python:3.12-slim-bookworm
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir --timeout 600 --retries 5 -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
EXPOSE 8000
|
||||
CMD ["python", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "1"]
|
||||
18
services/image-render/env.example
Normal file
18
services/image-render/env.example
Normal file
@@ -0,0 +1,18 @@
|
||||
# Redis (NAS)
|
||||
REDIS_URL=redis://192.168.45.54:6379
|
||||
|
||||
# NAS image-lab webhook
|
||||
NAS_BASE_URL=http://192.168.45.54:18802
|
||||
INTERNAL_API_KEY=replace-me
|
||||
|
||||
# API provider keys (worker reports failed if missing)
|
||||
OPENAI_API_KEY=
|
||||
GEMINI_API_KEY=
|
||||
# Seedance key not used by image-render
|
||||
|
||||
# FLUX local
|
||||
COMFYUI_URL=http://host.docker.internal:8188
|
||||
FLUX_BLOCK_TRADING_HOURS=1
|
||||
|
||||
# NAS SMB mount target (image-render writes to this, NAS reads via /media/image/)
|
||||
IMAGE_MEDIA_ROOT=/mnt/nas/webpage/data/image
|
||||
36
services/image-render/main.py
Normal file
36
services/image-render/main.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""image-render FastAPI entry — health + lifespan (worker loop spawn)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
import worker
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
worker_task = asyncio.create_task(worker.worker_loop())
|
||||
logger.info("image-render lifespan 시작")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
worker_task.cancel()
|
||||
try:
|
||||
await worker_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("image-render lifespan 종료")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"ok": True, "service": "image-render"}
|
||||
54
services/image-render/nas_client.py
Normal file
54
services/image-render/nas_client.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""NAS webhook 어댑터 — Windows worker → NAS image-lab HTTP 위임.
|
||||
|
||||
video-render nas_client 복제 (call-time os.getenv으로 테스트 격리).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TIMEOUT = 10.0
|
||||
|
||||
|
||||
def _post(payload: Dict[str, Any]) -> None:
|
||||
nas_base_url = os.getenv("NAS_BASE_URL", "http://192.168.45.54:18802")
|
||||
internal_api_key = os.getenv("INTERNAL_API_KEY", "")
|
||||
url = f"{nas_base_url}/api/internal/image/update"
|
||||
try:
|
||||
r = httpx.post(
|
||||
url,
|
||||
headers={"X-Internal-Key": internal_api_key},
|
||||
json=payload,
|
||||
timeout=_TIMEOUT,
|
||||
)
|
||||
if r.status_code != 200:
|
||||
logger.error("webhook %s returned %d: %s",
|
||||
payload.get("task_id"), r.status_code, r.text[:200])
|
||||
except Exception:
|
||||
logger.exception("webhook %s 호출 실패", payload.get("task_id"))
|
||||
|
||||
|
||||
def webhook_update_task(
|
||||
task_id: str,
|
||||
status: str,
|
||||
progress: int,
|
||||
message: str = "",
|
||||
image_url: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> None:
|
||||
payload: Dict[str, Any] = {
|
||||
"task_id": task_id,
|
||||
"status": status,
|
||||
"progress": progress,
|
||||
"message": message,
|
||||
}
|
||||
if image_url is not None:
|
||||
payload["image_url"] = image_url
|
||||
if error is not None:
|
||||
payload["error"] = error
|
||||
_post(payload)
|
||||
0
services/image-render/providers/__init__.py
Normal file
0
services/image-render/providers/__init__.py
Normal file
18
services/image-render/providers/_media.py
Normal file
18
services/image-render/providers/_media.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""b64 이미지 → NAS SMB 경로 저장 → /media/image URL 반환."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import os
|
||||
import uuid
|
||||
|
||||
IMAGE_MEDIA_ROOT = os.getenv("IMAGE_MEDIA_ROOT", "/mnt/nas/webpage/data/image")
|
||||
IMAGE_MEDIA_URL_PREFIX = os.getenv("IMAGE_MEDIA_URL_PREFIX", "/media/image")
|
||||
|
||||
|
||||
def save_b64_png(task_id: str, b64_data: str) -> str:
|
||||
os.makedirs(IMAGE_MEDIA_ROOT, exist_ok=True)
|
||||
fname = f"{task_id}-{uuid.uuid4().hex[:8]}.png"
|
||||
path = os.path.join(IMAGE_MEDIA_ROOT, fname)
|
||||
with open(path, "wb") as f:
|
||||
f.write(base64.b64decode(b64_data))
|
||||
return f"{IMAGE_MEDIA_URL_PREFIX}/{fname}"
|
||||
79
services/image-render/providers/flux.py
Normal file
79
services/image-render/providers/flux.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""FLUX 로컬 — ComfyUI HTTP API.
|
||||
|
||||
POST {COMFYUI_URL}/prompt (workflow JSON) → prompt_id
|
||||
GET {COMFYUI_URL}/history/{prompt_id} → outputs → image filename
|
||||
GET {COMFYUI_URL}/view?filename=... → PNG bytes → b64
|
||||
|
||||
워크플로우 JSON은 `flux_workflow.json` (ComfyUI UI에서 "Save (API Format)"로 export, CLIPTextEncode 노드 text를 "%PROMPT%"로 수동 치환). 박재오 산출물.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64, json, logging, os, time
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import requests
|
||||
|
||||
from nas_client import webhook_update_task
|
||||
from providers._media import save_b64_png
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
COMFYUI_URL = os.getenv("COMFYUI_URL", "http://127.0.0.1:8188")
|
||||
WORKFLOW_PATH = os.path.join(os.path.dirname(__file__), "flux_workflow.json")
|
||||
POLL_INTERVAL = 2
|
||||
POLL_MAX = 120
|
||||
|
||||
|
||||
def _is_trading_hours() -> bool:
|
||||
kst = timezone(timedelta(hours=9))
|
||||
now = datetime.now(kst)
|
||||
if now.weekday() >= 5:
|
||||
return False
|
||||
return (now.hour, now.minute) >= (9, 0) and (now.hour, now.minute) <= (15, 30)
|
||||
|
||||
|
||||
def _load_workflow(prompt: str, size: str) -> dict:
|
||||
with open(WORKFLOW_PATH, encoding="utf-8") as f:
|
||||
wf = json.load(f)
|
||||
# CLIPTextEncode 노드의 text를 prompt로 치환 (workflow에 "%PROMPT%" placeholder 사용)
|
||||
raw = json.dumps(wf).replace("%PROMPT%", prompt.replace('"', "'"))
|
||||
return json.loads(raw)
|
||||
|
||||
|
||||
def _submit_prompt(workflow: dict) -> str:
|
||||
r = requests.post(f"{COMFYUI_URL}/prompt", json={"prompt": workflow}, timeout=30)
|
||||
r.raise_for_status()
|
||||
return r.json()["prompt_id"]
|
||||
|
||||
|
||||
def _poll_image_b64(prompt_id: str):
|
||||
for _ in range(POLL_MAX):
|
||||
h = requests.get(f"{COMFYUI_URL}/history/{prompt_id}", timeout=10)
|
||||
data = h.json().get(prompt_id)
|
||||
if data and data.get("outputs"):
|
||||
for node_out in data["outputs"].values():
|
||||
for img in node_out.get("images", []):
|
||||
view = requests.get(f"{COMFYUI_URL}/view",
|
||||
params={"filename": img["filename"], "subfolder": img.get("subfolder", ""), "type": img.get("type", "output")},
|
||||
timeout=30)
|
||||
view.raise_for_status()
|
||||
return base64.b64encode(view.content).decode()
|
||||
time.sleep(POLL_INTERVAL)
|
||||
return None
|
||||
|
||||
|
||||
def run_flux_generation(task_id: str, params: dict) -> None:
|
||||
try:
|
||||
if os.getenv("FLUX_BLOCK_TRADING_HOURS") == "1" and _is_trading_hours():
|
||||
webhook_update_task(task_id, "failed", 0, "", error="장중 GPU 보호 — FLUX 거부 (API provider 사용 권장)")
|
||||
return
|
||||
webhook_update_task(task_id, "processing", 10, "FLUX (ComfyUI) 생성 중...")
|
||||
wf = _load_workflow(params["prompt"], params.get("size") or "1024x1024")
|
||||
pid = _submit_prompt(wf)
|
||||
b64 = _poll_image_b64(pid)
|
||||
if not b64:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="ComfyUI 타임아웃 또는 출력 없음")
|
||||
return
|
||||
url = save_b64_png(task_id, b64)
|
||||
webhook_update_task(task_id, "succeeded", 100, "완료", image_url=url)
|
||||
except Exception as e:
|
||||
logger.exception("flux task=%s 실패", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
47
services/image-render/providers/gpt_image.py
Normal file
47
services/image-render/providers/gpt_image.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""GPT Image 2.0 — OpenAI Images API.
|
||||
|
||||
POST https://api.openai.com/v1/images/generations
|
||||
body {model:"gpt-image-1", prompt, size, n:1} → data[0].b64_json
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
from nas_client import webhook_update_task
|
||||
from providers._media import save_b64_png
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
OPENAI_URL = "https://api.openai.com/v1/images/generations"
|
||||
DEFAULT_MODEL = "gpt-image-1"
|
||||
|
||||
|
||||
def run_gpt_image_generation(task_id: str, params: dict) -> None:
|
||||
try:
|
||||
if not os.getenv("OPENAI_API_KEY"):
|
||||
webhook_update_task(task_id, "failed", 0, "", error="OPENAI_API_KEY 미설정 (Windows .env)")
|
||||
return
|
||||
webhook_update_task(task_id, "processing", 10, "GPT Image 호출 중...")
|
||||
body = {
|
||||
"model": params.get("model") or DEFAULT_MODEL,
|
||||
"prompt": params["prompt"],
|
||||
"size": params.get("size") or "1024x1024",
|
||||
"n": 1,
|
||||
}
|
||||
resp = requests.post(
|
||||
OPENAI_URL,
|
||||
headers={"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}", "Content-Type": "application/json"},
|
||||
json=body,
|
||||
timeout=120,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"OpenAI {resp.status_code}: {resp.text[:200]}")
|
||||
return
|
||||
b64 = resp.json()["data"][0]["b64_json"]
|
||||
url = save_b64_png(task_id, b64)
|
||||
webhook_update_task(task_id, "succeeded", 100, "완료", image_url=url)
|
||||
except Exception as e:
|
||||
logger.exception("gpt_image task=%s 실패", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
52
services/image-render/providers/nano_banana.py
Normal file
52
services/image-render/providers/nano_banana.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Nano Banana — Gemini 2.5 Flash Image (generativelanguage API).
|
||||
|
||||
POST /v1beta/models/{MODEL}:generateContent
|
||||
→ candidates[0].content.parts[*].inlineData.data (b64 png)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging, os
|
||||
import requests
|
||||
|
||||
from nas_client import webhook_update_task
|
||||
from providers._media import save_b64_png
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
GEMINI_BASE = "https://generativelanguage.googleapis.com/v1beta"
|
||||
DEFAULT_MODEL = "gemini-2.5-flash-image"
|
||||
|
||||
|
||||
def _extract_b64(data: dict):
|
||||
for cand in data.get("candidates", []):
|
||||
for part in cand.get("content", {}).get("parts", []):
|
||||
inline = part.get("inlineData") or part.get("inline_data")
|
||||
if inline and inline.get("data"):
|
||||
return inline["data"]
|
||||
return None
|
||||
|
||||
|
||||
def run_nano_banana_generation(task_id: str, params: dict) -> None:
|
||||
try:
|
||||
if not os.getenv("GEMINI_API_KEY"):
|
||||
webhook_update_task(task_id, "failed", 0, "", error="GEMINI_API_KEY 미설정 (Windows .env)")
|
||||
return
|
||||
webhook_update_task(task_id, "processing", 10, "Nano Banana (Gemini) 호출 중...")
|
||||
model_id = params.get("model") or DEFAULT_MODEL
|
||||
body = {"contents": [{"parts": [{"text": params["prompt"]}]}]}
|
||||
resp = requests.post(
|
||||
f"{GEMINI_BASE}/models/{model_id}:generateContent",
|
||||
headers={"x-goog-api-key": os.getenv("GEMINI_API_KEY"), "Content-Type": "application/json"},
|
||||
json=body, timeout=120,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
webhook_update_task(task_id, "failed", 0, "", error=f"Gemini {resp.status_code}: {resp.text[:200]}")
|
||||
return
|
||||
b64 = _extract_b64(resp.json())
|
||||
if not b64:
|
||||
webhook_update_task(task_id, "failed", 0, "", error="Gemini 응답에 이미지 없음")
|
||||
return
|
||||
url = save_b64_png(task_id, b64)
|
||||
webhook_update_task(task_id, "succeeded", 100, "완료", image_url=url)
|
||||
except Exception as e:
|
||||
logger.exception("nano_banana task=%s 실패", task_id)
|
||||
webhook_update_task(task_id, "failed", 0, "", error=str(e))
|
||||
9
services/image-render/requirements.txt
Normal file
9
services/image-render/requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
fastapi==0.115.6
|
||||
uvicorn[standard]==0.34.0
|
||||
requests==2.32.3
|
||||
redis>=5.0
|
||||
httpx>=0.27
|
||||
openai>=1.50.0
|
||||
pytest>=8.0
|
||||
pytest-asyncio>=0.24
|
||||
respx>=0.21
|
||||
0
services/image-render/tests/__init__.py
Normal file
0
services/image-render/tests/__init__.py
Normal file
21
services/image-render/tests/test_flux.py
Normal file
21
services/image-render/tests/test_flux.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import providers.flux as fx
|
||||
|
||||
def test_blocked_during_trading_hours(monkeypatch):
|
||||
monkeypatch.setenv("FLUX_BLOCK_TRADING_HOURS", "1")
|
||||
monkeypatch.setattr(fx, "_is_trading_hours", lambda: True)
|
||||
calls = []
|
||||
monkeypatch.setattr(fx, "webhook_update_task", lambda *a, **k: calls.append((a, k)))
|
||||
fx.run_flux_generation("t1", {"prompt": "a cat"})
|
||||
assert calls[-1][0][1] == "failed"
|
||||
assert "장중" in calls[-1][1]["error"]
|
||||
|
||||
def test_success_polls_history_and_saves(monkeypatch):
|
||||
monkeypatch.setattr(fx, "_is_trading_hours", lambda: False)
|
||||
calls = []
|
||||
monkeypatch.setattr(fx, "webhook_update_task", lambda *a, **k: calls.append((a, k)))
|
||||
monkeypatch.setattr(fx, "_load_workflow", lambda prompt, size: {"3": {}})
|
||||
monkeypatch.setattr(fx, "_submit_prompt", lambda wf: "pid-1")
|
||||
monkeypatch.setattr(fx, "_poll_image_b64", lambda pid: "ZmFrZQ==")
|
||||
monkeypatch.setattr(fx, "save_b64_png", lambda tid, b64: "/media/image/t1.png")
|
||||
fx.run_flux_generation("t1", {"prompt": "a cat"})
|
||||
assert [c for c in calls if c[0][1] == "succeeded"]
|
||||
32
services/image-render/tests/test_gpt_image.py
Normal file
32
services/image-render/tests/test_gpt_image.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import providers.gpt_image as gi
|
||||
|
||||
|
||||
def test_missing_key_reports_failed(monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
calls = []
|
||||
monkeypatch.setattr(gi, "webhook_update_task", lambda *a, **k: calls.append((a, k)))
|
||||
gi.run_gpt_image_generation("t1", {"prompt": "a cat"})
|
||||
# 마지막 호출이 failed
|
||||
assert calls[-1][0][1] == "failed"
|
||||
|
||||
|
||||
def test_success_saves_and_reports_url(monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-test")
|
||||
calls = []
|
||||
monkeypatch.setattr(gi, "webhook_update_task", lambda *a, **k: calls.append((a, k)))
|
||||
monkeypatch.setattr(gi, "save_b64_png", lambda tid, b64: "/media/image/t1.png")
|
||||
|
||||
class FakeResp:
|
||||
status_code = 200
|
||||
|
||||
def json(self):
|
||||
return {"data": [{"b64_json": "ZmFrZQ=="}]}
|
||||
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(gi.requests, "post", lambda *a, **k: FakeResp())
|
||||
|
||||
gi.run_gpt_image_generation("t1", {"prompt": "a cat"})
|
||||
succeeded = [c for c in calls if c[0][1] == "succeeded"]
|
||||
assert succeeded and succeeded[-1][1]["image_url"] == "/media/image/t1.png"
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user