feat(screener): ai_news pipeline (top-100 parallel, fail-soft, upsert)
This commit is contained in:
150
stock-lab/app/screener/ai_news/pipeline.py
Normal file
150
stock-lab/app/screener/ai_news/pipeline.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""ai_news refresh pipeline — 시총 상위 N종목 병렬 처리."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime as dt
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from . import scraper as _scraper
|
||||
from . import analyzer as _analyzer
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_TOP_N = 100
|
||||
DEFAULT_CONCURRENCY = 10
|
||||
DEFAULT_NEWS_PER_TICKER = 5
|
||||
DEFAULT_RATE_LIMIT_SEC = 0.2
|
||||
|
||||
|
||||
def _top_market_cap_tickers(conn: sqlite3.Connection, n: int) -> List[str]:
|
||||
rows = conn.execute(
|
||||
"SELECT ticker FROM krx_master "
|
||||
"WHERE market_cap IS NOT NULL AND is_preferred=0 AND is_spac=0 "
|
||||
"ORDER BY market_cap DESC LIMIT ?",
|
||||
(n,),
|
||||
).fetchall()
|
||||
return [r[0] for r in rows]
|
||||
|
||||
|
||||
def _make_http():
|
||||
return httpx.AsyncClient(timeout=10.0, headers=_scraper.NAVER_HEADERS)
|
||||
|
||||
|
||||
def _make_llm():
|
||||
"""Anthropic AsyncClient — env에 ANTHROPIC_API_KEY 필수."""
|
||||
from anthropic import AsyncAnthropic
|
||||
return AsyncAnthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
|
||||
|
||||
|
||||
async def _process_one(
|
||||
ticker: str, name: str, sem: asyncio.Semaphore,
|
||||
http_client, llm, news_per_ticker: int, rate_limit_sec: float, model: str,
|
||||
) -> Dict[str, Any]:
|
||||
async with sem:
|
||||
if rate_limit_sec > 0:
|
||||
await asyncio.sleep(rate_limit_sec)
|
||||
news = await _scraper.fetch_news(http_client, ticker, n=news_per_ticker)
|
||||
if not news:
|
||||
return {
|
||||
"ticker": ticker, "score_raw": 0.0, "reason": "no news",
|
||||
"news_count": 0, "tokens_input": 0, "tokens_output": 0,
|
||||
"model": model,
|
||||
}
|
||||
return await _analyzer.score_sentiment(
|
||||
llm, ticker, news, name=name, model=model,
|
||||
)
|
||||
|
||||
|
||||
def _upsert_news_sentiment(
|
||||
conn: sqlite3.Connection, asof: dt.date, rows: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
iso = asof.isoformat()
|
||||
data = [
|
||||
(
|
||||
r["ticker"], iso, r["score_raw"], r["reason"], r["news_count"],
|
||||
r["tokens_input"], r["tokens_output"], r["model"],
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
conn.executemany(
|
||||
"""INSERT INTO news_sentiment
|
||||
(ticker, date, score_raw, reason, news_count,
|
||||
tokens_input, tokens_output, model)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(ticker, date) DO UPDATE SET
|
||||
score_raw=excluded.score_raw,
|
||||
reason=excluded.reason,
|
||||
news_count=excluded.news_count,
|
||||
tokens_input=excluded.tokens_input,
|
||||
tokens_output=excluded.tokens_output,
|
||||
model=excluded.model
|
||||
""",
|
||||
data,
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
async def refresh_daily(
|
||||
conn: sqlite3.Connection,
|
||||
asof: dt.date,
|
||||
*,
|
||||
top_n: int = DEFAULT_TOP_N,
|
||||
concurrency: int = DEFAULT_CONCURRENCY,
|
||||
news_per_ticker: int = DEFAULT_NEWS_PER_TICKER,
|
||||
rate_limit_sec: float = DEFAULT_RATE_LIMIT_SEC,
|
||||
model: str = _analyzer.DEFAULT_MODEL,
|
||||
) -> Dict[str, Any]:
|
||||
"""Returns summary dict with top_pos/top_neg/token totals/failures."""
|
||||
started = time.time()
|
||||
tickers = _top_market_cap_tickers(conn, n=top_n)
|
||||
name_map = {
|
||||
r[0]: r[1] for r in conn.execute(
|
||||
f"SELECT ticker, name FROM krx_master WHERE ticker IN "
|
||||
f"({','.join('?' * len(tickers))})", tickers,
|
||||
).fetchall()
|
||||
} if tickers else {}
|
||||
|
||||
sem = asyncio.Semaphore(concurrency)
|
||||
|
||||
async with _make_http() as http_client, _make_llm() as llm:
|
||||
tasks = [
|
||||
_process_one(
|
||||
t, name_map.get(t, t), sem, http_client, llm,
|
||||
news_per_ticker, rate_limit_sec, model,
|
||||
)
|
||||
for t in tickers
|
||||
]
|
||||
raw_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
successes: List[Dict[str, Any]] = []
|
||||
failures: List[str] = []
|
||||
for r in raw_results:
|
||||
if isinstance(r, BaseException):
|
||||
failures.append(repr(r))
|
||||
elif isinstance(r, dict):
|
||||
successes.append(r)
|
||||
|
||||
if successes:
|
||||
_upsert_news_sentiment(conn, asof, successes)
|
||||
|
||||
top_pos = sorted(successes, key=lambda r: -r["score_raw"])[:5]
|
||||
top_neg = sorted(successes, key=lambda r: r["score_raw"])[:5]
|
||||
|
||||
return {
|
||||
"asof": asof.isoformat(),
|
||||
"updated": len(successes),
|
||||
"failures": failures,
|
||||
"duration_sec": round(time.time() - started, 2),
|
||||
"tokens_input": sum(r["tokens_input"] for r in successes),
|
||||
"tokens_output": sum(r["tokens_output"] for r in successes),
|
||||
"top_pos": top_pos,
|
||||
"top_neg": top_neg,
|
||||
"model": model,
|
||||
}
|
||||
110
stock-lab/tests/test_ai_news_pipeline.py
Normal file
110
stock-lab/tests/test_ai_news_pipeline.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import datetime as dt
|
||||
import sqlite3
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from app.screener.ai_news import pipeline
|
||||
from app.screener.schema import ensure_screener_schema
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conn():
|
||||
c = sqlite3.connect(":memory:")
|
||||
c.row_factory = sqlite3.Row
|
||||
ensure_screener_schema(c)
|
||||
# 시총 상위 3종목 시드
|
||||
c.execute("INSERT INTO krx_master (ticker, name, market, market_cap, updated_at) "
|
||||
"VALUES (?, ?, 'KOSPI', ?, datetime('now'))", ("005930", "삼성전자", 9_000_000))
|
||||
c.execute("INSERT INTO krx_master (ticker, name, market, market_cap, updated_at) "
|
||||
"VALUES (?, ?, 'KOSPI', ?, datetime('now'))", ("000660", "SK하이닉스", 8_000_000))
|
||||
c.execute("INSERT INTO krx_master (ticker, name, market, market_cap, updated_at) "
|
||||
"VALUES (?, ?, 'KOSPI', ?, datetime('now'))", ("373220", "LG에너지솔루션", 7_000_000))
|
||||
c.commit()
|
||||
yield c
|
||||
c.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_daily_happy_path(conn):
|
||||
"""3종목 mini integration — 각 종목별로 scraper/analyzer mock."""
|
||||
asof = dt.date(2026, 5, 13)
|
||||
fake_news = [{"title": "헤드라인"}]
|
||||
|
||||
async def fake_fetch(client, ticker, n):
|
||||
return fake_news
|
||||
|
||||
scores_by_ticker = {
|
||||
"005930": 7.5, "000660": 4.0, "373220": -6.0,
|
||||
}
|
||||
async def fake_score(llm, ticker, news, *, name=None, model="m"):
|
||||
return {
|
||||
"ticker": ticker, "score_raw": scores_by_ticker[ticker],
|
||||
"reason": f"r{ticker}", "news_count": 1,
|
||||
"tokens_input": 100, "tokens_output": 20, "model": model,
|
||||
}
|
||||
|
||||
with patch.object(pipeline, "_scraper") as ms, \
|
||||
patch.object(pipeline, "_analyzer") as ma, \
|
||||
patch.object(pipeline, "_make_llm") as ml, \
|
||||
patch.object(pipeline, "_make_http") as mh:
|
||||
ms.fetch_news = fake_fetch
|
||||
ma.score_sentiment = fake_score
|
||||
ml.return_value.__aenter__.return_value = AsyncMock()
|
||||
ml.return_value.__aexit__.return_value = None
|
||||
mh.return_value.__aenter__.return_value = AsyncMock()
|
||||
mh.return_value.__aexit__.return_value = None
|
||||
result = await pipeline.refresh_daily(conn, asof, concurrency=3, rate_limit_sec=0)
|
||||
|
||||
assert result["asof"] == "2026-05-13"
|
||||
assert result["updated"] == 3
|
||||
assert result["failures"] == []
|
||||
assert len(result["top_pos"]) == 3
|
||||
assert result["top_pos"][0]["ticker"] == "005930" # 가장 큰 점수
|
||||
assert result["top_neg"][0]["ticker"] == "373220" # 가장 작은 점수
|
||||
assert result["tokens_input"] == 300
|
||||
assert result["tokens_output"] == 60
|
||||
|
||||
# DB upsert 확인
|
||||
rows = conn.execute("SELECT ticker, score_raw FROM news_sentiment WHERE date=?",
|
||||
("2026-05-13",)).fetchall()
|
||||
assert len(rows) == 3
|
||||
by_ticker = {r["ticker"]: r["score_raw"] for r in rows}
|
||||
assert by_ticker["005930"] == 7.5
|
||||
assert by_ticker["373220"] == -6.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_daily_failures_isolated(conn):
|
||||
"""한 종목이 예외 던져도 나머지 종목은 정상 처리."""
|
||||
asof = dt.date(2026, 5, 13)
|
||||
|
||||
async def fake_fetch(client, ticker, n):
|
||||
return [{"title": "h"}]
|
||||
|
||||
async def fake_score(llm, ticker, news, *, name=None, model="m"):
|
||||
if ticker == "000660":
|
||||
raise RuntimeError("llm exploded")
|
||||
return {
|
||||
"ticker": ticker, "score_raw": 5.0, "reason": "r", "news_count": 1,
|
||||
"tokens_input": 100, "tokens_output": 20, "model": model,
|
||||
}
|
||||
|
||||
with patch.object(pipeline, "_scraper") as ms, \
|
||||
patch.object(pipeline, "_analyzer") as ma, \
|
||||
patch.object(pipeline, "_make_llm") as ml, \
|
||||
patch.object(pipeline, "_make_http") as mh:
|
||||
ms.fetch_news = fake_fetch
|
||||
ma.score_sentiment = fake_score
|
||||
ml.return_value.__aenter__.return_value = AsyncMock()
|
||||
ml.return_value.__aexit__.return_value = None
|
||||
mh.return_value.__aenter__.return_value = AsyncMock()
|
||||
mh.return_value.__aexit__.return_value = None
|
||||
result = await pipeline.refresh_daily(conn, asof, concurrency=3, rate_limit_sec=0)
|
||||
|
||||
assert result["updated"] == 2
|
||||
assert len(result["failures"]) == 1
|
||||
|
||||
|
||||
def test_top_market_cap_tickers(conn):
|
||||
out = pipeline._top_market_cap_tickers(conn, n=2)
|
||||
assert out == ["005930", "000660"]
|
||||
Reference in New Issue
Block a user