Compare commits
14 Commits
cb70226f42
...
d2f7030446
| Author | SHA1 | Date | |
|---|---|---|---|
| d2f7030446 | |||
| 43ee610780 | |||
| f79c5c26df | |||
| 7108e5e4f5 | |||
| 1e6638a64b | |||
| 32308bede6 | |||
| ac6409605c | |||
| e4d02b8059 | |||
| 94a034ef38 | |||
| 2a11d05f4a | |||
| c2e77a7310 | |||
| bea27a75cf | |||
| 39adfc5fc5 | |||
| 1a848faac4 |
211
README.md
Normal file
211
README.md
Normal file
@@ -0,0 +1,211 @@
|
||||
# web-ai
|
||||
|
||||
Windows AI 머신(AMD 9800X3D + RTX 5070 Ti 16GB)에서 동작하는 두 영역의 서비스:
|
||||
|
||||
1. **ai_trade** — Confidence Signal Pipeline V2. NAS stock 백엔드와 KIS Open API를 결합해 매수/매도 신호를 생성하는 FastAPI 워커.
|
||||
2. **services** — NAS↔Windows 분산 렌더링 워커(인스타 카드 / 음악 / 영상 / 이미지) + task-watcher.
|
||||
|
||||
상위 워크스페이스 컨텍스트는 `../CLAUDE.md`, 본 디렉토리 상세는 `CLAUDE.md`, 운영 체크포인트는 `CHECK_POINT.md` 참조.
|
||||
|
||||
---
|
||||
|
||||
## 디렉토리 구조
|
||||
|
||||
| 경로 | 역할 | 포트 |
|
||||
|------|------|------|
|
||||
| `ai_trade/` | 자동매매 메인. Chronos-bolt(또는 Chronos-2) + 분봉 모멘텀 + KIS WebSocket 호가 + 매수/매도 신호 생성기. | `:8001` |
|
||||
| `services/_shared/` | 4개 render worker 공통 모듈 (`ReliableQueue` — BLMOVE + ack/fail + recovery). | — |
|
||||
| `services/insta-render/` | Instagram 카드 Playwright 렌더 워커. NAS Redis `queue:insta-render` 소비. | `:18710` |
|
||||
| `services/music-render/` | Suno + MusicGen 음악 생성 워커. `queue:music-render` 소비. | `:18711` |
|
||||
| `services/video-render/` | sora / veo / kling / seedance 4 provider 영상 생성 게이트웨이. `queue:video-render` 소비. | `:18712` |
|
||||
| `services/image-render/` | gpt_image / nano_banana / flux(ComfyUI 로컬) 3 provider. `queue:image-render` 소비. | `:18714` |
|
||||
| `services/task-watcher/` | 박재오 작업 시간대에 `queue:paused` 토글 → 워커 일시 정지. | `:18713` |
|
||||
| `legacy/signal_v1/` | ⚠ **DEPRECATED** (2026-05-19). LSTM 봇. 자동 실행 차단됨. | OFF |
|
||||
|
||||
---
|
||||
|
||||
## ai_trade — Confidence Signal Pipeline V2
|
||||
|
||||
NAS stock 백엔드(`:18500`)에서 portfolio / news_sentiment / screener를 pull하고, KIS REST/WebSocket으로 분봉·호가를 보강한 뒤 Chronos 예측과 5분봉 모멘텀 분류로 매수/매도 신호를 생성한다.
|
||||
|
||||
### 매수 (screener Top-N + portfolio)
|
||||
|
||||
모두 충족 시 confidence 계산 → threshold 초과 시 emit:
|
||||
|
||||
1. `chronos.median > 0`
|
||||
2. `chronos.q90 - chronos.q10 < 0.6` (absolute spread)
|
||||
3. `minute_momentum == strong_up`
|
||||
4. `asking_price.bid_ratio >= 0.6`
|
||||
|
||||
종합 confidence = `chronos_conf * 0.5 + minute_score * 0.3 + screener_norm * 0.2`. `> 0.7` 시 emit.
|
||||
|
||||
### 매도 (portfolio only, 우선순위 stop_loss → anomaly → take_profit)
|
||||
|
||||
- **stop_loss**: `pnl_pct < -7%` 즉시 (confidence=1.0)
|
||||
- **anomaly**: `chronos.median < -1%` + `strong_down` + `bid_ratio < 0.4` + 종합 conf > 0.7
|
||||
- **take_profit**: `pnl_pct > 15%` 검토 (confidence=0.6)
|
||||
|
||||
### 핵심 파일
|
||||
|
||||
| 파일 | 책임 |
|
||||
|------|------|
|
||||
| `main.py` | FastAPI app + lifespan (의존성 wiring) + poll_loop task 생성 |
|
||||
| `config.py` | `Settings` dataclass — 환경변수 로드 |
|
||||
| `state.py` | `PollState` (process-wide singleton) — portfolio·screener·signals 등 + `get_active_signals` / `purge_expired_signals` |
|
||||
| `stock_client.py` | NAS stock 백엔드 pull (X-WebAI-Key + 메모리 캐시) |
|
||||
| `kis_client.py` | KIS REST 분봉/호가 + asyncio.Lock 직렬화 + 지수 backoff |
|
||||
| `kis_websocket.py` | KIS WebSocket 호가 + approval_key + 재연결 |
|
||||
| `chronos_predictor.py` | HuggingFace Chronos zero-shot 분위수 예측 (FP32 강제) |
|
||||
| `minute_momentum.py` | 5분봉 → strong_up / weak_up / neutral / weak_down / strong_down |
|
||||
| `signal_generator.py` | 매수/매도 룰 엔진. cycle_id + expires_at 부착 |
|
||||
| `pull_worker.py` | asyncio cron — 시간대별 분기 + post-close 트리거 + signal 생성 + expired purge |
|
||||
| `scheduler.py` | 폴링 윈도우 판정 (KST 캘린더 + 휴장일) |
|
||||
| `rate_limit.py` | 초당 N회 token bucket + `SignalDedup` SQLite WAL |
|
||||
|
||||
### 시작
|
||||
|
||||
```bat
|
||||
cd ai_trade
|
||||
start.bat
|
||||
```
|
||||
|
||||
→ `Uvicorn running on http://0.0.0.0:8001`, `poll_loop started`.
|
||||
|
||||
휴장일/장 외 시간엔 poll_loop만 idle.
|
||||
|
||||
### 헬스 / 로그
|
||||
|
||||
```powershell
|
||||
curl http://localhost:8001/health
|
||||
Get-Content logs\ai_trade.log -Wait
|
||||
nvidia-smi
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## services — NAS↔Windows 분산 워커
|
||||
|
||||
NAS측 lab 서비스(insta-lab / music-lab / video-lab / image-render NAS측)가 `queue:<worker>-render` 에 LPUSH로 작업을 enqueue. Windows worker가 BLMOVE로 atomic dequeue 후 처리, 완료 시 NAS internal webhook으로 결과 통지.
|
||||
|
||||
### 신뢰성 패턴 (`_shared.ReliableQueue`)
|
||||
|
||||
- **dequeue**: `BLMOVE main → processing:<queue>:<worker_id>` (atomic).
|
||||
- **ack**: `LREM processing 1 raw` (성공).
|
||||
- **fail**: `LREM processing` → `attempts++` 후 main 재큐 또는 `max_attempts` 도달 시 `dead_letter:<queue>` 이동.
|
||||
- **recover**: startup 시 자신의 processing list orphan을 main queue로 (attempts 증가).
|
||||
|
||||
### 시작 (NAS, WSL2 Docker)
|
||||
|
||||
```bash
|
||||
cd services
|
||||
docker compose up -d insta-render music-render video-render image-render task-watcher
|
||||
```
|
||||
|
||||
build context는 `services/` 루트. 각 Dockerfile은 `_shared` 모듈을 함께 COPY하고 `PYTHONPATH=/app`.
|
||||
|
||||
### 운영 조작
|
||||
|
||||
```bash
|
||||
# 워커 일시 정지 / 재개
|
||||
redis-cli -h 192.168.45.54 SET queue:paused 1
|
||||
redis-cli -h 192.168.45.54 DEL queue:paused
|
||||
|
||||
# 큐 / dead-letter 점검
|
||||
redis-cli -h 192.168.45.54 LLEN queue:insta-render
|
||||
redis-cli -h 192.168.45.54 LLEN dead_letter:queue:insta-render
|
||||
redis-cli -h 192.168.45.54 KEYS 'processing:*'
|
||||
```
|
||||
|
||||
### 환경 변수
|
||||
|
||||
| 변수 | 용도 |
|
||||
|------|------|
|
||||
| `REDIS_URL` | NAS Redis (`redis://192.168.45.54:6379`) |
|
||||
| `NAS_BASE_URL` | NAS 대상 서비스 URL (insta-lab `:18700`, music-lab `:18600`, video-lab `:18801`, image-render NAS측 `:18802`) |
|
||||
| `INTERNAL_API_KEY` | NAS internal webhook 인증 |
|
||||
| `WORKER_ID` | (권장) `<service>-prod-1` 등 영속 ID. hostname 기반 default는 컨테이너 재기동 시 바뀌어 orphan 추적 불가 |
|
||||
| `OPENAI_API_KEY` / `GEMINI_API_KEY` / `KLING_*` / `SEEDANCE_API_KEY` / `SUNO_API_KEY` | 각 provider 인증 |
|
||||
| `COMFYUI_URL` | image-render FLUX 로컬 ComfyUI (`http://host.docker.internal:8188`) |
|
||||
| `FLUX_BLOCK_TRADING_HOURS` | `1` 이면 장중(09:00~15:30) FLUX 차단 (Chronos GPU 보호) |
|
||||
|
||||
---
|
||||
|
||||
## 환경 변수 (ai_trade)
|
||||
|
||||
| 변수 | 기본 | 설명 |
|
||||
|------|------|------|
|
||||
| `STOCK_API_URL` | (필수) | NAS stock 백엔드 base URL |
|
||||
| `WEBAI_API_KEY` | (필수) | stock 백엔드 호출 시 X-WebAI-Key |
|
||||
| `SIGNAL_V2_PORT` | `8001` | uvicorn 포트 |
|
||||
| `KIS_ENV_TYPE` | `virtual` | `virtual` / `real` |
|
||||
| `KIS_REAL_APP_KEY` / `KIS_REAL_APP_SECRET` / `KIS_REAL_ACCOUNT` | — | KIS 실계좌 |
|
||||
| `KIS_VIRTUAL_APP_KEY` / `KIS_VIRTUAL_APP_SECRET` / `KIS_VIRTUAL_ACCOUNT` | — | KIS 모의계좌 |
|
||||
| `V1_TOKEN_PATH` | `legacy/signal_v1/data/kis_token.json` | KIS 토큰 파일 (V1 토큰 read-only 공유) |
|
||||
| `CHRONOS_MODEL` | `amazon/chronos-2` | Chronos 모델 ID |
|
||||
| `STOP_LOSS_PCT` | `-0.07` | 손절 임계 |
|
||||
| `TAKE_PROFIT_PCT` | `0.15` | 익절 임계 |
|
||||
| `CHRONOS_SPREAD_THRESHOLD` | `0.6` | 매수 hard gate spread 상한 |
|
||||
| `ASKING_BID_RATIO_THRESHOLD` | `0.6` | 매수 hard gate 호가 비율 |
|
||||
| `CONFIDENCE_THRESHOLD` | `0.7` | 매수 종합 confidence 하한 |
|
||||
| `MIN_MOMENTUM_FOR_BUY` | `strong_up` | 매수 hard gate 모멘텀 단계 |
|
||||
| `SIGNAL_TTL_SECONDS` | `300` | emit signal expires_at TTL |
|
||||
|
||||
`.env` 는 web-ai 루트 (이 디렉토리)에 둔다. **절대 커밋 금지.**
|
||||
|
||||
---
|
||||
|
||||
## 테스트
|
||||
|
||||
```bash
|
||||
# ai_trade
|
||||
python -m pytest ai_trade/tests -q
|
||||
|
||||
# services/_shared 공통 모듈
|
||||
cd services/_shared && python -m pytest tests/ -q
|
||||
|
||||
# 각 worker
|
||||
cd services/insta-render && python -m pytest tests/ -q
|
||||
cd services/music-render && python -m pytest tests/ -q
|
||||
cd services/video-render && python -m pytest tests/ -q
|
||||
cd services/image-render && python -m pytest tests/ -q
|
||||
```
|
||||
|
||||
**`.venv` 한글 사용자 경로 깨짐**으로 시스템 Python(`C:\Users\jaeoh\AppData\Local\Programs\Python\Python312\python.exe`) 사용 권장. 또는 `py -3.12 -m pytest …`.
|
||||
|
||||
---
|
||||
|
||||
## 알려진 함정
|
||||
|
||||
1. **KIS rate limit (EGW00201)** — V1+V2 동시 실행 시 충돌. V1은 `legacy/`로 격리. ai_trade는 `asyncio.Lock`으로 throttle 직렬화 (`kis_client.py`).
|
||||
2. **`.venv` 한글 경로** — 시스템 Python 사용.
|
||||
3. **Chronos FP16 overflow** — 한국 주가 5만원+ 시 inf. FP32 강제됨.
|
||||
4. **post-close 트리거** — 상태기반(`last_post_close_date`)으로 변경됨. 16:00 이후 + 오늘 미실행이면 trigger.
|
||||
5. **services worker_id** — env로 명시 권장. hostname 기반 default는 컨테이너 재기동 시 바뀌어 orphan 분실 위험.
|
||||
6. **dead-letter 누적** — `redis-cli LLEN dead_letter:*` 정기 점검 필요.
|
||||
7. **Dockerfile build context** — `services/` 루트 (각 worker 디렉토리 아님). compose 변경 동반.
|
||||
|
||||
---
|
||||
|
||||
## Phase 진행 상태 (Confidence Signal Pipeline V2)
|
||||
|
||||
| Phase | 내용 | 상태 |
|
||||
|-------|------|------|
|
||||
| 0 | Architecture & contract spec | ✅ |
|
||||
| 1 | stock 백엔드 WebAI API 보강 (NAS) | ✅ |
|
||||
| 1.5 | V1 → `signal_v1/` rename → `legacy/` 격리 | ✅ |
|
||||
| 2 | ai_trade pull worker + signal API client + scheduler | ✅ |
|
||||
| 3a | KIS REST 분봉 + WebSocket 호가 + NXT 스케줄 | ✅ |
|
||||
| 3b | Chronos-bolt-base 추론 + 5분봉 모멘텀 분류기 | ✅ |
|
||||
| 4 | Signal Generator + 로깅 | ✅ |
|
||||
| 4.5 | 코드 리뷰 F1-F6 hotfix (토큰 경로 / throttle Lock / post-close 상태기반 / Chronos abs / state.signals lifecycle / render queue 신뢰성) | ✅ |
|
||||
| 5 | agent-office `/signal` + Ollama Qwen3 14B + 이중 텔레그램 | ⏳ |
|
||||
| 6 | signal_v1 deprecation (legacy 완료, 아카이브만 남음) | 일부 ✅ |
|
||||
| 7 | 운영 모니터링 + 4주 IC 검증 | ⏳ |
|
||||
|
||||
상세 spec/plan은 `../web-ui/docs/superpowers/specs/` / `../web-ui/docs/superpowers/plans/` (별도 repo).
|
||||
|
||||
---
|
||||
|
||||
## 라이선스 / 사용
|
||||
|
||||
비공개. 박재오 개인 웹 플랫폼.
|
||||
@@ -10,6 +10,10 @@ import numpy as np
|
||||
logger = logging.getLogger(__name__)
|
||||
KST = ZoneInfo("Asia/Seoul")
|
||||
|
||||
# F4: signal_generator hard gate와 동일한 absolute spread threshold.
|
||||
# zero-shot median≈0에서 conf가 0으로 폭락하던 relative 산식 (spread/abs(median)) 대체.
|
||||
_SPREAD_THRESHOLD = 0.6
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChronosPrediction:
|
||||
@@ -103,8 +107,8 @@ class ChronosPredictor:
|
||||
median = float((q50_price - last_close) / last_close)
|
||||
q10 = float((q10_price - last_close) / last_close)
|
||||
q90 = float((q90_price - last_close) / last_close)
|
||||
spread = (q90 - q10) / max(abs(median), 0.001)
|
||||
conf = float(max(0.0, min(1.0, 1.0 - spread / 2.0)))
|
||||
spread = q90 - q10 # F4: absolute spread
|
||||
conf = float(max(0.0, min(1.0, 1.0 - spread / _SPREAD_THRESHOLD)))
|
||||
results[ticker] = ChronosPrediction(
|
||||
median=median, q10=q10, q90=q90, conf=conf, as_of=now_iso,
|
||||
)
|
||||
@@ -124,8 +128,8 @@ class ChronosPredictor:
|
||||
median = float(np.quantile(returns, 0.5))
|
||||
q10 = float(np.quantile(returns, 0.1))
|
||||
q90 = float(np.quantile(returns, 0.9))
|
||||
spread = (q90 - q10) / max(abs(median), 0.001)
|
||||
conf = float(max(0.0, min(1.0, 1.0 - spread / 2.0)))
|
||||
spread = q90 - q10 # F4: absolute spread
|
||||
conf = float(max(0.0, min(1.0, 1.0 - spread / _SPREAD_THRESHOLD)))
|
||||
results[ticker] = ChronosPrediction(
|
||||
median=median, q10=q10, q90=q90, conf=conf, as_of=now_iso,
|
||||
)
|
||||
|
||||
@@ -31,7 +31,7 @@ class Settings:
|
||||
v1_token_path: Path = field(
|
||||
default_factory=lambda: Path(
|
||||
os.getenv("V1_TOKEN_PATH",
|
||||
str(Path(__file__).parent.parent / "signal_v1" / "data" / "kis_token.json"))
|
||||
str(Path(__file__).parent.parent / "legacy" / "signal_v1" / "data" / "kis_token.json"))
|
||||
)
|
||||
)
|
||||
chronos_model: str = field(default_factory=lambda: os.getenv("CHRONOS_MODEL", "amazon/chronos-2"))
|
||||
@@ -53,6 +53,9 @@ class Settings:
|
||||
min_momentum_for_buy: str = field(
|
||||
default_factory=lambda: os.getenv("MIN_MOMENTUM_FOR_BUY", "strong_up")
|
||||
)
|
||||
signal_ttl_seconds: int = field(
|
||||
default_factory=lambda: int(os.getenv("SIGNAL_TTL_SECONDS", "300"))
|
||||
)
|
||||
|
||||
@property
|
||||
def kis_is_virtual(self) -> bool:
|
||||
|
||||
@@ -38,6 +38,7 @@ class KISClient:
|
||||
self._client = httpx.AsyncClient(timeout=timeout)
|
||||
self._token_cache: tuple[str, float] | None = None # (token, file_mtime)
|
||||
self._last_throttle_at = 0.0
|
||||
self._throttle_lock = asyncio.Lock()
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._client.aclose()
|
||||
@@ -56,10 +57,13 @@ class KISClient:
|
||||
return token
|
||||
|
||||
async def _throttle(self) -> None:
|
||||
elapsed = time.monotonic() - self._last_throttle_at
|
||||
if elapsed < _THROTTLE_INTERVAL:
|
||||
await asyncio.sleep(_THROTTLE_INTERVAL - elapsed)
|
||||
self._last_throttle_at = time.monotonic()
|
||||
# F2: Lock으로 직렬화. 없으면 asyncio.gather 동시 호출 시 race로
|
||||
# 같은 elapsed 계산 후 동시에 깨어나 KIS 초당 2회(EGW00201) 위반.
|
||||
async with self._throttle_lock:
|
||||
elapsed = time.monotonic() - self._last_throttle_at
|
||||
if elapsed < _THROTTLE_INTERVAL:
|
||||
await asyncio.sleep(_THROTTLE_INTERVAL - elapsed)
|
||||
self._last_throttle_at = time.monotonic()
|
||||
|
||||
def _common_headers(self, tr_id: str) -> dict[str, str]:
|
||||
token = self._read_v1_token()
|
||||
|
||||
@@ -24,6 +24,7 @@ async def poll_loop(
|
||||
) -> None:
|
||||
"""FastAPI lifespan 에서 asyncio.create_task 로 시작."""
|
||||
logger.info("poll_loop started")
|
||||
last_post_close_date = None # F3: state-based post-close trigger
|
||||
while not shutdown.is_set():
|
||||
now = datetime.now(KST)
|
||||
if _is_market_day(now) and _is_polling_window(now):
|
||||
@@ -36,10 +37,14 @@ async def poll_loop(
|
||||
update_minute_momentum_for_all(state)
|
||||
except Exception:
|
||||
logger.exception("minute momentum update failed")
|
||||
# Post-close trigger (16:00 KST)
|
||||
if _is_post_close_trigger(now) and chronos is not None and kis_client is not None:
|
||||
# Post-close trigger (F3: 상태기반 — 16:00 이후 + 오늘 미실행)
|
||||
if (
|
||||
_is_post_close_trigger(now, last_post_close_date)
|
||||
and chronos is not None and kis_client is not None
|
||||
):
|
||||
try:
|
||||
await _run_post_close_cycle(kis_client, chronos, state)
|
||||
last_post_close_date = now.date()
|
||||
except Exception:
|
||||
logger.exception("post-close cycle failed")
|
||||
# Phase 4: generate signals
|
||||
@@ -49,6 +54,11 @@ async def poll_loop(
|
||||
generate_signals(state, dedup, settings)
|
||||
except Exception:
|
||||
logger.exception("generate_signals failed")
|
||||
# F5: cycle 끝에 expired signal purge (consumer 미사용 케이스 보호)
|
||||
try:
|
||||
state.purge_expired_signals(datetime.now(KST))
|
||||
except Exception:
|
||||
logger.exception("purge_expired_signals failed")
|
||||
interval = _next_interval(now)
|
||||
try:
|
||||
await asyncio.wait_for(shutdown.wait(), timeout=interval)
|
||||
|
||||
@@ -76,12 +76,21 @@ def _seconds_until_nxt_or_market_open(now: datetime) -> float:
|
||||
return 86400.0
|
||||
|
||||
|
||||
def _is_post_close_trigger(now: datetime) -> bool:
|
||||
"""16:00 KST ±1분 (post-close cycle 트리거). 평일/영업일만."""
|
||||
def _is_post_close_trigger(now: datetime, last_post_close_date) -> bool:
|
||||
"""F3 — 16:00 KST 이후 오늘 아직 post-close cycle 안 돌렸으면 True (상태기반).
|
||||
|
||||
이전엔 16:00:00-16:00:59 1분 윈도우라 5분 sleep + 비결정적 cycle 시작시각
|
||||
조합으로 영영 못 잡는 경우 발생 (예: cycle이 15:31에 시작되면 16:01에 깸).
|
||||
|
||||
Args:
|
||||
now: 현재 KST datetime.
|
||||
last_post_close_date: 마지막 post-close 실행 영업일 date (None=미실행).
|
||||
"""
|
||||
if not _is_market_day(now):
|
||||
return False
|
||||
t = now.time()
|
||||
return time(16, 0) <= t < time(16, 1)
|
||||
if now.time() < time(16, 0):
|
||||
return False
|
||||
return last_post_close_date != now.date()
|
||||
|
||||
|
||||
def _seconds_until_next_market_open(now: datetime) -> float:
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -20,7 +20,12 @@ MOMENTUM_SCORES = {
|
||||
|
||||
|
||||
def generate_signals(state, dedup, settings) -> None:
|
||||
"""Phase 4 entry — state-mutating. Evaluation order: sell first (priority), then buy. A ticker receiving a sell signal in this cycle is excluded from buy evaluation to avoid silent overwrite."""
|
||||
"""Phase 4 entry — state-mutating. F5: cycle_id += 1 (호출마다, emit 여부 무관).
|
||||
|
||||
Evaluation order: sell first (priority), then buy. A ticker receiving a sell
|
||||
signal in this cycle is excluded from buy evaluation to avoid silent overwrite.
|
||||
"""
|
||||
state.signal_cycle_id += 1
|
||||
_evaluate_sell_signals(state, dedup, settings)
|
||||
_evaluate_buy_signals(state, dedup, settings)
|
||||
|
||||
@@ -45,9 +50,10 @@ def _evaluate_buy_signals(state, dedup, settings) -> None:
|
||||
if dedup.is_recent(ticker, "buy", within_hours=24):
|
||||
logger.debug("buy %s skipped: dedup 24h", ticker)
|
||||
continue
|
||||
state.signals[ticker] = _build_buy_signal(state, ticker, name, rank, confidence)
|
||||
state.signals[ticker] = _build_buy_signal(state, ticker, name, rank, confidence, settings)
|
||||
dedup.record(ticker, "buy", confidence=confidence)
|
||||
logger.info("signal emit %s buy conf=%.3f rank=%s", ticker, confidence, rank)
|
||||
logger.info("signal emit %s buy conf=%.3f rank=%s cycle=%d",
|
||||
ticker, confidence, rank, state.signal_cycle_id)
|
||||
|
||||
|
||||
def _buy_candidates(state) -> list[tuple[str, str, int | None]]:
|
||||
@@ -96,8 +102,11 @@ def _compute_buy_confidence(state, ticker: str, rank: int | None) -> float:
|
||||
return chronos_conf * 0.5 + minute_score * 0.3 + screener_norm * 0.2
|
||||
|
||||
|
||||
def _build_buy_signal(state, ticker: str, name: str, rank: int | None, confidence: float) -> dict:
|
||||
def _build_buy_signal(state, ticker: str, name: str, rank: int | None, confidence: float, settings) -> dict:
|
||||
ap = state.asking_price[ticker]
|
||||
as_of_dt = datetime.now(KST)
|
||||
ttl = getattr(settings, "signal_ttl_seconds", 300)
|
||||
expires_at = (as_of_dt + timedelta(seconds=ttl)).isoformat()
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"name": name,
|
||||
@@ -107,7 +116,9 @@ def _build_buy_signal(state, ticker: str, name: str, rank: int | None, confidenc
|
||||
"avg_price": None,
|
||||
"pnl_pct": None,
|
||||
"context": _build_context(state, ticker, rank),
|
||||
"as_of": datetime.now(KST).isoformat(),
|
||||
"as_of": as_of_dt.isoformat(),
|
||||
"cycle_id": state.signal_cycle_id,
|
||||
"expires_at": expires_at,
|
||||
}
|
||||
|
||||
|
||||
@@ -132,23 +143,24 @@ def _evaluate_sell_signals(state, dedup, settings) -> None:
|
||||
continue
|
||||
state.signals[ticker] = sell
|
||||
dedup.record(ticker, "sell", confidence=sell["confidence_webai"])
|
||||
logger.info("signal emit %s sell conf=%.3f reason=%s",
|
||||
logger.info("signal emit %s sell conf=%.3f reason=%s cycle=%d",
|
||||
ticker, sell["confidence_webai"],
|
||||
sell.get("context", {}).get("sell_reason"))
|
||||
sell.get("context", {}).get("sell_reason"),
|
||||
state.signal_cycle_id)
|
||||
|
||||
|
||||
def _try_stop_loss(state, holding: dict, settings) -> dict | None:
|
||||
pnl = holding.get("pnl_pct")
|
||||
if pnl is None or pnl >= settings.stop_loss_pct:
|
||||
return None
|
||||
return _build_sell_signal(state, holding, confidence=1.0, reason="stop_loss")
|
||||
return _build_sell_signal(state, holding, confidence=1.0, reason="stop_loss", settings=settings)
|
||||
|
||||
|
||||
def _try_take_profit(state, holding: dict, settings) -> dict | None:
|
||||
pnl = holding.get("pnl_pct")
|
||||
if pnl is None or pnl <= settings.take_profit_pct:
|
||||
return None
|
||||
return _build_sell_signal(state, holding, confidence=0.6, reason="take_profit")
|
||||
return _build_sell_signal(state, holding, confidence=0.6, reason="take_profit", settings=settings)
|
||||
|
||||
|
||||
def _try_anomaly(state, holding: dict, settings) -> dict | None:
|
||||
@@ -168,11 +180,14 @@ def _try_anomaly(state, holding: dict, settings) -> dict | None:
|
||||
confidence = pred["conf"] * 0.5 + minute_score * 0.3 + 1.0 * 0.2
|
||||
if confidence <= settings.confidence_threshold:
|
||||
return None
|
||||
return _build_sell_signal(state, holding, confidence=confidence, reason="anomaly")
|
||||
return _build_sell_signal(state, holding, confidence=confidence, reason="anomaly", settings=settings)
|
||||
|
||||
|
||||
def _build_sell_signal(state, holding: dict, confidence: float, reason: str) -> dict:
|
||||
def _build_sell_signal(state, holding: dict, confidence: float, reason: str, settings=None) -> dict:
|
||||
ticker = holding["ticker"]
|
||||
as_of_dt = datetime.now(KST)
|
||||
ttl = getattr(settings, "signal_ttl_seconds", 300) if settings else 300
|
||||
expires_at = (as_of_dt + timedelta(seconds=ttl)).isoformat()
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"name": holding.get("name", ticker),
|
||||
@@ -182,7 +197,9 @@ def _build_sell_signal(state, holding: dict, confidence: float, reason: str) ->
|
||||
"avg_price": holding.get("avg_price"),
|
||||
"pnl_pct": holding.get("pnl_pct"),
|
||||
"context": _build_context(state, ticker, rank=None, sell_reason=reason),
|
||||
"as_of": datetime.now(KST).isoformat(),
|
||||
"as_of": as_of_dt.isoformat(),
|
||||
"cycle_id": state.signal_cycle_id,
|
||||
"expires_at": expires_at,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""PollState — process-wide singleton."""
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -15,8 +16,44 @@ class PollState:
|
||||
chronos_predictions: dict[str, dict] = field(default_factory=dict)
|
||||
minute_momentum: dict[str, str] = field(default_factory=dict)
|
||||
signals: dict[str, dict] = field(default_factory=dict)
|
||||
# F5 lifecycle
|
||||
signal_cycle_id: int = 0
|
||||
last_updated: dict[str, str] = field(default_factory=dict)
|
||||
fetch_errors: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
def get_active_signals(self, now: datetime) -> list[dict]:
|
||||
"""expires_at > now 인 신호만 반환. expires_at 없거나 파싱 실패는 expired 취급."""
|
||||
active: list[dict] = []
|
||||
for sig in self.signals.values():
|
||||
expires_at = sig.get("expires_at")
|
||||
if not expires_at:
|
||||
continue
|
||||
try:
|
||||
exp_dt = datetime.fromisoformat(expires_at)
|
||||
except ValueError:
|
||||
continue
|
||||
if exp_dt > now:
|
||||
active.append(sig)
|
||||
return active
|
||||
|
||||
def purge_expired_signals(self, now: datetime) -> int:
|
||||
"""만료된 signal 제거. expires_at 없거나 파싱 실패도 제거. 제거 개수 반환."""
|
||||
to_drop = []
|
||||
for ticker, sig in self.signals.items():
|
||||
expires_at = sig.get("expires_at")
|
||||
if not expires_at:
|
||||
to_drop.append(ticker)
|
||||
continue
|
||||
try:
|
||||
exp_dt = datetime.fromisoformat(expires_at)
|
||||
except ValueError:
|
||||
to_drop.append(ticker)
|
||||
continue
|
||||
if exp_dt <= now:
|
||||
to_drop.append(ticker)
|
||||
for t in to_drop:
|
||||
del self.signals[t]
|
||||
return len(to_drop)
|
||||
|
||||
|
||||
state = PollState()
|
||||
|
||||
@@ -90,3 +90,54 @@ def test_return_computed_from_price_relative_to_last_close(mock_pipeline, mock_t
|
||||
daily = {"005930": _daily_ohlcv(list(range(41, 101)))} # last = 100
|
||||
result = predictor.predict_batch(daily)
|
||||
assert abs(result["005930"].median - 0.10) < 0.001
|
||||
|
||||
|
||||
# ----- F4: absolute spread 기반 confidence -----
|
||||
|
||||
def test_confidence_high_when_spread_near_zero(mock_pipeline, mock_torch_cpu):
|
||||
"""F4 — median≈0 + spread≈0 일 때 conf≈1 (현 relative 산식의 회귀 케이스).
|
||||
|
||||
한국 주가 100000원, q10=q50=q90=100000 → median=0, spread=0.
|
||||
Relative 산식 (spread/abs(median))은 0/0.001 보호선이라 spread=0이면 conf=1로
|
||||
동작하지만, median≈0 + 미세 spread(예 1원) 케이스에서 폭증 → conf=0.
|
||||
Absolute 산식은 그런 폭증 없음.
|
||||
"""
|
||||
quantiles = _mk_quantiles_tensor(100000.0, 100000.0, 100000.0)
|
||||
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
|
||||
|
||||
from ai_trade.chronos_predictor import ChronosPredictor
|
||||
predictor = ChronosPredictor(model_name="mock-model")
|
||||
daily = {"005930": _daily_ohlcv([100000] * 60)}
|
||||
result = predictor.predict_batch(daily)
|
||||
assert result["005930"].conf > 0.95, (
|
||||
f"median≈0 + spread≈0인데 conf={result['005930'].conf} (F4 회귀)"
|
||||
)
|
||||
|
||||
|
||||
def test_confidence_half_at_spread_03(mock_pipeline, mock_torch_cpu):
|
||||
"""F4 — spread 0.30일 때 conf ≈ 0.5 (1 - 0.3/0.6)."""
|
||||
# q10=85000 → -0.15, q90=115000 → 0.15, q50=100000 → 0.0
|
||||
# spread = 0.30, conf = 1 - 0.30/0.60 = 0.50
|
||||
quantiles = _mk_quantiles_tensor(85000.0, 100000.0, 115000.0)
|
||||
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
|
||||
|
||||
from ai_trade.chronos_predictor import ChronosPredictor
|
||||
predictor = ChronosPredictor(model_name="mock-model")
|
||||
daily = {"005930": _daily_ohlcv([100000] * 60)}
|
||||
result = predictor.predict_batch(daily)
|
||||
conf = result["005930"].conf
|
||||
assert 0.45 < conf < 0.55, f"spread=0.30에서 conf={conf} (expected ≈0.5)"
|
||||
|
||||
|
||||
def test_confidence_zero_at_threshold_spread(mock_pipeline, mock_torch_cpu):
|
||||
"""F4 — spread가 _SPREAD_THRESHOLD(0.6)이면 conf=0."""
|
||||
quantiles = _mk_quantiles_tensor(70000.0, 100000.0, 130000.0)
|
||||
mock_pipeline.predict_quantiles.return_value = (quantiles, None)
|
||||
|
||||
from ai_trade.chronos_predictor import ChronosPredictor
|
||||
predictor = ChronosPredictor(model_name="mock-model")
|
||||
daily = {"005930": _daily_ohlcv([100000] * 60)}
|
||||
result = predictor.predict_batch(daily)
|
||||
assert result["005930"].conf < 0.05, (
|
||||
f"spread=threshold에서 conf={result['005930'].conf} (expected ≈0)"
|
||||
)
|
||||
|
||||
22
ai_trade/tests/test_config_token_path.py
Normal file
22
ai_trade/tests/test_config_token_path.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""F1 — V1_TOKEN_PATH default가 legacy/signal_v1/ 경유인지 검증."""
|
||||
from pathlib import Path
|
||||
|
||||
from ai_trade.config import Settings
|
||||
|
||||
|
||||
def test_v1_token_default_path_uses_legacy_dir(monkeypatch):
|
||||
"""env에 V1_TOKEN_PATH 없으면 legacy/signal_v1/data/kis_token.json"""
|
||||
monkeypatch.delenv("V1_TOKEN_PATH", raising=False)
|
||||
settings = Settings()
|
||||
expected_suffix = Path("legacy") / "signal_v1" / "data" / "kis_token.json"
|
||||
assert str(settings.v1_token_path).endswith(str(expected_suffix)), (
|
||||
f"expected default to end with {expected_suffix}, got {settings.v1_token_path}"
|
||||
)
|
||||
|
||||
|
||||
def test_v1_token_env_override_wins(monkeypatch, tmp_path):
|
||||
"""env로 명시한 경로가 default를 덮어씀."""
|
||||
custom = tmp_path / "custom_token.json"
|
||||
monkeypatch.setenv("V1_TOKEN_PATH", str(custom))
|
||||
settings = Settings()
|
||||
assert settings.v1_token_path == custom
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Tests for KISClient (REST)."""
|
||||
import asyncio
|
||||
import json
|
||||
import time as time_module
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
@@ -159,3 +161,30 @@ async def test_get_daily_ohlcv_returns_60_bars(kis_client_factory):
|
||||
assert "datetime" in bars[0]
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
@respx.mock
|
||||
async def test_throttle_serializes_concurrent_gather(kis_client_factory):
|
||||
"""F2 — 5개 동시 요청이 asyncio.gather로 들어와도 0.5초 간격으로 직렬화.
|
||||
|
||||
초당 2회 = 0.5초 간격. 5개 요청 시 최소 (5-1)*0.5 = 2.0초.
|
||||
Lock 없으면 race condition으로 거의 동시에 나가 0.5초대로 끝남.
|
||||
"""
|
||||
sample = {"output2": []}
|
||||
respx.get(
|
||||
"https://openapivts.koreainvestment.com:29443"
|
||||
"/uapi/domestic-stock/v1/quotations/inquire-time-itemchartprice"
|
||||
).mock(return_value=httpx.Response(200, json=sample))
|
||||
|
||||
client = kis_client_factory()
|
||||
try:
|
||||
start = time_module.monotonic()
|
||||
await asyncio.gather(*[client.get_minute_ohlcv(f"00593{i}") for i in range(5)])
|
||||
elapsed = time_module.monotonic() - start
|
||||
# 5 throttle = 최소 (5-1)*0.5 = 2.0s, tolerance 0.3s
|
||||
assert elapsed >= 1.7, (
|
||||
f"throttle race condition: 5 concurrent calls took only {elapsed:.2f}s, "
|
||||
f"expected >=1.7s (0.5s * 4 inter-call gaps)"
|
||||
)
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
@@ -122,6 +122,7 @@ def test_poll_loop_calls_generate_signals_after_cycle(monkeypatch):
|
||||
settings.asking_bid_ratio_threshold = 0.6
|
||||
settings.confidence_threshold = 0.7
|
||||
settings.min_momentum_for_buy = "strong_up"
|
||||
settings.signal_ttl_seconds = 300
|
||||
|
||||
generate_signals(state, dedup, settings)
|
||||
|
||||
@@ -129,3 +130,112 @@ def test_poll_loop_calls_generate_signals_after_cycle(monkeypatch):
|
||||
assert state.signals["005930"]["action"] == "sell"
|
||||
assert state.signals["005930"]["confidence_webai"] == 1.0
|
||||
dedup.record.assert_called_with("005930", "sell", confidence=1.0)
|
||||
|
||||
|
||||
async def test_post_close_fires_at_1601_when_not_yet_today(monkeypatch):
|
||||
"""F3 — 16:01에 깬 cycle도 오늘 post_close 안 돌렸으면 호출됨 (회귀 방지)."""
|
||||
from datetime import datetime as _dt
|
||||
from zoneinfo import ZoneInfo as _ZI
|
||||
import asyncio as _asyncio
|
||||
|
||||
from ai_trade import pull_worker
|
||||
|
||||
_kst = _ZI("Asia/Seoul")
|
||||
now_at_1601 = _dt(2026, 5, 18, 16, 1, tzinfo=_kst)
|
||||
|
||||
class FrozenDateTime:
|
||||
@staticmethod
|
||||
def now(tz=None):
|
||||
return now_at_1601
|
||||
|
||||
monkeypatch.setattr(pull_worker, "datetime", FrozenDateTime)
|
||||
monkeypatch.setattr(pull_worker, "_is_market_day", lambda n: True)
|
||||
monkeypatch.setattr(pull_worker, "_is_polling_window", lambda n: True)
|
||||
monkeypatch.setattr(pull_worker, "_next_interval", lambda n: 0.01)
|
||||
monkeypatch.setattr(pull_worker, "_run_polling_cycle", AsyncMock())
|
||||
monkeypatch.setattr(pull_worker, "update_minute_momentum_for_all", lambda s: None)
|
||||
post_close = AsyncMock()
|
||||
monkeypatch.setattr(pull_worker, "_run_post_close_cycle", post_close)
|
||||
|
||||
state = MagicMock()
|
||||
chronos = MagicMock()
|
||||
kis = MagicMock()
|
||||
shutdown = _asyncio.Event()
|
||||
|
||||
async def _stop_soon():
|
||||
await _asyncio.sleep(0.05)
|
||||
shutdown.set()
|
||||
|
||||
_asyncio.create_task(_stop_soon())
|
||||
await pull_worker.poll_loop(
|
||||
client=MagicMock(),
|
||||
state=state,
|
||||
shutdown=shutdown,
|
||||
kis_client=kis,
|
||||
chronos=chronos,
|
||||
dedup=None,
|
||||
settings=None,
|
||||
)
|
||||
|
||||
assert post_close.await_count >= 1, "post-close가 16:01에 호출되지 않음 (F3 회귀)"
|
||||
|
||||
|
||||
async def test_poll_loop_purges_expired_signals(monkeypatch):
|
||||
"""F5 — 매 cycle 끝에 expired signal이 제거됨."""
|
||||
from datetime import datetime as _dt
|
||||
from zoneinfo import ZoneInfo as _ZI
|
||||
import asyncio as _asyncio
|
||||
|
||||
from ai_trade import pull_worker
|
||||
from ai_trade.state import PollState
|
||||
|
||||
_kst = _ZI("Asia/Seoul")
|
||||
now = _dt(2026, 5, 18, 10, 0, tzinfo=_kst)
|
||||
|
||||
class FrozenDT:
|
||||
@staticmethod
|
||||
def now(tz=None):
|
||||
return now
|
||||
|
||||
state = PollState()
|
||||
state.signals = {
|
||||
"OLD": {
|
||||
"ticker": "OLD",
|
||||
"expires_at": _dt(2026, 5, 18, 9, 0, tzinfo=_kst).isoformat(),
|
||||
"cycle_id": 1,
|
||||
},
|
||||
"FRESH": {
|
||||
"ticker": "FRESH",
|
||||
"expires_at": _dt(2026, 5, 18, 10, 30, tzinfo=_kst).isoformat(),
|
||||
"cycle_id": 1,
|
||||
},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(pull_worker, "datetime", FrozenDT)
|
||||
monkeypatch.setattr(pull_worker, "_is_market_day", lambda n: True)
|
||||
monkeypatch.setattr(pull_worker, "_is_polling_window", lambda n: True)
|
||||
monkeypatch.setattr(pull_worker, "_next_interval", lambda n: 0.01)
|
||||
monkeypatch.setattr(pull_worker, "_run_polling_cycle", AsyncMock())
|
||||
monkeypatch.setattr(pull_worker, "update_minute_momentum_for_all", lambda s: None)
|
||||
monkeypatch.setattr(pull_worker, "_is_post_close_trigger", lambda *a, **k: False)
|
||||
|
||||
shutdown = _asyncio.Event()
|
||||
|
||||
async def stop_soon():
|
||||
await _asyncio.sleep(0.05)
|
||||
shutdown.set()
|
||||
|
||||
_asyncio.create_task(stop_soon())
|
||||
|
||||
await pull_worker.poll_loop(
|
||||
client=MagicMock(),
|
||||
state=state,
|
||||
shutdown=shutdown,
|
||||
kis_client=MagicMock(),
|
||||
chronos=MagicMock(),
|
||||
dedup=None,
|
||||
settings=None,
|
||||
)
|
||||
|
||||
assert "OLD" not in state.signals
|
||||
assert "FRESH" in state.signals
|
||||
|
||||
@@ -79,3 +79,41 @@ def test_next_interval_dead_zone_skip():
|
||||
interval = _next_interval(now)
|
||||
# 02:00 → 04:30 = 2.5h = 9000s
|
||||
assert 9000 - 60 < interval < 9000 + 60
|
||||
|
||||
|
||||
# ----- F3 post-close 상태기반 트리거 -----
|
||||
|
||||
from datetime import date as _date # noqa: E402
|
||||
from ai_trade.scheduler import _is_post_close_trigger # noqa: E402
|
||||
|
||||
|
||||
def test_post_close_trigger_fires_at_1601_if_not_yet_today():
|
||||
"""F3 — 16:01에 깬 cycle도 오늘 아직 안 돌렸으면 trigger."""
|
||||
now = _kst(2026, 5, 18, 16, 1)
|
||||
assert _is_post_close_trigger(now, last_post_close_date=None) is True
|
||||
|
||||
|
||||
def test_post_close_trigger_skips_if_already_today():
|
||||
"""F3 — 이미 오늘 돌렸으면 trigger 안 함."""
|
||||
now = _kst(2026, 5, 18, 16, 5)
|
||||
today = _date(2026, 5, 18)
|
||||
assert _is_post_close_trigger(now, last_post_close_date=today) is False
|
||||
|
||||
|
||||
def test_post_close_trigger_skips_before_1600():
|
||||
"""F3 — 16:00 전에는 trigger 안 함."""
|
||||
now = _kst(2026, 5, 18, 15, 59)
|
||||
assert _is_post_close_trigger(now, last_post_close_date=None) is False
|
||||
|
||||
|
||||
def test_post_close_trigger_fires_next_day_after_reset():
|
||||
"""F3 — 다음 영업일이 되면 다시 trigger."""
|
||||
now = _kst(2026, 5, 19, 16, 0)
|
||||
yesterday = _date(2026, 5, 18)
|
||||
assert _is_post_close_trigger(now, last_post_close_date=yesterday) is True
|
||||
|
||||
|
||||
def test_post_close_trigger_skips_on_holiday():
|
||||
"""F3 — 휴장일에는 trigger 안 함 (2026-05-05 어린이날)."""
|
||||
now = _kst(2026, 5, 5, 16, 30)
|
||||
assert _is_post_close_trigger(now, last_post_close_date=None) is False
|
||||
|
||||
@@ -16,6 +16,7 @@ def _settings(**overrides):
|
||||
asking_bid_ratio_threshold=0.6,
|
||||
confidence_threshold=0.7,
|
||||
min_momentum_for_buy="strong_up",
|
||||
signal_ttl_seconds=300,
|
||||
)
|
||||
defaults.update(overrides)
|
||||
m = MagicMock()
|
||||
@@ -170,3 +171,48 @@ def test_sell_signal_triggers_on_anomaly_path(dedup_mock):
|
||||
assert sig["action"] == "sell"
|
||||
assert sig["context"]["sell_reason"] == "anomaly"
|
||||
assert sig["confidence_webai"] > 0.7
|
||||
|
||||
|
||||
# ----- F5: cycle_id + expires_at 부착 -----
|
||||
|
||||
def test_emit_attaches_cycle_id_and_expires_at(dedup_mock):
|
||||
"""F5 — emit signal에 cycle_id (state.signal_cycle_id) + expires_at 부착."""
|
||||
from datetime import datetime, timedelta
|
||||
from zoneinfo import ZoneInfo
|
||||
_kst = ZoneInfo("Asia/Seoul")
|
||||
|
||||
state = _make_state_with_buy_candidate()
|
||||
before = datetime.now(_kst)
|
||||
generate_signals(state, dedup_mock, _settings(signal_ttl_seconds=300))
|
||||
after = datetime.now(_kst)
|
||||
|
||||
sig = state.signals["005930"]
|
||||
assert sig["cycle_id"] == 1
|
||||
assert "expires_at" in sig
|
||||
exp_dt = datetime.fromisoformat(sig["expires_at"])
|
||||
assert before + timedelta(seconds=295) < exp_dt < after + timedelta(seconds=305)
|
||||
|
||||
|
||||
def test_cycle_id_increments_each_call(dedup_mock):
|
||||
"""F5 — generate_signals 호출마다 cycle_id += 1 (emit 여부 무관)."""
|
||||
state = _make_state_with_buy_candidate()
|
||||
generate_signals(state, dedup_mock, _settings())
|
||||
assert state.signal_cycle_id == 1
|
||||
# 2번째 호출 — dedup이 막아도 cycle_id는 증가
|
||||
dedup_mock.is_recent.return_value = True
|
||||
generate_signals(state, dedup_mock, _settings())
|
||||
assert state.signal_cycle_id == 2
|
||||
|
||||
|
||||
def test_sell_signal_also_carries_cycle_id_and_expires_at(dedup_mock):
|
||||
"""F5 — sell signal도 동일하게 부착."""
|
||||
from datetime import datetime
|
||||
state = _make_state_with_holding(pnl_pct=-0.08, current_price=68000)
|
||||
generate_signals(state, dedup_mock, _settings(signal_ttl_seconds=120))
|
||||
|
||||
assert "005930" in state.signals
|
||||
sig = state.signals["005930"]
|
||||
assert sig["action"] == "sell"
|
||||
assert sig["cycle_id"] == 1
|
||||
# parse expires_at as ISO — must succeed
|
||||
datetime.fromisoformat(sig["expires_at"])
|
||||
|
||||
66
ai_trade/tests/test_state_signals_lifecycle.py
Normal file
66
ai_trade/tests/test_state_signals_lifecycle.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""F5 — state.signals lifecycle (expires_at + cycle_id)."""
|
||||
from datetime import datetime, timedelta
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from ai_trade.state import PollState
|
||||
|
||||
KST = ZoneInfo("Asia/Seoul")
|
||||
|
||||
|
||||
def test_initial_signal_cycle_id_is_zero():
|
||||
state = PollState()
|
||||
assert state.signal_cycle_id == 0
|
||||
|
||||
|
||||
def test_get_active_signals_excludes_expired():
|
||||
state = PollState()
|
||||
now = datetime(2026, 5, 25, 10, 0, tzinfo=KST)
|
||||
future = (now + timedelta(seconds=300)).isoformat()
|
||||
past = (now - timedelta(seconds=60)).isoformat()
|
||||
state.signals = {
|
||||
"A": {"ticker": "A", "expires_at": future, "cycle_id": 1, "action": "buy"},
|
||||
"B": {"ticker": "B", "expires_at": past, "cycle_id": 1, "action": "buy"},
|
||||
}
|
||||
active = state.get_active_signals(now)
|
||||
tickers = [s["ticker"] for s in active]
|
||||
assert "A" in tickers
|
||||
assert "B" not in tickers
|
||||
|
||||
|
||||
def test_get_active_signals_treats_missing_expires_as_expired():
|
||||
"""expires_at 없는 legacy 신호는 expired로 간주."""
|
||||
state = PollState()
|
||||
now = datetime(2026, 5, 25, 10, 0, tzinfo=KST)
|
||||
state.signals = {"C": {"ticker": "C", "action": "buy"}}
|
||||
assert state.get_active_signals(now) == []
|
||||
|
||||
|
||||
def test_purge_expired_signals_removes_expired():
|
||||
state = PollState()
|
||||
now = datetime(2026, 5, 25, 10, 0, tzinfo=KST)
|
||||
future = (now + timedelta(seconds=300)).isoformat()
|
||||
past = (now - timedelta(seconds=60)).isoformat()
|
||||
state.signals = {
|
||||
"A": {"ticker": "A", "expires_at": future, "cycle_id": 1},
|
||||
"B": {"ticker": "B", "expires_at": past, "cycle_id": 1},
|
||||
}
|
||||
removed = state.purge_expired_signals(now)
|
||||
assert "A" in state.signals
|
||||
assert "B" not in state.signals
|
||||
assert removed == 1
|
||||
|
||||
|
||||
# ----- SIGNAL_TTL_SECONDS env -----
|
||||
|
||||
def test_signal_ttl_seconds_default(monkeypatch):
|
||||
monkeypatch.delenv("SIGNAL_TTL_SECONDS", raising=False)
|
||||
from ai_trade.config import Settings
|
||||
s = Settings()
|
||||
assert s.signal_ttl_seconds == 300
|
||||
|
||||
|
||||
def test_signal_ttl_seconds_env_override(monkeypatch):
|
||||
monkeypatch.setenv("SIGNAL_TTL_SECONDS", "60")
|
||||
from ai_trade.config import Settings
|
||||
s = Settings()
|
||||
assert s.signal_ttl_seconds == 60
|
||||
0
services/_shared/__init__.py
Normal file
0
services/_shared/__init__.py
Normal file
2
services/_shared/pytest.ini
Normal file
2
services/_shared/pytest.ini
Normal file
@@ -0,0 +1,2 @@
|
||||
[pytest]
|
||||
asyncio_mode = auto
|
||||
135
services/_shared/reliable_queue.py
Normal file
135
services/_shared/reliable_queue.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""F6 — Reliable Redis queue with processing list + recovery + retry.
|
||||
|
||||
Pattern:
|
||||
- BLMOVE main → processing (atomic dequeue)
|
||||
- ack: LREM processing (1 occurrence)
|
||||
- fail: LREM processing + (re-enqueue with attempts++ OR move to dead-letter)
|
||||
- recover: startup-time orphan recovery (worker's processing list → main queue)
|
||||
|
||||
Producer side stays unchanged: LPUSH queue:<x> <json payload>.
|
||||
Worker side: dequeue() → process → ack(raw) on success or fail(raw, payload) on error.
|
||||
Startup: await queue.recover() to re-enqueue orphans.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def default_worker_id(queue_key: str) -> str:
|
||||
"""env WORKER_ID > hostname-pid."""
|
||||
explicit = os.getenv("WORKER_ID")
|
||||
if explicit:
|
||||
return explicit
|
||||
return f"{queue_key}-{socket.gethostname()}-{os.getpid()}"
|
||||
|
||||
|
||||
class ReliableQueue:
|
||||
"""BLMOVE-backed atomic dequeue + processing list + retry/dead-letter."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis,
|
||||
queue_key: str,
|
||||
worker_id: Optional[str] = None,
|
||||
max_attempts: int = 3,
|
||||
):
|
||||
self._redis = redis
|
||||
self._queue_key = queue_key
|
||||
self._worker_id = worker_id or default_worker_id(queue_key)
|
||||
self._processing_key = f"processing:{queue_key}:{self._worker_id}"
|
||||
self._dead_letter_key = f"dead_letter:{queue_key}"
|
||||
self._max_attempts = max_attempts
|
||||
|
||||
@property
|
||||
def worker_id(self) -> str:
|
||||
return self._worker_id
|
||||
|
||||
@property
|
||||
def processing_key(self) -> str:
|
||||
return self._processing_key
|
||||
|
||||
async def dequeue(self, timeout: int = 5) -> Optional[tuple[dict, bytes]]:
|
||||
"""Atomically move 1 item from main queue tail to processing head.
|
||||
|
||||
Returns (parsed_dict, raw_bytes) or None on timeout/parse-error.
|
||||
Caller MUST call ack(raw) on success or fail(raw, payload) on error.
|
||||
"""
|
||||
raw = await self._redis.blmove(
|
||||
self._queue_key, self._processing_key,
|
||||
timeout, "RIGHT", "LEFT",
|
||||
)
|
||||
if raw is None:
|
||||
return None
|
||||
try:
|
||||
payload = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(
|
||||
"invalid payload on dequeue, moving to dead-letter: %r", raw[:200]
|
||||
)
|
||||
await self._redis.lrem(self._processing_key, 1, raw)
|
||||
await self._redis.lpush(self._dead_letter_key, raw)
|
||||
return None
|
||||
return payload, raw
|
||||
|
||||
async def ack(self, raw: bytes) -> None:
|
||||
"""Successful processing — remove from processing list."""
|
||||
removed = await self._redis.lrem(self._processing_key, 1, raw)
|
||||
if removed == 0:
|
||||
logger.warning("ack on missing payload (already removed?): %r", raw[:100])
|
||||
|
||||
async def fail(self, raw: bytes, payload: dict) -> None:
|
||||
"""Failed processing — remove from processing list and re-enqueue or dead-letter."""
|
||||
await self._redis.lrem(self._processing_key, 1, raw)
|
||||
attempts = int(payload.get("attempts", 0)) + 1
|
||||
if attempts >= self._max_attempts:
|
||||
payload["attempts"] = attempts
|
||||
await self._redis.lpush(self._dead_letter_key, json.dumps(payload).encode())
|
||||
logger.error(
|
||||
"task moved to dead-letter after %d attempts: task_id=%s",
|
||||
attempts, payload.get("task_id"),
|
||||
)
|
||||
return
|
||||
payload["attempts"] = attempts
|
||||
await self._redis.lpush(self._queue_key, json.dumps(payload).encode())
|
||||
logger.info(
|
||||
"task re-enqueued (attempt %d/%d): task_id=%s",
|
||||
attempts, self._max_attempts, payload.get("task_id"),
|
||||
)
|
||||
|
||||
async def recover(self) -> int:
|
||||
"""Startup: move all orphans from this worker's processing list back to main queue.
|
||||
|
||||
Increments attempts counter (orphan == implicit failure). Returns count.
|
||||
"""
|
||||
count = 0
|
||||
while True:
|
||||
raw = await self._redis.lpop(self._processing_key)
|
||||
if raw is None:
|
||||
break
|
||||
try:
|
||||
payload = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
await self._redis.lpush(self._dead_letter_key, raw)
|
||||
count += 1
|
||||
continue
|
||||
payload["attempts"] = int(payload.get("attempts", 0)) + 1
|
||||
if payload["attempts"] >= self._max_attempts:
|
||||
await self._redis.lpush(
|
||||
self._dead_letter_key, json.dumps(payload).encode()
|
||||
)
|
||||
else:
|
||||
await self._redis.lpush(
|
||||
self._queue_key, json.dumps(payload).encode()
|
||||
)
|
||||
count += 1
|
||||
if count:
|
||||
logger.info(
|
||||
"recovered %d orphaned items for worker %s", count, self._worker_id
|
||||
)
|
||||
return count
|
||||
1
services/_shared/requirements.txt
Normal file
1
services/_shared/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
redis>=5.0.0
|
||||
0
services/_shared/tests/__init__.py
Normal file
0
services/_shared/tests/__init__.py
Normal file
84
services/_shared/tests/test_reliable_queue.py
Normal file
84
services/_shared/tests/test_reliable_queue.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""F6 — ReliableQueue: atomic dequeue + recovery + retry."""
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import fakeredis.aioredis
|
||||
import pytest
|
||||
|
||||
# Make `_shared` importable when tests run from services/_shared
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
|
||||
|
||||
from _shared.reliable_queue import ReliableQueue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def redis():
|
||||
r = fakeredis.aioredis.FakeRedis(decode_responses=False)
|
||||
yield r
|
||||
await r.flushall()
|
||||
await r.aclose()
|
||||
|
||||
|
||||
async def test_dequeue_atomically_moves_to_processing(redis):
|
||||
"""BLMOVE: queue → processing 원자적 이동."""
|
||||
q = ReliableQueue(redis, queue_key="queue:test", worker_id="w1")
|
||||
await redis.lpush("queue:test", json.dumps({"task_id": "t1"}).encode())
|
||||
result = await q.dequeue(timeout=1)
|
||||
assert result is not None
|
||||
payload, raw = result
|
||||
assert payload["task_id"] == "t1"
|
||||
assert await redis.llen("queue:test") == 0
|
||||
assert await redis.llen("processing:queue:test:w1") == 1
|
||||
|
||||
|
||||
async def test_dequeue_returns_none_on_timeout(redis):
|
||||
q = ReliableQueue(redis, queue_key="queue:test", worker_id="w1")
|
||||
result = await q.dequeue(timeout=1)
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_ack_removes_from_processing(redis):
|
||||
q = ReliableQueue(redis, queue_key="queue:test", worker_id="w1")
|
||||
await redis.lpush("queue:test", json.dumps({"task_id": "t1"}).encode())
|
||||
_, raw = await q.dequeue(timeout=1)
|
||||
await q.ack(raw)
|
||||
assert await redis.llen("processing:queue:test:w1") == 0
|
||||
|
||||
|
||||
async def test_recover_returns_orphaned_to_main_queue(redis):
|
||||
"""startup recovery: 잔존 processing list 항목을 main queue로 되돌림."""
|
||||
orphan = json.dumps({"task_id": "t1", "attempts": 0}).encode()
|
||||
await redis.lpush("processing:queue:test:w1", orphan)
|
||||
q = ReliableQueue(redis, queue_key="queue:test", worker_id="w1")
|
||||
recovered = await q.recover()
|
||||
assert recovered == 1
|
||||
assert await redis.llen("processing:queue:test:w1") == 0
|
||||
payload, _ = await q.dequeue(timeout=1)
|
||||
assert payload["task_id"] == "t1"
|
||||
assert payload["attempts"] == 1 # incremented on recover
|
||||
|
||||
|
||||
async def test_fail_below_max_attempts_returns_to_main_queue(redis):
|
||||
q = ReliableQueue(redis, queue_key="queue:test", worker_id="w1", max_attempts=3)
|
||||
await redis.lpush("queue:test", json.dumps({"task_id": "t1", "attempts": 0}).encode())
|
||||
payload, raw = await q.dequeue(timeout=1)
|
||||
await q.fail(raw, payload)
|
||||
assert await redis.llen("processing:queue:test:w1") == 0
|
||||
assert await redis.llen("queue:test") == 1
|
||||
requeued_raw = await redis.lindex("queue:test", 0)
|
||||
requeued = json.loads(requeued_raw)
|
||||
assert requeued["attempts"] == 1
|
||||
|
||||
|
||||
async def test_fail_at_max_attempts_moves_to_dead_letter(redis):
|
||||
q = ReliableQueue(redis, queue_key="queue:test", worker_id="w1", max_attempts=3)
|
||||
await redis.lpush(
|
||||
"queue:test", json.dumps({"task_id": "t1", "attempts": 2}).encode()
|
||||
)
|
||||
payload, raw = await q.dequeue(timeout=1)
|
||||
await q.fail(raw, payload)
|
||||
# attempts 2 → 3 (== max) → dead-letter
|
||||
assert await redis.llen("queue:test") == 0
|
||||
assert await redis.llen("processing:queue:test:w1") == 0
|
||||
assert await redis.llen("dead_letter:queue:test") == 1
|
||||
@@ -3,7 +3,8 @@ name: web-ai-services
|
||||
services:
|
||||
insta-render:
|
||||
build:
|
||||
context: ./insta-render
|
||||
context: .
|
||||
dockerfile: insta-render/Dockerfile
|
||||
container_name: insta-render
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
@@ -26,7 +27,8 @@ services:
|
||||
|
||||
music-render:
|
||||
build:
|
||||
context: ./music-render
|
||||
context: .
|
||||
dockerfile: music-render/Dockerfile
|
||||
container_name: music-render
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
@@ -52,7 +54,8 @@ services:
|
||||
|
||||
video-render:
|
||||
build:
|
||||
context: ./video-render
|
||||
context: .
|
||||
dockerfile: video-render/Dockerfile
|
||||
container_name: video-render
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
@@ -98,7 +101,8 @@ services:
|
||||
|
||||
image-render:
|
||||
build:
|
||||
context: ./image-render
|
||||
context: .
|
||||
dockerfile: image-render/Dockerfile
|
||||
container_name: image-render
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
|
||||
@@ -7,10 +7,13 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
COPY image-render/requirements.txt /app/
|
||||
RUN pip install --no-cache-dir --timeout 600 --retries 5 -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
# F6: 공통 ReliableQueue 모듈 (services/_shared)
|
||||
COPY _shared /app/_shared
|
||||
COPY image-render/. /app/
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
EXPOSE 8000
|
||||
CMD ["python", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "1"]
|
||||
|
||||
5
services/image-render/conftest.py
Normal file
5
services/image-render/conftest.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Make services/ root importable so `from _shared.reliable_queue import ...` works during tests."""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
83
services/image-render/providers/flux_workflow.json
Normal file
83
services/image-render/providers/flux_workflow.json
Normal file
@@ -0,0 +1,83 @@
|
||||
{
|
||||
"5": {
|
||||
"inputs": {
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptyLatentImage",
|
||||
"_meta": {"title": "Empty Latent Image"}
|
||||
},
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "%PROMPT%",
|
||||
"clip": ["11", 0]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {"title": "Positive Prompt"}
|
||||
},
|
||||
"8": {
|
||||
"inputs": {
|
||||
"samples": ["13", 0],
|
||||
"vae": ["10", 0]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {"title": "VAE Decode"}
|
||||
},
|
||||
"9": {
|
||||
"inputs": {
|
||||
"filename_prefix": "flux",
|
||||
"images": ["8", 0]
|
||||
},
|
||||
"class_type": "SaveImage",
|
||||
"_meta": {"title": "Save Image"}
|
||||
},
|
||||
"10": {
|
||||
"inputs": {
|
||||
"vae_name": "ae.safetensors"
|
||||
},
|
||||
"class_type": "VAELoader",
|
||||
"_meta": {"title": "Load VAE"}
|
||||
},
|
||||
"11": {
|
||||
"inputs": {
|
||||
"clip_name1": "clip_l.safetensors",
|
||||
"clip_name2": "t5xxl_fp8_e4m3fn.safetensors",
|
||||
"type": "flux"
|
||||
},
|
||||
"class_type": "DualCLIPLoader",
|
||||
"_meta": {"title": "Dual CLIP Loader"}
|
||||
},
|
||||
"12": {
|
||||
"inputs": {
|
||||
"unet_name": "flux1-schnell-fp8.safetensors",
|
||||
"weight_dtype": "default"
|
||||
},
|
||||
"class_type": "UNETLoader",
|
||||
"_meta": {"title": "Load Diffusion Model"}
|
||||
},
|
||||
"13": {
|
||||
"inputs": {
|
||||
"seed": 0,
|
||||
"steps": 4,
|
||||
"cfg": 1.0,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "simple",
|
||||
"denoise": 1.0,
|
||||
"model": ["12", 0],
|
||||
"positive": ["6", 0],
|
||||
"negative": ["33", 0],
|
||||
"latent_image": ["5", 0]
|
||||
},
|
||||
"class_type": "KSampler",
|
||||
"_meta": {"title": "KSampler"}
|
||||
},
|
||||
"33": {
|
||||
"inputs": {
|
||||
"text": "",
|
||||
"clip": ["11", 0]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {"title": "Negative Prompt (empty for Schnell)"}
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,8 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import worker
|
||||
|
||||
|
||||
@@ -13,3 +18,52 @@ def test_dispatch_unknown_job_type_reports_failed(monkeypatch):
|
||||
monkeypatch.setattr(worker, "webhook_update_task", lambda *a, **k: calls.append((a, k)))
|
||||
worker._dispatch({"job_type": "midjourney_generation", "task_id": "t9", "params": {}})
|
||||
assert calls[-1][0][1] == "failed"
|
||||
|
||||
|
||||
# ----- F6: ReliableQueue poll_once -----
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_once_acks_on_success(monkeypatch):
|
||||
payload = {"task_id": "t1", "job_type": "gpt_image_generation", "params": {}}
|
||||
raw = json.dumps(payload).encode()
|
||||
fake_queue = AsyncMock()
|
||||
fake_queue.dequeue = AsyncMock(return_value=(payload, raw))
|
||||
fake_queue.ack = AsyncMock()
|
||||
fake_queue.fail = AsyncMock()
|
||||
monkeypatch.setattr(worker, "_dispatch", MagicMock())
|
||||
handled = await worker.poll_once(fake_queue)
|
||||
assert handled is True
|
||||
fake_queue.ack.assert_awaited_once_with(raw)
|
||||
fake_queue.fail.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_once_calls_fail_on_dispatch_exception(monkeypatch):
|
||||
payload = {"task_id": "t2", "job_type": "gpt_image_generation", "params": {}}
|
||||
raw = json.dumps(payload).encode()
|
||||
fake_queue = AsyncMock()
|
||||
fake_queue.dequeue = AsyncMock(return_value=(payload, raw))
|
||||
fake_queue.ack = AsyncMock()
|
||||
fake_queue.fail = AsyncMock()
|
||||
|
||||
def _boom(p):
|
||||
raise RuntimeError("dispatch crash")
|
||||
|
||||
monkeypatch.setattr(worker, "_dispatch", _boom)
|
||||
handled = await worker.poll_once(fake_queue)
|
||||
assert handled is True
|
||||
fake_queue.fail.assert_awaited_once_with(raw, payload)
|
||||
fake_queue.ack.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_once_returns_false_on_timeout(monkeypatch):
|
||||
fake_queue = AsyncMock()
|
||||
fake_queue.dequeue = AsyncMock(return_value=None)
|
||||
fake_queue.ack = AsyncMock()
|
||||
fake_queue.fail = AsyncMock()
|
||||
monkeypatch.setattr(worker, "_dispatch", MagicMock())
|
||||
handled = await worker.poll_once(fake_queue)
|
||||
assert handled is False
|
||||
fake_queue.ack.assert_not_awaited()
|
||||
fake_queue.fail.assert_not_awaited()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Redis BLPOP worker — queue:image-render → job_type dispatch → NAS webhook.
|
||||
"""Redis ReliableQueue worker — F6 신뢰성 패턴 (BLMOVE + ack/fail + recovery).
|
||||
|
||||
queue:paused 가 set이면 대기 (task-watcher가 박재오 활동 감지 시 set).
|
||||
video-render worker.py 패턴 — string-based dispatch + getattr (테스트 patch 호환).
|
||||
string-based dispatch + getattr (테스트 patch 호환).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -17,6 +17,7 @@ from nas_client import webhook_update_task
|
||||
from providers.gpt_image import run_gpt_image_generation
|
||||
from providers.nano_banana import run_nano_banana_generation
|
||||
from providers.flux import run_flux_generation
|
||||
from _shared.reliable_queue import ReliableQueue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -52,25 +53,42 @@ def _dispatch(payload: dict) -> None:
|
||||
fn(task_id, params)
|
||||
|
||||
|
||||
async def poll_once(queue: ReliableQueue) -> bool:
|
||||
"""F6 — 1 cycle: dequeue → _dispatch → ack/fail. Returns True if a job handled."""
|
||||
result = await queue.dequeue(timeout=5)
|
||||
if result is None:
|
||||
return False
|
||||
payload, raw = result
|
||||
try:
|
||||
await asyncio.to_thread(_dispatch, payload)
|
||||
except Exception:
|
||||
logger.exception("dispatch unhandled exception task_id=%s",
|
||||
payload.get("task_id"))
|
||||
await queue.fail(raw, payload)
|
||||
return True
|
||||
await queue.ack(raw)
|
||||
return True
|
||||
|
||||
|
||||
async def worker_loop():
|
||||
redis = aioredis.from_url(REDIS_URL, decode_responses=False)
|
||||
logger.info("image-render worker started (queue=%s)", QUEUE_KEY)
|
||||
queue = ReliableQueue(redis, queue_key=QUEUE_KEY)
|
||||
logger.info("image-render worker started worker_id=%s queue=%s",
|
||||
queue.worker_id, QUEUE_KEY)
|
||||
try:
|
||||
recovered = await queue.recover()
|
||||
if recovered:
|
||||
logger.info("recovered %d orphaned items at startup", recovered)
|
||||
except Exception:
|
||||
logger.exception("startup recover failed")
|
||||
|
||||
while True:
|
||||
try:
|
||||
paused = await redis.get(PAUSED_KEY)
|
||||
if paused == b"1":
|
||||
await asyncio.sleep(10)
|
||||
continue
|
||||
item = await redis.blpop(QUEUE_KEY, timeout=5)
|
||||
if item is None:
|
||||
continue
|
||||
_, raw = item
|
||||
try:
|
||||
payload = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("invalid queue payload: %r", raw[:200])
|
||||
continue
|
||||
await asyncio.to_thread(_dispatch, payload)
|
||||
await poll_once(queue)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("worker_loop cancelled")
|
||||
raise
|
||||
|
||||
@@ -12,11 +12,14 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libcairo2 libasound2 libatspi2.0-0 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
COPY insta-render/requirements.txt /app/
|
||||
RUN pip install --no-cache-dir --timeout 600 --retries 5 -r requirements.txt
|
||||
RUN playwright install chromium
|
||||
|
||||
COPY . .
|
||||
# F6: 공통 ReliableQueue 모듈 (services/_shared)
|
||||
COPY _shared /app/_shared
|
||||
COPY insta-render/. /app/
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
EXPOSE 8000
|
||||
CMD ["python", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "1"]
|
||||
|
||||
5
services/insta-render/conftest.py
Normal file
5
services/insta-render/conftest.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Make services/ root importable so `from _shared.reliable_queue import ...` works during tests."""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
@@ -112,11 +112,88 @@ async def test_process_one_render_failure_reports_failed(monkeypatch, fake_slate
|
||||
worker.NAS_BASE_URL = "http://nas.test"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
await worker._process_one(client, {
|
||||
"task_id": "t-3",
|
||||
"params": {"slate_id": 99},
|
||||
})
|
||||
# F6: _process_one은 webhook(failed) 호출 후 raise — poll_once가 fail(raw)로 retry/dead-letter.
|
||||
with pytest.raises(RuntimeError, match="Chromium"):
|
||||
await worker._process_one(client, {
|
||||
"task_id": "t-3",
|
||||
"params": {"slate_id": 99},
|
||||
})
|
||||
|
||||
last = calls[-1]
|
||||
assert last["status"] == "failed"
|
||||
assert "Chromium" in last["error"]
|
||||
|
||||
|
||||
# ----- F6: ReliableQueue (ack on success, fail on exception) -----
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_once_acks_on_success(monkeypatch):
|
||||
"""F6 — 성공 시 queue.ack(raw) 호출 + fail 안 부름."""
|
||||
fake_payload = {
|
||||
"task_id": "t-ok",
|
||||
"params": {"slate_id": 7, "theme": "default"},
|
||||
}
|
||||
fake_raw = json.dumps(fake_payload).encode()
|
||||
|
||||
fake_queue = AsyncMock()
|
||||
fake_queue.dequeue = AsyncMock(return_value=(fake_payload, fake_raw))
|
||||
fake_queue.ack = AsyncMock()
|
||||
fake_queue.fail = AsyncMock()
|
||||
|
||||
process_mock = AsyncMock()
|
||||
monkeypatch.setattr(worker, "_process_one", process_mock)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
handled = await worker.poll_once(fake_queue, client)
|
||||
|
||||
assert handled is True
|
||||
process_mock.assert_awaited_once()
|
||||
fake_queue.ack.assert_awaited_once_with(fake_raw)
|
||||
fake_queue.fail.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_once_calls_fail_on_exception(monkeypatch):
|
||||
"""F6 — _process_one 예외 시 queue.fail(raw, payload) 호출."""
|
||||
fake_payload = {
|
||||
"task_id": "t-err",
|
||||
"params": {"slate_id": 9, "theme": "default"},
|
||||
}
|
||||
fake_raw = json.dumps(fake_payload).encode()
|
||||
|
||||
fake_queue = AsyncMock()
|
||||
fake_queue.dequeue = AsyncMock(return_value=(fake_payload, fake_raw))
|
||||
fake_queue.ack = AsyncMock()
|
||||
fake_queue.fail = AsyncMock()
|
||||
|
||||
async def boom(client, payload):
|
||||
raise RuntimeError("simulated dispatch failure")
|
||||
|
||||
monkeypatch.setattr(worker, "_process_one", boom)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
handled = await worker.poll_once(fake_queue, client)
|
||||
|
||||
assert handled is True
|
||||
fake_queue.fail.assert_awaited_once_with(fake_raw, fake_payload)
|
||||
fake_queue.ack.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_once_returns_false_on_timeout(monkeypatch):
|
||||
"""F6 — dequeue가 None 반환(타임아웃)이면 False 리턴, ack/fail 안 부름."""
|
||||
fake_queue = AsyncMock()
|
||||
fake_queue.dequeue = AsyncMock(return_value=None)
|
||||
fake_queue.ack = AsyncMock()
|
||||
fake_queue.fail = AsyncMock()
|
||||
|
||||
process_mock = AsyncMock()
|
||||
monkeypatch.setattr(worker, "_process_one", process_mock)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
handled = await worker.poll_once(fake_queue, client)
|
||||
|
||||
assert handled is False
|
||||
process_mock.assert_not_awaited()
|
||||
fake_queue.ack.assert_not_awaited()
|
||||
fake_queue.fail.assert_not_awaited()
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
"""Redis BLPOP worker — queue:insta-render → render_slate → NAS webhook.
|
||||
"""Redis ReliableQueue worker — F6 신뢰성 패턴 (BLMOVE + ack/fail + recovery).
|
||||
|
||||
queue:paused가 set이면 대기 (task-watcher가 박재오 활동 감지 시 set).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
@@ -14,6 +13,7 @@ import httpx
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from card_renderer import render_slate
|
||||
from _shared.reliable_queue import ReliableQueue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -57,7 +57,10 @@ async def _fetch_slate(client: httpx.AsyncClient, slate_id: int) -> dict:
|
||||
|
||||
|
||||
async def _process_one(client: httpx.AsyncClient, payload: dict) -> None:
|
||||
"""단일 작업 처리: fetch slate → render → webhook."""
|
||||
"""단일 작업 처리: fetch slate → render → webhook. 예외 발생 시 webhook(failed) 호출 후 raise.
|
||||
|
||||
F6: webhook 통신 외 예외는 poll_once가 fail(raw, payload)로 retry/dead-letter 처리.
|
||||
"""
|
||||
task_id = payload["task_id"]
|
||||
params = payload.get("params", {})
|
||||
slate_id = params.get("slate_id")
|
||||
@@ -69,7 +72,6 @@ async def _process_one(client: httpx.AsyncClient, payload: dict) -> None:
|
||||
slate = await _fetch_slate(client, slate_id)
|
||||
await _post_update(client, task_id, "processing", 50)
|
||||
paths = await render_slate(slate, slate_id, template=template)
|
||||
# 결과 URL은 첫 페이지의 nginx 경로
|
||||
first_url = f"{INSTA_MEDIA_URL_PREFIX}/{slate_id}/01.png"
|
||||
await _post_update(
|
||||
client, task_id, "succeeded", 100, result_path=first_url
|
||||
@@ -78,29 +80,46 @@ async def _process_one(client: httpx.AsyncClient, payload: dict) -> None:
|
||||
except Exception as e:
|
||||
logger.exception("render task=%s 실패", task_id)
|
||||
await _post_update(client, task_id, "failed", 0, error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
async def poll_once(queue: ReliableQueue, client: httpx.AsyncClient) -> bool:
|
||||
"""1 cycle: dequeue → _process_one → ack/fail. Returns True if a job handled."""
|
||||
result = await queue.dequeue(timeout=5)
|
||||
if result is None:
|
||||
return False
|
||||
payload, raw = result
|
||||
try:
|
||||
await _process_one(client, payload)
|
||||
except Exception:
|
||||
await queue.fail(raw, payload)
|
||||
return True
|
||||
await queue.ack(raw)
|
||||
return True
|
||||
|
||||
|
||||
async def worker_loop():
|
||||
"""무한 루프 — paused 체크 → BLPOP → process_one."""
|
||||
"""무한 루프 — paused 체크 → ReliableQueue.dequeue → process_one → ack/fail."""
|
||||
redis = aioredis.from_url(REDIS_URL, decode_responses=False)
|
||||
queue = ReliableQueue(redis, queue_key=QUEUE_KEY)
|
||||
async with httpx.AsyncClient() as client:
|
||||
logger.info("insta-render worker started (queue=%s)", QUEUE_KEY)
|
||||
logger.info("insta-render worker started worker_id=%s queue=%s",
|
||||
queue.worker_id, QUEUE_KEY)
|
||||
# F6: startup recovery — 이전 crash 시 잔존 orphan 재큐
|
||||
try:
|
||||
recovered = await queue.recover()
|
||||
if recovered:
|
||||
logger.info("recovered %d orphaned items at startup", recovered)
|
||||
except Exception:
|
||||
logger.exception("startup recover failed")
|
||||
|
||||
while True:
|
||||
try:
|
||||
paused = await redis.get(PAUSED_KEY)
|
||||
if paused == b"1":
|
||||
await asyncio.sleep(10)
|
||||
continue
|
||||
item = await redis.blpop(QUEUE_KEY, timeout=1)
|
||||
if item is None:
|
||||
continue
|
||||
_, raw = item
|
||||
try:
|
||||
payload = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("invalid queue payload: %r", raw[:200])
|
||||
continue
|
||||
await _process_one(client, payload)
|
||||
await poll_once(queue, client)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("worker_loop cancelled")
|
||||
raise
|
||||
|
||||
@@ -8,10 +8,13 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
COPY music-render/requirements.txt /app/
|
||||
RUN pip install --no-cache-dir --timeout 600 --retries 5 -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
# F6: 공통 ReliableQueue 모듈 (services/_shared)
|
||||
COPY _shared /app/_shared
|
||||
COPY music-render/. /app/
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
EXPOSE 8000
|
||||
CMD ["python", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "1"]
|
||||
|
||||
5
services/music-render/conftest.py
Normal file
5
services/music-render/conftest.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Make services/ root importable so `from _shared.reliable_queue import ...` works during tests."""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
@@ -107,3 +107,63 @@ def test_dispatch_add_instrumental_calls_run_add_instrumental():
|
||||
with patch("worker.run_add_instrumental") as m:
|
||||
worker._dispatch(payload)
|
||||
m.assert_called_once_with("t13", {"upload_url": "u"})
|
||||
|
||||
|
||||
# ----- F6: ReliableQueue poll_once -----
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_once_acks_on_success(monkeypatch):
|
||||
"""F6 — _dispatch 정상 return → queue.ack(raw)."""
|
||||
payload = {"task_id": "t1", "job_type": "suno_generation", "params": {}}
|
||||
raw = json.dumps(payload).encode()
|
||||
fake_queue = AsyncMock()
|
||||
fake_queue.dequeue = AsyncMock(return_value=(payload, raw))
|
||||
fake_queue.ack = AsyncMock()
|
||||
fake_queue.fail = AsyncMock()
|
||||
|
||||
monkeypatch.setattr(worker, "_dispatch", MagicMock())
|
||||
|
||||
handled = await worker.poll_once(fake_queue)
|
||||
assert handled is True
|
||||
fake_queue.ack.assert_awaited_once_with(raw)
|
||||
fake_queue.fail.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_once_calls_fail_on_dispatch_exception(monkeypatch):
|
||||
"""F6 — _dispatch unhandled exception → queue.fail(raw, payload)."""
|
||||
payload = {"task_id": "t2", "job_type": "suno_generation", "params": {}}
|
||||
raw = json.dumps(payload).encode()
|
||||
fake_queue = AsyncMock()
|
||||
fake_queue.dequeue = AsyncMock(return_value=(payload, raw))
|
||||
fake_queue.ack = AsyncMock()
|
||||
fake_queue.fail = AsyncMock()
|
||||
|
||||
def _boom(p):
|
||||
raise RuntimeError("dispatch crash")
|
||||
|
||||
monkeypatch.setattr(worker, "_dispatch", _boom)
|
||||
|
||||
handled = await worker.poll_once(fake_queue)
|
||||
assert handled is True
|
||||
fake_queue.fail.assert_awaited_once_with(raw, payload)
|
||||
fake_queue.ack.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_once_returns_false_on_timeout(monkeypatch):
|
||||
fake_queue = AsyncMock()
|
||||
fake_queue.dequeue = AsyncMock(return_value=None)
|
||||
fake_queue.ack = AsyncMock()
|
||||
fake_queue.fail = AsyncMock()
|
||||
dispatch_mock = MagicMock()
|
||||
monkeypatch.setattr(worker, "_dispatch", dispatch_mock)
|
||||
|
||||
handled = await worker.poll_once(fake_queue)
|
||||
assert handled is False
|
||||
dispatch_mock.assert_not_called()
|
||||
fake_queue.ack.assert_not_awaited()
|
||||
fake_queue.fail.assert_not_awaited()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Redis BLPOP worker — queue:music-render → job_type 디스패치 → NAS webhook.
|
||||
"""Redis ReliableQueue worker — F6 신뢰성 패턴 (BLMOVE + ack/fail + recovery).
|
||||
|
||||
queue:paused 가 set이면 대기 (task-watcher가 박재오 활동 감지 시 set).
|
||||
"""
|
||||
@@ -20,6 +20,7 @@ from providers.suno import (
|
||||
run_add_instrumental, run_video_generate,
|
||||
)
|
||||
from providers.local import run_local_generation
|
||||
from _shared.reliable_queue import ReliableQueue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -67,26 +68,44 @@ def _dispatch(payload: dict) -> None:
|
||||
fn(task_id, params)
|
||||
|
||||
|
||||
async def poll_once(queue: ReliableQueue) -> bool:
|
||||
"""F6 — 1 cycle: dequeue → _dispatch → ack/fail. Returns True if a job handled."""
|
||||
result = await queue.dequeue(timeout=5)
|
||||
if result is None:
|
||||
return False
|
||||
payload, raw = result
|
||||
try:
|
||||
# sync provider 함수 — thread로 실행해서 이벤트 루프 블로킹 방지
|
||||
await asyncio.to_thread(_dispatch, payload)
|
||||
except Exception:
|
||||
logger.exception("dispatch unhandled exception task_id=%s",
|
||||
payload.get("task_id"))
|
||||
await queue.fail(raw, payload)
|
||||
return True
|
||||
await queue.ack(raw)
|
||||
return True
|
||||
|
||||
|
||||
async def worker_loop():
|
||||
redis = aioredis.from_url(REDIS_URL, decode_responses=False)
|
||||
logger.info("music-render worker started (queue=%s)", QUEUE_KEY)
|
||||
queue = ReliableQueue(redis, queue_key=QUEUE_KEY)
|
||||
logger.info("music-render worker started worker_id=%s queue=%s",
|
||||
queue.worker_id, QUEUE_KEY)
|
||||
# F6: startup recovery
|
||||
try:
|
||||
recovered = await queue.recover()
|
||||
if recovered:
|
||||
logger.info("recovered %d orphaned items at startup", recovered)
|
||||
except Exception:
|
||||
logger.exception("startup recover failed")
|
||||
|
||||
while True:
|
||||
try:
|
||||
paused = await redis.get(PAUSED_KEY)
|
||||
if paused == b"1":
|
||||
await asyncio.sleep(10)
|
||||
continue
|
||||
item = await redis.blpop(QUEUE_KEY, timeout=1)
|
||||
if item is None:
|
||||
continue
|
||||
_, raw = item
|
||||
try:
|
||||
payload = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("invalid queue payload: %r", raw[:200])
|
||||
continue
|
||||
# sync provider 함수 — thread로 실행해서 이벤트 루프 블로킹 방지
|
||||
await asyncio.to_thread(_dispatch, payload)
|
||||
await poll_once(queue)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("worker_loop cancelled")
|
||||
raise
|
||||
|
||||
@@ -7,10 +7,13 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
COPY video-render/requirements.txt /app/
|
||||
RUN pip install --no-cache-dir --timeout 600 --retries 5 -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
# F6: 공통 ReliableQueue 모듈 (services/_shared)
|
||||
COPY _shared /app/_shared
|
||||
COPY video-render/. /app/
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
EXPOSE 8000
|
||||
CMD ["python", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "1"]
|
||||
|
||||
5
services/video-render/conftest.py
Normal file
5
services/video-render/conftest.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Make services/ root importable so `from _shared.reliable_queue import ...` works during tests."""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
@@ -41,3 +41,56 @@ def test_dispatch_unknown_job_type_logs_error():
|
||||
args = m.call_args[0]
|
||||
assert args[0] == "t5"
|
||||
assert args[1] == "failed"
|
||||
|
||||
|
||||
# ----- F6: ReliableQueue poll_once -----
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_once_acks_on_success(monkeypatch):
|
||||
payload = {"task_id": "t1", "job_type": "sora_generation", "params": {}}
|
||||
raw = json.dumps(payload).encode()
|
||||
fake_queue = AsyncMock()
|
||||
fake_queue.dequeue = AsyncMock(return_value=(payload, raw))
|
||||
fake_queue.ack = AsyncMock()
|
||||
fake_queue.fail = AsyncMock()
|
||||
monkeypatch.setattr(worker, "_dispatch", MagicMock())
|
||||
handled = await worker.poll_once(fake_queue)
|
||||
assert handled is True
|
||||
fake_queue.ack.assert_awaited_once_with(raw)
|
||||
fake_queue.fail.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_once_calls_fail_on_dispatch_exception(monkeypatch):
|
||||
payload = {"task_id": "t2", "job_type": "sora_generation", "params": {}}
|
||||
raw = json.dumps(payload).encode()
|
||||
fake_queue = AsyncMock()
|
||||
fake_queue.dequeue = AsyncMock(return_value=(payload, raw))
|
||||
fake_queue.ack = AsyncMock()
|
||||
fake_queue.fail = AsyncMock()
|
||||
|
||||
def _boom(p):
|
||||
raise RuntimeError("dispatch crash")
|
||||
|
||||
monkeypatch.setattr(worker, "_dispatch", _boom)
|
||||
handled = await worker.poll_once(fake_queue)
|
||||
assert handled is True
|
||||
fake_queue.fail.assert_awaited_once_with(raw, payload)
|
||||
fake_queue.ack.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_once_returns_false_on_timeout(monkeypatch):
|
||||
fake_queue = AsyncMock()
|
||||
fake_queue.dequeue = AsyncMock(return_value=None)
|
||||
fake_queue.ack = AsyncMock()
|
||||
fake_queue.fail = AsyncMock()
|
||||
monkeypatch.setattr(worker, "_dispatch", MagicMock())
|
||||
handled = await worker.poll_once(fake_queue)
|
||||
assert handled is False
|
||||
fake_queue.ack.assert_not_awaited()
|
||||
fake_queue.fail.assert_not_awaited()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Redis BLPOP worker — queue:video-render → job_type 디스패치 → NAS webhook.
|
||||
"""Redis ReliableQueue worker — F6 신뢰성 패턴 (BLMOVE + ack/fail + recovery).
|
||||
|
||||
queue:paused 가 set이면 대기 (task-watcher가 박재오 활동 감지 시 set).
|
||||
Plan-B-Music worker.py 패턴 — string-based dispatch + getattr (테스트 patch 호환).
|
||||
string-based dispatch + getattr (테스트 patch 호환).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -18,6 +18,7 @@ from providers.sora import run_sora_generation
|
||||
from providers.veo import run_veo_generation
|
||||
from providers.kling import run_kling_generation
|
||||
from providers.seedance import run_seedance_generation
|
||||
from _shared.reliable_queue import ReliableQueue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -53,25 +54,42 @@ def _dispatch(payload: dict) -> None:
|
||||
fn(task_id, params)
|
||||
|
||||
|
||||
async def poll_once(queue: ReliableQueue) -> bool:
|
||||
"""F6 — 1 cycle: dequeue → _dispatch → ack/fail. Returns True if a job handled."""
|
||||
result = await queue.dequeue(timeout=5)
|
||||
if result is None:
|
||||
return False
|
||||
payload, raw = result
|
||||
try:
|
||||
await asyncio.to_thread(_dispatch, payload)
|
||||
except Exception:
|
||||
logger.exception("dispatch unhandled exception task_id=%s",
|
||||
payload.get("task_id"))
|
||||
await queue.fail(raw, payload)
|
||||
return True
|
||||
await queue.ack(raw)
|
||||
return True
|
||||
|
||||
|
||||
async def worker_loop():
|
||||
redis = aioredis.from_url(REDIS_URL, decode_responses=False)
|
||||
logger.info("video-render worker started (queue=%s)", QUEUE_KEY)
|
||||
queue = ReliableQueue(redis, queue_key=QUEUE_KEY)
|
||||
logger.info("video-render worker started worker_id=%s queue=%s",
|
||||
queue.worker_id, QUEUE_KEY)
|
||||
try:
|
||||
recovered = await queue.recover()
|
||||
if recovered:
|
||||
logger.info("recovered %d orphaned items at startup", recovered)
|
||||
except Exception:
|
||||
logger.exception("startup recover failed")
|
||||
|
||||
while True:
|
||||
try:
|
||||
paused = await redis.get(PAUSED_KEY)
|
||||
if paused == b"1":
|
||||
await asyncio.sleep(10)
|
||||
continue
|
||||
item = await redis.blpop(QUEUE_KEY, timeout=1)
|
||||
if item is None:
|
||||
continue
|
||||
_, raw = item
|
||||
try:
|
||||
payload = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("invalid queue payload: %r", raw[:200])
|
||||
continue
|
||||
await asyncio.to_thread(_dispatch, payload)
|
||||
await poll_once(queue)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("worker_loop cancelled")
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user