Compare commits
16 Commits
d6081ba2d3
...
feat/secur
| Author | SHA1 | Date | |
|---|---|---|---|
| 49c5c57be5 | |||
| 6053e69afc | |||
| 1e5e1bcdff | |||
| 64fbbb7958 | |||
| cfbb72051f | |||
| bf5897fc85 | |||
| ad6c744f2c | |||
| aad9bfbe8b | |||
| 42bd53ee7b | |||
| 86694ae4fe | |||
| 41225b3337 | |||
| 6bb5c2fb40 | |||
| bd1773e29e | |||
| 685320f3cf | |||
| b3982c8f72 | |||
| 002c0893f8 |
@@ -51,9 +51,14 @@ PGID=1000
|
|||||||
# Windows AI Server (NAS 입장에서 바라본 Windows PC IP)
|
# Windows AI Server (NAS 입장에서 바라본 Windows PC IP)
|
||||||
WINDOWS_AI_SERVER_URL=http://192.168.45.59:8000
|
WINDOWS_AI_SERVER_URL=http://192.168.45.59:8000
|
||||||
|
|
||||||
# Admin API Key (trade/order 등 민감 엔드포인트 보호, 미설정 시 인증 비활성화)
|
# Admin API Key — /api/trade/* 등 민감 엔드포인트 보호.
|
||||||
|
# 운영 .env에는 반드시 값을 채워야 함. 빈 값이면 503 응답으로 거부됨 (CODE_REVIEW F2).
|
||||||
ADMIN_API_KEY=
|
ADMIN_API_KEY=
|
||||||
|
|
||||||
|
# 개발 모드: 위 ADMIN_API_KEY 비워둔 채로 trade/admin 엔드포인트 호출 허용.
|
||||||
|
# 운영 환경에서는 절대 true로 두지 말 것. 기본 false (보호 활성).
|
||||||
|
ALLOW_UNAUTHENTICATED_ADMIN=false
|
||||||
|
|
||||||
# Anthropic API Key (AI Coach 프록시 + 뉴스 요약 Claude provider)
|
# Anthropic API Key (AI Coach 프록시 + 뉴스 요약 Claude provider)
|
||||||
ANTHROPIC_API_KEY=
|
ANTHROPIC_API_KEY=
|
||||||
ANTHROPIC_MODEL=claude-haiku-4-5-20251001
|
ANTHROPIC_MODEL=claude-haiku-4-5-20251001
|
||||||
|
|||||||
@@ -56,6 +56,8 @@ class InstaAgent(BaseAgent):
|
|||||||
requires_approval=False)
|
requires_approval=False)
|
||||||
await self.transition("working", "뉴스 수집·키워드 추출", task_id)
|
await self.transition("working", "뉴스 수집·키워드 추출", task_id)
|
||||||
try:
|
try:
|
||||||
|
prefs = await service_proxy.insta_get_preferences()
|
||||||
|
add_log(self.agent_id, f"insta preferences: {prefs}", "info", task_id)
|
||||||
await self._run_collect_and_extract()
|
await self._run_collect_and_extract()
|
||||||
kws = await service_proxy.insta_list_keywords(used=False)
|
kws = await service_proxy.insta_list_keywords(used=False)
|
||||||
if auto_select:
|
if auto_select:
|
||||||
@@ -147,6 +149,12 @@ class InstaAgent(BaseAgent):
|
|||||||
return {"ok": False, "message": "keyword_id 필수"}
|
return {"ok": False, "message": "keyword_id 필수"}
|
||||||
await self._render_and_push(kid)
|
await self._render_and_push(kid)
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
if command == "collect_trends":
|
||||||
|
await messaging.send_raw("🌐 외부 트렌드 수집 시작")
|
||||||
|
created = await service_proxy.insta_collect_trends()
|
||||||
|
st = await self._wait_task(created["task_id"], step="trends_collect", timeout_sec=300)
|
||||||
|
await messaging.send_raw(f"✅ 트렌드 수집 완료: {st.get('message', '')}")
|
||||||
|
return {"ok": True, "result": st}
|
||||||
return {"ok": False, "message": f"Unknown command: {command}"}
|
return {"ok": False, "message": f"Unknown command: {command}"}
|
||||||
|
|
||||||
async def on_callback(self, action: str, params: dict) -> dict:
|
async def on_callback(self, action: str, params: dict) -> dict:
|
||||||
|
|||||||
@@ -29,6 +29,12 @@ async def _run_insta_schedule():
|
|||||||
if agent:
|
if agent:
|
||||||
await agent.on_schedule()
|
await agent.on_schedule()
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_insta_trends_collect():
|
||||||
|
agent = AGENT_REGISTRY.get("insta")
|
||||||
|
if agent:
|
||||||
|
await agent.on_command("collect_trends", {})
|
||||||
|
|
||||||
async def _run_lotto_schedule():
|
async def _run_lotto_schedule():
|
||||||
agent = AGENT_REGISTRY.get("lotto")
|
agent = AGENT_REGISTRY.get("lotto")
|
||||||
if agent:
|
if agent:
|
||||||
@@ -68,6 +74,7 @@ def init_scheduler():
|
|||||||
id="stock_ai_news_sentiment",
|
id="stock_ai_news_sentiment",
|
||||||
)
|
)
|
||||||
scheduler.add_job(_run_insta_schedule, "cron", hour=9, minute=30, id="insta_pipeline")
|
scheduler.add_job(_run_insta_schedule, "cron", hour=9, minute=30, id="insta_pipeline")
|
||||||
|
scheduler.add_job(_run_insta_trends_collect, "cron", hour=9, minute=0, id="insta_trends_collect")
|
||||||
scheduler.add_job(_run_lotto_schedule, "cron", day_of_week="mon", hour=9, minute=0, id="lotto_curate")
|
scheduler.add_job(_run_lotto_schedule, "cron", day_of_week="mon", hour=9, minute=0, id="lotto_curate")
|
||||||
scheduler.add_job(_run_youtube_research, "cron", hour=9, minute=0, id="youtube_research")
|
scheduler.add_job(_run_youtube_research, "cron", hour=9, minute=0, id="youtube_research")
|
||||||
scheduler.add_job(_send_youtube_weekly_report, "cron", day_of_week="mon", hour=8, minute=0, id="youtube_weekly_report")
|
scheduler.add_job(_send_youtube_weekly_report, "cron", day_of_week="mon", hour=8, minute=0, id="youtube_weekly_report")
|
||||||
|
|||||||
@@ -167,6 +167,41 @@ async def insta_get_asset_bytes(slate_id: int, page: int) -> bytes:
|
|||||||
return resp.content
|
return resp.content
|
||||||
|
|
||||||
|
|
||||||
|
async def insta_collect_trends(categories: Optional[list] = None) -> Dict[str, Any]:
|
||||||
|
payload = {"categories": categories} if categories else {}
|
||||||
|
resp = await _client.post(f"{INSTA_LAB_URL}/api/insta/trends/collect", json=payload)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
async def insta_list_trends(source: Optional[str] = None,
|
||||||
|
category: Optional[str] = None,
|
||||||
|
days: int = 1) -> List[Dict[str, Any]]:
|
||||||
|
params: Dict[str, Any] = {"days": days}
|
||||||
|
if source:
|
||||||
|
params["source"] = source
|
||||||
|
if category:
|
||||||
|
params["category"] = category
|
||||||
|
resp = await _client.get(f"{INSTA_LAB_URL}/api/insta/trends", params=params)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json().get("items", [])
|
||||||
|
|
||||||
|
|
||||||
|
async def insta_get_preferences() -> Dict[str, float]:
|
||||||
|
resp = await _client.get(f"{INSTA_LAB_URL}/api/insta/preferences")
|
||||||
|
resp.raise_for_status()
|
||||||
|
return {p["category"]: p["weight"] for p in resp.json().get("categories", [])}
|
||||||
|
|
||||||
|
|
||||||
|
async def insta_put_preferences(weights: Dict[str, float]) -> Dict[str, Any]:
|
||||||
|
resp = await _client.put(
|
||||||
|
f"{INSTA_LAB_URL}/api/insta/preferences",
|
||||||
|
json={"categories": weights},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
# --- realestate-lab ---
|
# --- realestate-lab ---
|
||||||
|
|
||||||
async def realestate_collect() -> Dict[str, Any]:
|
async def realestate_collect() -> Dict[str, Any]:
|
||||||
|
|||||||
73
agent-office/tests/test_insta_agent_trends.py
Normal file
73
agent-office/tests/test_insta_agent_trends.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
_fd, _TMP = tempfile.mkstemp(suffix=".db")
|
||||||
|
os.close(_fd)
|
||||||
|
os.unlink(_TMP)
|
||||||
|
os.environ["AGENT_OFFICE_DB_PATH"] = _TMP
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.insta import InstaAgent
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _init_db():
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
if os.path.exists(_TMP):
|
||||||
|
os.remove(_TMP)
|
||||||
|
from app.db import init_db
|
||||||
|
init_db()
|
||||||
|
yield
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_command_collect_trends_dispatches(monkeypatch):
|
||||||
|
agent = InstaAgent()
|
||||||
|
fake_collect = AsyncMock(return_value={"task_id": "tcollect"})
|
||||||
|
fake_status = AsyncMock(return_value={"status": "succeeded", "result_id": 8,
|
||||||
|
"message": "naver:5, google:3"})
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.agents.insta.service_proxy.insta_collect_trends", fake_collect)
|
||||||
|
monkeypatch.setattr("app.agents.insta.service_proxy.insta_task_status", fake_status)
|
||||||
|
monkeypatch.setattr("app.agents.insta.messaging.send_raw", AsyncMock(return_value={"ok": True}))
|
||||||
|
|
||||||
|
result = await agent.on_command("collect_trends", {})
|
||||||
|
assert result["ok"] is True
|
||||||
|
fake_collect.assert_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_schedule_loads_preferences(monkeypatch):
|
||||||
|
"""on_schedule이 preferences를 가져오는지 확인."""
|
||||||
|
agent = InstaAgent()
|
||||||
|
|
||||||
|
fake_collect = AsyncMock(return_value={"task_id": "t1"})
|
||||||
|
fake_extract = AsyncMock(return_value={"task_id": "t2"})
|
||||||
|
fake_status = AsyncMock(side_effect=[
|
||||||
|
{"status": "succeeded", "result_id": 0},
|
||||||
|
{"status": "succeeded", "result_id": 0},
|
||||||
|
])
|
||||||
|
fake_keywords = AsyncMock(return_value=[
|
||||||
|
{"id": 1, "keyword": "K", "category": "economy", "score": 0.9},
|
||||||
|
])
|
||||||
|
fake_prefs = AsyncMock(return_value={"economy": 0.6, "psychology": 0.4})
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.agents.insta.service_proxy.insta_collect", fake_collect)
|
||||||
|
monkeypatch.setattr("app.agents.insta.service_proxy.insta_extract", fake_extract)
|
||||||
|
monkeypatch.setattr("app.agents.insta.service_proxy.insta_task_status", fake_status)
|
||||||
|
monkeypatch.setattr("app.agents.insta.service_proxy.insta_list_keywords", fake_keywords)
|
||||||
|
monkeypatch.setattr("app.agents.insta.service_proxy.insta_get_preferences", fake_prefs)
|
||||||
|
monkeypatch.setattr("app.agents.insta.messaging.send_raw", AsyncMock(return_value={"ok": True}))
|
||||||
|
|
||||||
|
agent.state = "idle"
|
||||||
|
await agent.on_schedule()
|
||||||
|
|
||||||
|
fake_prefs.assert_awaited()
|
||||||
@@ -100,6 +100,7 @@ services:
|
|||||||
- ANTHROPIC_MODEL_SONNET=${ANTHROPIC_MODEL_SONNET:-claude-sonnet-4-6}
|
- ANTHROPIC_MODEL_SONNET=${ANTHROPIC_MODEL_SONNET:-claude-sonnet-4-6}
|
||||||
- NAVER_CLIENT_ID=${NAVER_CLIENT_ID:-}
|
- NAVER_CLIENT_ID=${NAVER_CLIENT_ID:-}
|
||||||
- NAVER_CLIENT_SECRET=${NAVER_CLIENT_SECRET:-}
|
- NAVER_CLIENT_SECRET=${NAVER_CLIENT_SECRET:-}
|
||||||
|
- YOUTUBE_DATA_API_KEY=${YOUTUBE_DATA_API_KEY:-}
|
||||||
- INSTA_DATA_PATH=/app/data
|
- INSTA_DATA_PATH=/app/data
|
||||||
- CARD_TEMPLATE_DIR=/app/app/templates
|
- CARD_TEMPLATE_DIR=/app/app/templates
|
||||||
- CORS_ALLOW_ORIGINS=${CORS_ALLOW_ORIGINS:-http://localhost:3007,http://localhost:8080}
|
- CORS_ALLOW_ORIGINS=${CORS_ALLOW_ORIGINS:-http://localhost:3007,http://localhost:8080}
|
||||||
|
|||||||
@@ -16,7 +16,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
# --timeout 600 --retries 5: NAS 느린 네트워크/CPU에서 pip 다운로드 timeout 방지
|
||||||
|
RUN pip install --no-cache-dir --timeout 600 --retries 5 -r requirements.txt
|
||||||
RUN playwright install chromium
|
RUN playwright install chromium
|
||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import os
|
|||||||
|
|
||||||
NAVER_CLIENT_ID = os.getenv("NAVER_CLIENT_ID", "")
|
NAVER_CLIENT_ID = os.getenv("NAVER_CLIENT_ID", "")
|
||||||
NAVER_CLIENT_SECRET = os.getenv("NAVER_CLIENT_SECRET", "")
|
NAVER_CLIENT_SECRET = os.getenv("NAVER_CLIENT_SECRET", "")
|
||||||
|
YOUTUBE_DATA_API_KEY = os.getenv("YOUTUBE_DATA_API_KEY", "")
|
||||||
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
||||||
ANTHROPIC_MODEL_HAIKU = os.getenv("ANTHROPIC_MODEL_HAIKU", "claude-haiku-4-5-20251001")
|
ANTHROPIC_MODEL_HAIKU = os.getenv("ANTHROPIC_MODEL_HAIKU", "claude-haiku-4-5-20251001")
|
||||||
ANTHROPIC_MODEL_SONNET = os.getenv("ANTHROPIC_MODEL_SONNET", "claude-sonnet-4-6")
|
ANTHROPIC_MODEL_SONNET = os.getenv("ANTHROPIC_MODEL_SONNET", "claude-sonnet-4-6")
|
||||||
|
|||||||
@@ -101,6 +101,29 @@ def init_db() -> None:
|
|||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
# source column for trending_keywords (idempotent ALTER)
|
||||||
|
cols = [r[1] for r in conn.execute("PRAGMA table_info(trending_keywords)").fetchall()]
|
||||||
|
if "source" not in cols:
|
||||||
|
conn.execute("ALTER TABLE trending_keywords ADD COLUMN source TEXT NOT NULL DEFAULT 'manual'")
|
||||||
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_tk_source ON trending_keywords(source, suggested_at DESC)")
|
||||||
|
|
||||||
|
# account_preferences — 카테고리 가중치
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS account_preferences (
|
||||||
|
category TEXT PRIMARY KEY,
|
||||||
|
weight REAL NOT NULL DEFAULT 1.0,
|
||||||
|
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now'))
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
# seed defaults if table empty
|
||||||
|
existing = conn.execute("SELECT COUNT(*) FROM account_preferences").fetchone()[0]
|
||||||
|
if existing == 0:
|
||||||
|
for cat in ("economy", "psychology", "celebrity"):
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO account_preferences(category, weight) VALUES(?,?)",
|
||||||
|
(cat, 1.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── news_articles ────────────────────────────────────────────────
|
# ── news_articles ────────────────────────────────────────────────
|
||||||
def add_news_article(row: Dict[str, Any]) -> int:
|
def add_news_article(row: Dict[str, Any]) -> int:
|
||||||
@@ -132,8 +155,12 @@ def list_news_articles(category: Optional[str] = None, days: int = 1) -> List[Di
|
|||||||
def add_trending_keyword(row: Dict[str, Any]) -> int:
|
def add_trending_keyword(row: Dict[str, Any]) -> int:
|
||||||
with _conn() as conn:
|
with _conn() as conn:
|
||||||
cur = conn.execute(
|
cur = conn.execute(
|
||||||
"INSERT INTO trending_keywords(keyword, category, score, articles_count) VALUES(?,?,?,?)",
|
"INSERT INTO trending_keywords(keyword, category, score, articles_count, source) VALUES(?,?,?,?,?)",
|
||||||
(row["keyword"], row["category"], float(row.get("score", 0.0)), int(row.get("articles_count", 0))),
|
(
|
||||||
|
row["keyword"], row["category"],
|
||||||
|
float(row.get("score", 0.0)), int(row.get("articles_count", 0)),
|
||||||
|
row.get("source", "manual"),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return cur.lastrowid
|
return cur.lastrowid
|
||||||
|
|
||||||
@@ -276,3 +303,50 @@ def get_prompt_template(name: str) -> Optional[Dict[str, Any]]:
|
|||||||
with _conn() as conn:
|
with _conn() as conn:
|
||||||
row = conn.execute("SELECT * FROM prompt_templates WHERE name=?", (name,)).fetchone()
|
row = conn.execute("SELECT * FROM prompt_templates WHERE name=?", (name,)).fetchone()
|
||||||
return dict(row) if row else None
|
return dict(row) if row else None
|
||||||
|
|
||||||
|
|
||||||
|
# ── external trends ─────────────────────────────────────────────
|
||||||
|
def add_external_trend(row: Dict[str, Any]) -> int:
|
||||||
|
"""`source` 필수 — naver_popular | google_trends. trending_keywords에 인서트."""
|
||||||
|
if "source" not in row:
|
||||||
|
raise ValueError("add_external_trend requires 'source' field")
|
||||||
|
return add_trending_keyword(row)
|
||||||
|
|
||||||
|
|
||||||
|
def list_trends(source: Optional[str] = None, category: Optional[str] = None,
|
||||||
|
days: int = 1) -> List[Dict[str, Any]]:
|
||||||
|
sql = "SELECT * FROM trending_keywords WHERE suggested_at >= datetime('now', ?)"
|
||||||
|
params: List[Any] = [f"-{int(days)} days"]
|
||||||
|
if source and source != "all":
|
||||||
|
sql += " AND source=?"
|
||||||
|
params.append(source)
|
||||||
|
if category:
|
||||||
|
sql += " AND category=?"
|
||||||
|
params.append(category)
|
||||||
|
sql += " ORDER BY suggested_at DESC, score DESC"
|
||||||
|
with _conn() as conn:
|
||||||
|
rows = conn.execute(sql, params).fetchall()
|
||||||
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
# ── account_preferences ─────────────────────────────────────────
|
||||||
|
def get_preferences() -> List[Dict[str, Any]]:
|
||||||
|
with _conn() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT category, weight, updated_at FROM account_preferences ORDER BY category ASC"
|
||||||
|
).fetchall()
|
||||||
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
def upsert_preferences(weights: Dict[str, float]) -> None:
|
||||||
|
"""전체 upsert. 기존에 있던 카테고리는 weight 갱신, 신규는 INSERT.
|
||||||
|
명시되지 않은 기존 카테고리는 그대로 둔다 (삭제 X). 삭제 필요 시 별도 API로."""
|
||||||
|
with _conn() as conn:
|
||||||
|
for cat, w in weights.items():
|
||||||
|
conn.execute("""
|
||||||
|
INSERT INTO account_preferences(category, weight)
|
||||||
|
VALUES(?,?)
|
||||||
|
ON CONFLICT(category) DO UPDATE SET
|
||||||
|
weight=excluded.weight,
|
||||||
|
updated_at=strftime('%Y-%m-%dT%H:%M:%fZ','now')
|
||||||
|
""", (cat, float(w)))
|
||||||
|
|||||||
@@ -81,3 +81,22 @@ def extract_for_category(category: str, limit: int = KEYWORDS_PER_CATEGORY) -> L
|
|||||||
})
|
})
|
||||||
saved.append({"id": kid, **kw, "category": category})
|
saved.append({"id": kid, **kw, "category": category})
|
||||||
return saved
|
return saved
|
||||||
|
|
||||||
|
|
||||||
|
def extract_with_weights(weights: Dict[str, float], total_limit: int) -> List[Dict[str, Any]]:
|
||||||
|
"""카테고리 가중치 비율대로 키워드를 분배 추출."""
|
||||||
|
from .config import DEFAULT_CATEGORY_SEEDS
|
||||||
|
if not weights or sum(weights.values()) == 0:
|
||||||
|
cats = list(DEFAULT_CATEGORY_SEEDS.keys())
|
||||||
|
weights = {c: 1.0 for c in cats}
|
||||||
|
|
||||||
|
total_weight = sum(weights.values())
|
||||||
|
out: List[Dict[str, Any]] = []
|
||||||
|
for category, w in weights.items():
|
||||||
|
if w <= 0:
|
||||||
|
continue
|
||||||
|
per_cat = round(total_limit * w / total_weight)
|
||||||
|
if per_cat <= 0:
|
||||||
|
continue
|
||||||
|
out.extend(extract_for_category(category, limit=per_cat))
|
||||||
|
return out
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from .config import (
|
|||||||
CORS_ALLOW_ORIGINS, NAVER_CLIENT_ID, ANTHROPIC_API_KEY,
|
CORS_ALLOW_ORIGINS, NAVER_CLIENT_ID, ANTHROPIC_API_KEY,
|
||||||
INSTA_DATA_PATH, DB_PATH, DEFAULT_CATEGORY_SEEDS, KEYWORDS_PER_CATEGORY,
|
INSTA_DATA_PATH, DB_PATH, DEFAULT_CATEGORY_SEEDS, KEYWORDS_PER_CATEGORY,
|
||||||
)
|
)
|
||||||
from . import db, news_collector, keyword_extractor, card_writer, card_renderer
|
from . import db, news_collector, keyword_extractor, card_writer, card_renderer, trend_collector
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
@@ -99,11 +99,16 @@ class ExtractRequest(BaseModel):
|
|||||||
categories: Optional[list[str]] = None
|
categories: Optional[list[str]] = None
|
||||||
|
|
||||||
|
|
||||||
async def _bg_extract(task_id: str, categories: list[str]):
|
async def _bg_extract(task_id: str, categories: Optional[list[str]] = None):
|
||||||
try:
|
try:
|
||||||
db.update_task(task_id, "processing", 10, "추출 중")
|
db.update_task(task_id, "processing", 10, "추출 중")
|
||||||
for cat in categories:
|
prefs_rows = db.get_preferences()
|
||||||
keyword_extractor.extract_for_category(cat, limit=KEYWORDS_PER_CATEGORY)
|
weights = {p["category"]: p["weight"] for p in prefs_rows}
|
||||||
|
if categories:
|
||||||
|
# 사용자가 카테고리 명시한 경우만 그 서브셋으로 균등 가중치 (override)
|
||||||
|
weights = {c: 1.0 for c in categories}
|
||||||
|
total = KEYWORDS_PER_CATEGORY * max(1, len([w for w in weights.values() if w > 0]))
|
||||||
|
keyword_extractor.extract_with_weights(weights, total_limit=total)
|
||||||
db.update_task(task_id, "succeeded", 100, "완료", result_id=0)
|
db.update_task(task_id, "succeeded", 100, "완료", result_id=0)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("extract failed")
|
logger.exception("extract failed")
|
||||||
@@ -119,7 +124,13 @@ def extract_keywords(req: ExtractRequest, bg: BackgroundTasks):
|
|||||||
|
|
||||||
|
|
||||||
@app.get("/api/insta/keywords")
|
@app.get("/api/insta/keywords")
|
||||||
def list_keywords(category: Optional[str] = None, used: Optional[bool] = None):
|
def list_keywords(
|
||||||
|
category: Optional[str] = None,
|
||||||
|
used: Optional[bool] = None,
|
||||||
|
source: Optional[str] = None,
|
||||||
|
):
|
||||||
|
if source:
|
||||||
|
return {"items": db.list_trends(source=source, category=category, days=30)}
|
||||||
return {"items": db.list_trending_keywords(category=category, used=used)}
|
return {"items": db.list_trending_keywords(category=category, used=used)}
|
||||||
|
|
||||||
|
|
||||||
@@ -243,3 +254,52 @@ def get_prompt(name: str):
|
|||||||
def upsert_prompt(name: str, body: TemplateBody):
|
def upsert_prompt(name: str, body: TemplateBody):
|
||||||
db.upsert_prompt_template(name, body.template, body.description)
|
db.upsert_prompt_template(name, body.template, body.description)
|
||||||
return db.get_prompt_template(name)
|
return db.get_prompt_template(name)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Trends ───────────────────────────────────────────────────────
|
||||||
|
class TrendsCollectRequest(BaseModel):
|
||||||
|
categories: Optional[list[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
async def _bg_collect_trends(task_id: str, categories: list[str]):
|
||||||
|
try:
|
||||||
|
db.update_task(task_id, "processing", 10, "외부 트렌드 수집 중")
|
||||||
|
result = trend_collector.collect_all(categories)
|
||||||
|
msg = f"naver:{result['naver_popular']}, youtube:{result['youtube_trending']}"
|
||||||
|
db.update_task(task_id, "succeeded", 100, msg, result_id=sum(result.values()))
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("trends collect failed")
|
||||||
|
db.update_task(task_id, "failed", 0, "", error=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/insta/trends/collect")
|
||||||
|
def collect_trends(req: TrendsCollectRequest, bg: BackgroundTasks):
|
||||||
|
cats = req.categories or list(DEFAULT_CATEGORY_SEEDS.keys())
|
||||||
|
tid = db.create_task("trends_collect", {"categories": cats})
|
||||||
|
bg.add_task(_bg_collect_trends, tid, cats)
|
||||||
|
return {"task_id": tid, "categories": cats}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/insta/trends")
|
||||||
|
def list_trends_endpoint(
|
||||||
|
source: Optional[str] = None,
|
||||||
|
category: Optional[str] = None,
|
||||||
|
days: int = Query(1, ge=1, le=90),
|
||||||
|
):
|
||||||
|
return {"items": db.list_trends(source=source, category=category, days=days)}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Preferences ──────────────────────────────────────────────────
|
||||||
|
class PreferencesBody(BaseModel):
|
||||||
|
categories: dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/insta/preferences")
|
||||||
|
def get_preferences_endpoint():
|
||||||
|
return {"categories": db.get_preferences()}
|
||||||
|
|
||||||
|
|
||||||
|
@app.put("/api/insta/preferences")
|
||||||
|
def put_preferences_endpoint(body: PreferencesBody):
|
||||||
|
db.upsert_preferences(body.categories)
|
||||||
|
return {"categories": db.get_preferences()}
|
||||||
|
|||||||
250
insta-lab/app/trend_collector.py
Normal file
250
insta-lab/app/trend_collector.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
"""외부 트렌드 수집 — NAVER 인기 + YouTube 인기 영상 + LLM 카테고리 분류.
|
||||||
|
|
||||||
|
NAVER: 카테고리별 시드 키워드로 인기 검색 → 빈도 상위 추출.
|
||||||
|
YouTube: Google Trends 비공식 endpoint(RSS / dailytrends JSON)가 모두 404 폐기되어
|
||||||
|
대체로 YouTube Data API v3 (`videos.list?chart=mostPopular®ionCode=KR`) 사용.
|
||||||
|
무료 일일 quota 10000, 한국 region 지원, 인기 영상 50개 제목에서 트렌드 추출.
|
||||||
|
LLM 분류 결과는 24h in-memory 캐시.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from anthropic import Anthropic
|
||||||
|
|
||||||
|
from .config import (
|
||||||
|
NAVER_CLIENT_ID, NAVER_CLIENT_SECRET, DEFAULT_CATEGORY_SEEDS,
|
||||||
|
ANTHROPIC_API_KEY, ANTHROPIC_MODEL_HAIKU, YOUTUBE_DATA_API_KEY,
|
||||||
|
)
|
||||||
|
from . import db
|
||||||
|
from .news_collector import _clean
|
||||||
|
from .keyword_extractor import _count_nouns, _top_candidates
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
NEWS_URL = "https://openapi.naver.com/v1/search/news.json"
|
||||||
|
_NAVER_HEADERS = {
|
||||||
|
"X-Naver-Client-Id": NAVER_CLIENT_ID,
|
||||||
|
"X-Naver-Client-Secret": NAVER_CLIENT_SECRET,
|
||||||
|
}
|
||||||
|
|
||||||
|
YOUTUBE_TRENDING_URL = "https://www.googleapis.com/youtube/v3/videos"
|
||||||
|
# YouTube 제목 정제: 대괄호·이모지·과도한 길이 제거 후 카드 주제로 적합한 키워드 형태
|
||||||
|
_TITLE_BRACKET_RE = re.compile(r"[\[【「『\(][^\]】」』\)]{0,30}[\]】」』\)]")
|
||||||
|
_EMOJI_RE = re.compile(
|
||||||
|
r"["
|
||||||
|
r"\U0001F300-\U0001FAFF" # symbols & pictographs, etc.
|
||||||
|
r"\U00002600-\U000027BF" # misc symbols, dingbats
|
||||||
|
r"\U0001F1E6-\U0001F1FF" # regional indicator
|
||||||
|
r"]"
|
||||||
|
)
|
||||||
|
_TITLE_MAX_LEN = 60
|
||||||
|
|
||||||
|
_PLACEHOLDER_SEEDS = {"...", "…", "tbd", "todo", "placeholder", "example"}
|
||||||
|
|
||||||
|
|
||||||
|
def _is_valid_seed(s: str) -> bool:
|
||||||
|
"""프롬프트 템플릿에 placeholder/빈 값이 들어가 NAVER에 400을 유발하는 일을 막는 가드."""
|
||||||
|
if not s:
|
||||||
|
return False
|
||||||
|
s = s.strip()
|
||||||
|
if len(s) < 2:
|
||||||
|
return False
|
||||||
|
if s.lower() in _PLACEHOLDER_SEEDS:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _seeds_for(category: str) -> List[str]:
|
||||||
|
"""category_seeds 프롬프트 템플릿이 있으면 사용, 없거나 모두 invalid면 config DEFAULT 폴백."""
|
||||||
|
pt = db.get_prompt_template("category_seeds")
|
||||||
|
if pt and pt.get("template"):
|
||||||
|
try:
|
||||||
|
data = json.loads(pt["template"])
|
||||||
|
if category in data:
|
||||||
|
filtered = [s for s in (data[category] or []) if _is_valid_seed(s)]
|
||||||
|
if filtered:
|
||||||
|
return filtered
|
||||||
|
logger.warning("category_seeds[%s]에 유효한 시드 없음 → DEFAULT 폴백", category)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("category_seeds JSON 파싱 실패 → DEFAULT 폴백: %s", e)
|
||||||
|
return list(DEFAULT_CATEGORY_SEEDS.get(category, []))
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_naver_popular(category: str, per_seed: int = 30, top_n: int = 10) -> List[Dict[str, Any]]:
|
||||||
|
"""카테고리 시드 키워드들로 NAVER news.json `sort=sim` 호출,
|
||||||
|
응답 기사 묶음에서 빈도어 추출 후 상위 N개 반환."""
|
||||||
|
seeds = _seeds_for(category)
|
||||||
|
if not seeds:
|
||||||
|
return []
|
||||||
|
blob_parts: List[str] = []
|
||||||
|
for seed in seeds:
|
||||||
|
try:
|
||||||
|
resp = requests.get(
|
||||||
|
NEWS_URL,
|
||||||
|
headers=_NAVER_HEADERS,
|
||||||
|
params={"query": seed, "display": per_seed, "sort": "sim"},
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
for item in resp.json().get("items", []):
|
||||||
|
blob_parts.append(_clean(item.get("title", "")))
|
||||||
|
blob_parts.append(_clean(item.get("description", "")))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("fetch_naver_popular seed=%s err=%s", seed, e)
|
||||||
|
continue
|
||||||
|
text = "\n".join(blob_parts)
|
||||||
|
counts = _count_nouns(text)
|
||||||
|
candidates = _top_candidates(counts, n=top_n)
|
||||||
|
if not candidates:
|
||||||
|
return []
|
||||||
|
max_count = candidates[0][1] or 1
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"keyword": k,
|
||||||
|
"category": category,
|
||||||
|
"source": "naver_popular",
|
||||||
|
"score": round(min(1.0, c / max_count), 4),
|
||||||
|
"articles_count": c,
|
||||||
|
}
|
||||||
|
for k, c in candidates
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def collect_naver_popular_for(categories: List[str]) -> int:
|
||||||
|
total = 0
|
||||||
|
for cat in categories:
|
||||||
|
trends = fetch_naver_popular(cat)
|
||||||
|
for t in trends:
|
||||||
|
db.add_external_trend(t)
|
||||||
|
total += 1
|
||||||
|
return total
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM 분류 캐시 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_CACHE_TTL_SEC = 24 * 3600
|
||||||
|
_category_cache: Dict[str, tuple] = {} # keyword -> (category, expires_ts)
|
||||||
|
|
||||||
|
|
||||||
|
def _llm_classify_one(keyword: str) -> str:
|
||||||
|
"""Claude Haiku 1회 호출로 단일 키워드 분류."""
|
||||||
|
if not ANTHROPIC_API_KEY:
|
||||||
|
return "uncategorized"
|
||||||
|
seeds_template = db.get_prompt_template("category_seeds")
|
||||||
|
if seeds_template and seeds_template.get("template"):
|
||||||
|
try:
|
||||||
|
allowed = sorted(json.loads(seeds_template["template"]).keys())
|
||||||
|
except Exception:
|
||||||
|
allowed = sorted(DEFAULT_CATEGORY_SEEDS.keys())
|
||||||
|
else:
|
||||||
|
allowed = sorted(DEFAULT_CATEGORY_SEEDS.keys())
|
||||||
|
allowed.append("uncategorized")
|
||||||
|
|
||||||
|
client = Anthropic(api_key=ANTHROPIC_API_KEY)
|
||||||
|
msg = client.messages.create(
|
||||||
|
model=ANTHROPIC_MODEL_HAIKU,
|
||||||
|
max_tokens=20,
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
f"다음 한국어 트렌딩 키워드를 카테고리 중 하나로 분류해라. "
|
||||||
|
f"카테고리: {allowed}. 키워드: '{keyword}'. "
|
||||||
|
f"카테고리명 한 단어만 출력. 다른 텍스트 금지."
|
||||||
|
),
|
||||||
|
}],
|
||||||
|
)
|
||||||
|
raw = msg.content[0].text.strip().lower()
|
||||||
|
for cat in allowed:
|
||||||
|
if cat.lower() in raw:
|
||||||
|
return cat
|
||||||
|
return "uncategorized"
|
||||||
|
|
||||||
|
|
||||||
|
def classify_keyword(keyword: str) -> str:
|
||||||
|
now = time.time()
|
||||||
|
cached = _category_cache.get(keyword)
|
||||||
|
if cached and cached[1] > now:
|
||||||
|
return cached[0]
|
||||||
|
cat = _llm_classify_one(keyword)
|
||||||
|
_category_cache[keyword] = (cat, now + _CACHE_TTL_SEC)
|
||||||
|
return cat
|
||||||
|
|
||||||
|
|
||||||
|
# ── YouTube Trending ──────────────────────────────────────────────────────────
|
||||||
|
# YouTube Data API v3 videos.list?chart=mostPopular®ionCode=KR
|
||||||
|
# 한국 인기 영상 50개 제목에서 카드 주제로 적합한 키워드 추출.
|
||||||
|
|
||||||
|
def _clean_yt_title(title: str) -> str:
|
||||||
|
"""[공식]·【속보】·🔥 등 제거 후 60자 이내로 자른다."""
|
||||||
|
if not title:
|
||||||
|
return ""
|
||||||
|
cleaned = _TITLE_BRACKET_RE.sub("", title)
|
||||||
|
cleaned = _EMOJI_RE.sub("", cleaned)
|
||||||
|
cleaned = re.sub(r"\s+", " ", cleaned).strip()
|
||||||
|
return cleaned[:_TITLE_MAX_LEN]
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_youtube_trending() -> List[Dict[str, Any]]:
|
||||||
|
"""YouTube Data API v3 mostPopular (한국, 50개). API 키 없거나 호출 실패 시 빈 리스트."""
|
||||||
|
if not YOUTUBE_DATA_API_KEY:
|
||||||
|
logger.info("YOUTUBE_DATA_API_KEY 미설정 — youtube_trending skip")
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
resp = requests.get(
|
||||||
|
YOUTUBE_TRENDING_URL,
|
||||||
|
params={
|
||||||
|
"part": "snippet",
|
||||||
|
"chart": "mostPopular",
|
||||||
|
"regionCode": "KR",
|
||||||
|
"maxResults": 50,
|
||||||
|
"key": YOUTUBE_DATA_API_KEY,
|
||||||
|
},
|
||||||
|
timeout=15,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
videos = resp.json().get("items", []) or []
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("YouTube trending fetch failed: %s", e)
|
||||||
|
return []
|
||||||
|
|
||||||
|
items: List[Dict[str, Any]] = []
|
||||||
|
seen = set()
|
||||||
|
total = max(1, len(videos))
|
||||||
|
for idx, v in enumerate(videos):
|
||||||
|
title = (v.get("snippet") or {}).get("title", "")
|
||||||
|
kw = _clean_yt_title(title)
|
||||||
|
if not kw or kw in seen:
|
||||||
|
continue
|
||||||
|
seen.add(kw)
|
||||||
|
try:
|
||||||
|
cat = classify_keyword(kw)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("classify_keyword(%s) 실패: %s", kw, e)
|
||||||
|
cat = "uncategorized"
|
||||||
|
rank_score = round(max(0.0, 1.0 - (idx / total)), 4)
|
||||||
|
items.append({
|
||||||
|
"keyword": kw,
|
||||||
|
"category": cat,
|
||||||
|
"source": "youtube_trending",
|
||||||
|
"score": rank_score,
|
||||||
|
"articles_count": 0,
|
||||||
|
})
|
||||||
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
def collect_youtube_trending() -> int:
|
||||||
|
items = fetch_youtube_trending()
|
||||||
|
for it in items:
|
||||||
|
db.add_external_trend(it)
|
||||||
|
return len(items)
|
||||||
|
|
||||||
|
|
||||||
|
def collect_all(categories: List[str]) -> Dict[str, int]:
|
||||||
|
naver_n = collect_naver_popular_for(categories)
|
||||||
|
yt_n = collect_youtube_trending()
|
||||||
|
return {"naver_popular": naver_n, "youtube_trending": yt_n}
|
||||||
@@ -24,7 +24,7 @@ def tmp_db(monkeypatch):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def test_init_db_creates_six_tables(tmp_db):
|
def test_init_db_creates_seven_tables(tmp_db):
|
||||||
with db_module._conn() as conn:
|
with db_module._conn() as conn:
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
|
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
|
||||||
@@ -33,6 +33,7 @@ def test_init_db_creates_six_tables(tmp_db):
|
|||||||
assert names == sorted([
|
assert names == sorted([
|
||||||
"news_articles", "trending_keywords", "card_slates",
|
"news_articles", "trending_keywords", "card_slates",
|
||||||
"card_assets", "generation_tasks", "prompt_templates",
|
"card_assets", "generation_tasks", "prompt_templates",
|
||||||
|
"account_preferences",
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
71
insta-lab/tests/test_extract_with_weights.py
Normal file
71
insta-lab/tests/test_extract_with_weights.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
import os
|
||||||
|
import gc
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app import db as db_module
|
||||||
|
from app import keyword_extractor
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmp_db(monkeypatch):
|
||||||
|
fd, path = tempfile.mkstemp(suffix=".db")
|
||||||
|
os.close(fd)
|
||||||
|
monkeypatch.setattr(db_module, "DB_PATH", path)
|
||||||
|
db_module.init_db()
|
||||||
|
yield path
|
||||||
|
gc.collect()
|
||||||
|
for ext in ("", "-wal", "-shm"):
|
||||||
|
try:
|
||||||
|
os.remove(path + ext)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_with_weights_proportional(tmp_db, monkeypatch):
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def fake_extract(category, limit):
|
||||||
|
calls.append((category, limit))
|
||||||
|
return [{"id": i, "keyword": f"{category}{i}", "category": category, "score": 0.5}
|
||||||
|
for i in range(limit)]
|
||||||
|
|
||||||
|
monkeypatch.setattr(keyword_extractor, "extract_for_category", fake_extract)
|
||||||
|
out = keyword_extractor.extract_with_weights(
|
||||||
|
{"economy": 0.6, "psychology": 0.3, "celebrity": 0.1}, total_limit=10,
|
||||||
|
)
|
||||||
|
by_cat = {c: l for c, l in calls}
|
||||||
|
assert by_cat == {"economy": 6, "psychology": 3, "celebrity": 1}
|
||||||
|
assert len(out) == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_with_weights_skips_zero(tmp_db, monkeypatch):
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def fake_extract(category, limit):
|
||||||
|
calls.append((category, limit))
|
||||||
|
return []
|
||||||
|
|
||||||
|
monkeypatch.setattr(keyword_extractor, "extract_for_category", fake_extract)
|
||||||
|
keyword_extractor.extract_with_weights(
|
||||||
|
{"economy": 1.0, "celebrity": 0.0}, total_limit=10,
|
||||||
|
)
|
||||||
|
cats_called = [c for c, _ in calls]
|
||||||
|
assert "celebrity" not in cats_called
|
||||||
|
assert "economy" in cats_called
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_with_weights_fallback_to_equal(tmp_db, monkeypatch):
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def fake_extract(category, limit):
|
||||||
|
calls.append((category, limit))
|
||||||
|
return []
|
||||||
|
|
||||||
|
monkeypatch.setattr(keyword_extractor, "extract_for_category", fake_extract)
|
||||||
|
keyword_extractor.extract_with_weights({}, total_limit=9)
|
||||||
|
by_cat = {c: l for c, l in calls}
|
||||||
|
assert set(by_cat.keys()) == {"economy", "psychology", "celebrity"}
|
||||||
|
assert all(l == 3 for l in by_cat.values())
|
||||||
83
insta-lab/tests/test_main_trends.py
Normal file
83
insta-lab/tests/test_main_trends.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
import os
|
||||||
|
import gc
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from app import db as db_module
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(monkeypatch):
|
||||||
|
fd, path = tempfile.mkstemp(suffix=".db")
|
||||||
|
os.close(fd)
|
||||||
|
monkeypatch.setattr(db_module, "DB_PATH", path)
|
||||||
|
db_module.init_db()
|
||||||
|
from app import main
|
||||||
|
monkeypatch.setattr(main, "DB_PATH", path)
|
||||||
|
with TestClient(main.app) as c:
|
||||||
|
yield c
|
||||||
|
gc.collect()
|
||||||
|
for ext in ("", "-wal", "-shm"):
|
||||||
|
try:
|
||||||
|
os.remove(path + ext)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_preferences_returns_defaults(client):
|
||||||
|
resp = client.get("/api/insta/preferences")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
cats = {p["category"]: p["weight"] for p in resp.json()["categories"]}
|
||||||
|
assert cats == {"economy": 1.0, "psychology": 1.0, "celebrity": 1.0}
|
||||||
|
|
||||||
|
|
||||||
|
def test_put_preferences_upsert(client):
|
||||||
|
resp = client.put("/api/insta/preferences",
|
||||||
|
json={"categories": {"economy": 0.7, "psychology": 0.2, "tech": 0.5}})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
cats = {p["category"]: p["weight"] for p in resp.json()["categories"]}
|
||||||
|
assert cats["economy"] == 0.7
|
||||||
|
assert cats["tech"] == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_trends_filter(client):
|
||||||
|
db_module.add_external_trend({"keyword": "A", "category": "economy",
|
||||||
|
"source": "naver_popular", "score": 1.0})
|
||||||
|
db_module.add_external_trend({"keyword": "B", "category": "celebrity",
|
||||||
|
"source": "google_trends", "score": 0.8})
|
||||||
|
resp = client.get("/api/insta/trends?source=naver_popular")
|
||||||
|
items = resp.json()["items"]
|
||||||
|
assert {it["keyword"] for it in items} == {"A"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_trends_kicks_background(client, monkeypatch):
|
||||||
|
from app import main, trend_collector
|
||||||
|
|
||||||
|
captured = {"called": False}
|
||||||
|
|
||||||
|
def fake_collect_all(cats):
|
||||||
|
captured["called"] = True
|
||||||
|
return {"naver_popular": 3, "youtube_trending": 2}
|
||||||
|
|
||||||
|
monkeypatch.setattr(trend_collector, "collect_all", fake_collect_all)
|
||||||
|
resp = client.post("/api/insta/trends/collect", json={})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
task_id = resp.json()["task_id"]
|
||||||
|
for _ in range(20):
|
||||||
|
st = client.get(f"/api/insta/tasks/{task_id}").json()
|
||||||
|
if st["status"] in ("succeeded", "failed"):
|
||||||
|
break
|
||||||
|
assert st["status"] == "succeeded"
|
||||||
|
assert captured["called"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_keywords_filters_by_source(client):
|
||||||
|
db_module.add_trending_keyword({"keyword": "M", "category": "economy",
|
||||||
|
"score": 0.4, "articles_count": 1, "source": "manual"})
|
||||||
|
db_module.add_external_trend({"keyword": "N", "category": "economy",
|
||||||
|
"source": "naver_popular", "score": 0.9})
|
||||||
|
resp = client.get("/api/insta/keywords?source=manual")
|
||||||
|
items = resp.json()["items"]
|
||||||
|
assert {it["keyword"] for it in items} == {"M"}
|
||||||
77
insta-lab/tests/test_preferences_crud.py
Normal file
77
insta-lab/tests/test_preferences_crud.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import os
|
||||||
|
import gc
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app import db as db_module
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmp_db(monkeypatch):
|
||||||
|
fd, path = tempfile.mkstemp(suffix=".db")
|
||||||
|
os.close(fd)
|
||||||
|
monkeypatch.setattr(db_module, "DB_PATH", path)
|
||||||
|
db_module.init_db()
|
||||||
|
yield path
|
||||||
|
gc.collect()
|
||||||
|
for ext in ("", "-wal", "-shm"):
|
||||||
|
try:
|
||||||
|
os.remove(path + ext)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_db_creates_account_preferences(tmp_db):
|
||||||
|
with db_module._conn() as conn:
|
||||||
|
rows = conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
|
||||||
|
names = {r[0] for r in rows}
|
||||||
|
assert "account_preferences" in names
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_db_seeds_default_weights(tmp_db):
|
||||||
|
prefs = db_module.get_preferences()
|
||||||
|
cats = {p["category"]: p["weight"] for p in prefs}
|
||||||
|
assert cats["economy"] == pytest.approx(1.0)
|
||||||
|
assert cats["psychology"] == pytest.approx(1.0)
|
||||||
|
assert cats["celebrity"] == pytest.approx(1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_upsert_preferences_replaces_weights(tmp_db):
|
||||||
|
db_module.upsert_preferences({"economy": 0.6, "psychology": 0.3, "celebrity": 0.1, "tech": 0.5})
|
||||||
|
prefs = {p["category"]: p["weight"] for p in db_module.get_preferences()}
|
||||||
|
assert prefs["economy"] == pytest.approx(0.6)
|
||||||
|
assert prefs["tech"] == pytest.approx(0.5)
|
||||||
|
assert "celebrity" in prefs and prefs["celebrity"] == pytest.approx(0.1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_trending_keywords_source_column_exists(tmp_db):
|
||||||
|
with db_module._conn() as conn:
|
||||||
|
cols = [r[1] for r in conn.execute("PRAGMA table_info(trending_keywords)").fetchall()]
|
||||||
|
assert "source" in cols
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_trending_keyword_default_source(tmp_db):
|
||||||
|
kid = db_module.add_trending_keyword({
|
||||||
|
"keyword": "K", "category": "economy", "score": 0.5, "articles_count": 3,
|
||||||
|
})
|
||||||
|
with db_module._conn() as conn:
|
||||||
|
row = conn.execute("SELECT source FROM trending_keywords WHERE id=?", (kid,)).fetchone()
|
||||||
|
assert row[0] == "manual"
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_external_trend_stores_source(tmp_db):
|
||||||
|
tid = db_module.add_external_trend({
|
||||||
|
"keyword": "급등주", "category": "economy", "source": "naver_popular", "score": 0.9,
|
||||||
|
})
|
||||||
|
rows = db_module.list_trends(source="naver_popular")
|
||||||
|
assert any(r["id"] == tid and r["keyword"] == "급등주" for r in rows)
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_trends_filters_by_source_and_category(tmp_db):
|
||||||
|
db_module.add_external_trend({"keyword": "A", "category": "economy", "source": "naver_popular", "score": 1.0})
|
||||||
|
db_module.add_external_trend({"keyword": "B", "category": "celebrity", "source": "google_trends", "score": 1.0})
|
||||||
|
only_naver = db_module.list_trends(source="naver_popular")
|
||||||
|
assert {r["keyword"] for r in only_naver} == {"A"}
|
||||||
|
only_celeb_google = db_module.list_trends(source="google_trends", category="celebrity")
|
||||||
|
assert {r["keyword"] for r in only_celeb_google} == {"B"}
|
||||||
160
insta-lab/tests/test_trend_collector.py
Normal file
160
insta-lab/tests/test_trend_collector.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
import os
|
||||||
|
import gc
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app import db as db_module
|
||||||
|
from app import trend_collector
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmp_db(monkeypatch):
|
||||||
|
fd, path = tempfile.mkstemp(suffix=".db")
|
||||||
|
os.close(fd)
|
||||||
|
monkeypatch.setattr(db_module, "DB_PATH", path)
|
||||||
|
db_module.init_db()
|
||||||
|
yield path
|
||||||
|
gc.collect()
|
||||||
|
for ext in ("", "-wal", "-shm"):
|
||||||
|
try:
|
||||||
|
os.remove(path + ext)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
NAVER_RESPONSE = {
|
||||||
|
"items": [
|
||||||
|
{"title": "<b>기준금리</b> 인상", "link": "https://n.news.naver.com/a/1", "description": "한국은행 발표"},
|
||||||
|
{"title": "환율 급등", "link": "https://n.news.naver.com/a/2", "description": "달러 강세"},
|
||||||
|
{"title": "기준금리 추가 인상", "link": "https://n.news.naver.com/a/3", "description": "추가 발표"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_naver_popular_extracts_top_terms(tmp_db, monkeypatch):
|
||||||
|
fake_resp = MagicMock()
|
||||||
|
fake_resp.json.return_value = NAVER_RESPONSE
|
||||||
|
fake_resp.raise_for_status.return_value = None
|
||||||
|
|
||||||
|
with patch.object(trend_collector.requests, "get", return_value=fake_resp):
|
||||||
|
trends = trend_collector.fetch_naver_popular("economy", per_seed=10, top_n=5)
|
||||||
|
|
||||||
|
keywords = [t["keyword"] for t in trends]
|
||||||
|
assert "기준금리" in keywords
|
||||||
|
for t in trends:
|
||||||
|
assert t["category"] == "economy"
|
||||||
|
assert t["source"] == "naver_popular"
|
||||||
|
assert 0.0 <= t["score"] <= 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_naver_writes_to_db(tmp_db, monkeypatch):
|
||||||
|
fake_resp = MagicMock()
|
||||||
|
fake_resp.json.return_value = NAVER_RESPONSE
|
||||||
|
fake_resp.raise_for_status.return_value = None
|
||||||
|
with patch.object(trend_collector.requests, "get", return_value=fake_resp):
|
||||||
|
n = trend_collector.collect_naver_popular_for(["economy"])
|
||||||
|
assert n > 0
|
||||||
|
rows = db_module.list_trends(source="naver_popular")
|
||||||
|
assert len(rows) > 0
|
||||||
|
assert all(r["source"] == "naver_popular" for r in rows)
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_keyword_with_cache(monkeypatch):
|
||||||
|
calls = {"n": 0}
|
||||||
|
|
||||||
|
def fake_claude(keyword: str) -> str:
|
||||||
|
calls["n"] += 1
|
||||||
|
return "economy"
|
||||||
|
|
||||||
|
monkeypatch.setattr(trend_collector, "_llm_classify_one", fake_claude)
|
||||||
|
trend_collector._category_cache.clear()
|
||||||
|
|
||||||
|
c1 = trend_collector.classify_keyword("기준금리")
|
||||||
|
c2 = trend_collector.classify_keyword("기준금리")
|
||||||
|
assert c1 == c2 == "economy"
|
||||||
|
assert calls["n"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_youtube_trending_parses_and_cleans_titles(tmp_db, monkeypatch):
|
||||||
|
"""YouTube Data API mostPopular 응답 → 제목 정제 + 분류."""
|
||||||
|
monkeypatch.setattr(trend_collector, "YOUTUBE_DATA_API_KEY", "fake_key")
|
||||||
|
payload = {
|
||||||
|
"items": [
|
||||||
|
{"snippet": {"title": "[속보] 기준금리 인상 단행 🔥"}},
|
||||||
|
{"snippet": {"title": "(공식) BTS 컴백 무대 🎤"}},
|
||||||
|
{"snippet": {"title": "스트레스 관리 5가지 방법"}},
|
||||||
|
# 중복 제목 — 중복 제거 확인
|
||||||
|
{"snippet": {"title": "[속보] 기준금리 인상 단행 🔥"}},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
fake_resp = MagicMock()
|
||||||
|
fake_resp.json.return_value = payload
|
||||||
|
fake_resp.raise_for_status.return_value = None
|
||||||
|
monkeypatch.setattr(trend_collector.requests, "get", lambda *a, **kw: fake_resp)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
trend_collector, "classify_keyword",
|
||||||
|
lambda kw: ("economy" if "금리" in kw else
|
||||||
|
"celebrity" if "BTS" in kw else
|
||||||
|
"psychology" if "스트레스" in kw else "uncategorized"),
|
||||||
|
)
|
||||||
|
|
||||||
|
trends = trend_collector.fetch_youtube_trending()
|
||||||
|
keywords = [t["keyword"] for t in trends]
|
||||||
|
assert "기준금리 인상 단행" in keywords # 대괄호·이모지 제거
|
||||||
|
assert "BTS 컴백 무대" in keywords # 괄호 제거
|
||||||
|
assert "스트레스 관리 5가지 방법" in keywords # 그대로
|
||||||
|
assert len(trends) == 3 # 중복 제거됨
|
||||||
|
assert all(t["source"] == "youtube_trending" for t in trends)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_youtube_trending_no_api_key_returns_empty(monkeypatch):
|
||||||
|
monkeypatch.setattr(trend_collector, "YOUTUBE_DATA_API_KEY", "")
|
||||||
|
out = trend_collector.fetch_youtube_trending()
|
||||||
|
assert out == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_youtube_trending_graceful_on_api_failure(monkeypatch):
|
||||||
|
monkeypatch.setattr(trend_collector, "YOUTUBE_DATA_API_KEY", "fake_key")
|
||||||
|
fake_resp = MagicMock()
|
||||||
|
fake_resp.raise_for_status.side_effect = RuntimeError("quota exceeded")
|
||||||
|
monkeypatch.setattr(trend_collector.requests, "get", lambda *a, **kw: fake_resp)
|
||||||
|
out = trend_collector.fetch_youtube_trending()
|
||||||
|
assert out == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_all_invokes_both_sources(tmp_db, monkeypatch):
|
||||||
|
monkeypatch.setattr(trend_collector, "collect_naver_popular_for",
|
||||||
|
lambda cats: 5)
|
||||||
|
monkeypatch.setattr(trend_collector, "collect_youtube_trending",
|
||||||
|
lambda: 3)
|
||||||
|
out = trend_collector.collect_all(["economy"])
|
||||||
|
assert out == {"naver_popular": 5, "youtube_trending": 3}
|
||||||
|
|
||||||
|
|
||||||
|
def test_seeds_for_filters_placeholder(tmp_db, monkeypatch):
|
||||||
|
"""category_seeds 템플릿에 placeholder '...'가 들어가도 DEFAULT 폴백."""
|
||||||
|
from app import db as db_module
|
||||||
|
db_module.upsert_prompt_template(
|
||||||
|
"category_seeds",
|
||||||
|
'{"economy": ["...", "…", "a", "real_keyword"]}',
|
||||||
|
"test",
|
||||||
|
)
|
||||||
|
out = trend_collector._seeds_for("economy")
|
||||||
|
# '...', '…', 'a'(2자 미만)는 필터링되고 'real_keyword'만 남음
|
||||||
|
assert out == ["real_keyword"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_seeds_for_falls_back_when_all_invalid(tmp_db, monkeypatch):
|
||||||
|
"""모든 시드가 invalid면 DEFAULT_CATEGORY_SEEDS 폴백."""
|
||||||
|
from app import db as db_module
|
||||||
|
db_module.upsert_prompt_template(
|
||||||
|
"category_seeds",
|
||||||
|
'{"economy": ["...", "TBD", ""]}',
|
||||||
|
"test",
|
||||||
|
)
|
||||||
|
out = trend_collector._seeds_for("economy")
|
||||||
|
# DEFAULT_CATEGORY_SEEDS["economy"] 가 반환되어야 함
|
||||||
|
from app.config import DEFAULT_CATEGORY_SEEDS
|
||||||
|
assert out == list(DEFAULT_CATEGORY_SEEDS["economy"])
|
||||||
@@ -133,8 +133,12 @@ async def sign_link(
|
|||||||
|
|
||||||
# 경로 안전: PACK_HOST_DIR(NAS 호스트 절대경로) 하위인지 확인.
|
# 경로 안전: PACK_HOST_DIR(NAS 호스트 절대경로) 하위인지 확인.
|
||||||
# file_path는 upload 라우트가 Supabase에 저장한 호스트경로 그대로 전달되어 DSM API에 사용됨.
|
# file_path는 upload 라우트가 Supabase에 저장한 호스트경로 그대로 전달되어 DSM API에 사용됨.
|
||||||
|
# str.startswith는 '/foo/packs' 와 '/foo/packs_evil' 같은 sibling 경로를 통과시키므로
|
||||||
|
# Path.relative_to로 엄격하게 컴포넌트 단위 검증한다 (CODE_REVIEW F1).
|
||||||
abs_path = Path(payload.file_path).resolve()
|
abs_path = Path(payload.file_path).resolve()
|
||||||
if not str(abs_path).startswith(str(PACK_HOST_DIR)):
|
try:
|
||||||
|
abs_path.relative_to(PACK_HOST_DIR.resolve())
|
||||||
|
except ValueError:
|
||||||
raise HTTPException(status_code=400, detail="허용된 경로 외부")
|
raise HTTPException(status_code=400, detail="허용된 경로 외부")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -60,6 +60,29 @@ def test_sign_link_path_outside_base():
|
|||||||
assert r.status_code == 400
|
assert r.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_sign_link_rejects_sibling_path():
|
||||||
|
"""PACK_HOST_DIR='/foo/packs' 일 때 '/foo/packs_evil/x.mp4' 같이 prefix만
|
||||||
|
통과하는 sibling 경로는 거부해야 한다 (CODE_REVIEW F1, path traversal 변형).
|
||||||
|
|
||||||
|
기존 str.startswith 방식은 trailing slash가 없어 sibling 경로를 통과시킴.
|
||||||
|
relative_to 기반 검증으로 교체되어야 통과한다.
|
||||||
|
"""
|
||||||
|
import json as _json
|
||||||
|
from pathlib import Path
|
||||||
|
base_resolved = Path("/foo/packs").resolve()
|
||||||
|
# base의 자식이 아닌 sibling 경로 (예: /foo/packs_evil/...)
|
||||||
|
sibling_posix = (base_resolved.parent / f"{base_resolved.name}_evil" / "x.mp4").as_posix()
|
||||||
|
with patch("app.routes.PACK_HOST_DIR", base_resolved):
|
||||||
|
body = _json.dumps(
|
||||||
|
{"file_path": sibling_posix, "expires_in_seconds": 14400}
|
||||||
|
).encode()
|
||||||
|
r = client.post("/api/packs/sign-link", content=body, headers=_signed(body))
|
||||||
|
assert r.status_code == 400, (
|
||||||
|
f"sibling 경로 '{sibling_posix}'가 허용됨 (status={r.status_code}) "
|
||||||
|
f"— path traversal 가능성"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_upload_invalid_token():
|
def test_upload_invalid_token():
|
||||||
r = client.post(
|
r = client.post(
|
||||||
"/api/packs/upload",
|
"/api/packs/upload",
|
||||||
|
|||||||
@@ -1,6 +1,14 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
|
# ── docker / compose / buildkit timeout 늘리기 ──
|
||||||
|
# NAS Celeron J4025에서 pip install·chromium 다운로드 등 무거운 RUN step이
|
||||||
|
# 기본 timeout(2분)에 걸려 webhook 자동 배포가 "DeadlineExceeded"로 끝나는 일이
|
||||||
|
# 있어 10분으로 상향. 호스트 셸 + deployer 컨테이너 둘 다에 적용됨.
|
||||||
|
export COMPOSE_HTTP_TIMEOUT=600
|
||||||
|
export DOCKER_CLIENT_TIMEOUT=600
|
||||||
|
export BUILDKIT_STEP_LOG_MAX_SIZE=-1
|
||||||
|
|
||||||
# ── 동시 배포 방지 (flock) ──
|
# ── 동시 배포 방지 (flock) ──
|
||||||
exec 200>/tmp/deploy.lock
|
exec 200>/tmp/deploy.lock
|
||||||
flock -n 200 || { echo "Deploy already running, skipping"; exit 0; }
|
flock -n 200 || { echo "Deploy already running, skipping"; exit 0; }
|
||||||
|
|||||||
@@ -47,13 +47,30 @@ scheduler = BackgroundScheduler(timezone=os.getenv("TZ", "Asia/Seoul"))
|
|||||||
# Windows AI Server URL (NAS .env에서 설정)
|
# Windows AI Server URL (NAS .env에서 설정)
|
||||||
WINDOWS_AI_SERVER_URL = os.getenv("WINDOWS_AI_SERVER_URL", "http://192.168.0.5:8000")
|
WINDOWS_AI_SERVER_URL = os.getenv("WINDOWS_AI_SERVER_URL", "http://192.168.0.5:8000")
|
||||||
|
|
||||||
# Admin API Key 인증
|
# Admin API Key 인증 — /api/trade/* 보호 (CODE_REVIEW F2)
|
||||||
|
# 빈 키 + 명시적 dev flag 없으면 503으로 거부. 운영 .env에 ADMIN_API_KEY 누락 시
|
||||||
|
# 무인증 통과되던 버그 차단.
|
||||||
ADMIN_API_KEY = os.getenv("ADMIN_API_KEY", "")
|
ADMIN_API_KEY = os.getenv("ADMIN_API_KEY", "")
|
||||||
|
|
||||||
def verify_admin(x_admin_key: str = Header(None)):
|
def verify_admin(x_admin_key: str = Header(None)):
|
||||||
"""admin/trade 엔드포인트 보호용 API 키 검증"""
|
"""admin/trade 엔드포인트 보호용 API 키 검증.
|
||||||
|
|
||||||
|
- ADMIN_API_KEY 설정됨 + 키 일치 → 통과
|
||||||
|
- ADMIN_API_KEY 설정됨 + 키 불일치 → 401 Unauthorized
|
||||||
|
- ADMIN_API_KEY 미설정 + ALLOW_UNAUTHENTICATED_ADMIN=true → 통과 (개발 모드)
|
||||||
|
- ADMIN_API_KEY 미설정 + dev flag 없음 → 503 (보호 강화, 운영 .env 누락 차단)
|
||||||
|
"""
|
||||||
if not ADMIN_API_KEY:
|
if not ADMIN_API_KEY:
|
||||||
return # 키 미설정 시 인증 비활성화 (개발 환경)
|
if os.getenv("ALLOW_UNAUTHENTICATED_ADMIN", "false").lower() == "true":
|
||||||
|
return # 개발 환경 명시적 허용
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503,
|
||||||
|
detail=(
|
||||||
|
"admin endpoint protected — ADMIN_API_KEY not configured. "
|
||||||
|
"Set ADMIN_API_KEY in .env, or set ALLOW_UNAUTHENTICATED_ADMIN=true "
|
||||||
|
"for development only."
|
||||||
|
),
|
||||||
|
)
|
||||||
if x_admin_key != ADMIN_API_KEY:
|
if x_admin_key != ADMIN_API_KEY:
|
||||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
|
|
||||||
|
|||||||
3
stock/pytest.ini
Normal file
3
stock/pytest.ini
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[pytest]
|
||||||
|
pythonpath = .
|
||||||
|
asyncio_mode = auto
|
||||||
43
stock/tests/test_admin_auth.py
Normal file
43
stock/tests/test_admin_auth.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
"""verify_admin 보안 강화 회귀 테스트 (CODE_REVIEW F2).
|
||||||
|
|
||||||
|
운영 .env에서 ADMIN_API_KEY가 누락되면 /api/trade/balance, /api/trade/order
|
||||||
|
인증이 무력화되는 버그를 막기 위한 가드.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from app import main as stock_main
|
||||||
|
|
||||||
|
|
||||||
|
def test_verify_admin_rejects_when_key_missing_and_no_dev_flag(monkeypatch):
|
||||||
|
"""ADMIN_API_KEY 미설정 + ALLOW_UNAUTHENTICATED_ADMIN 미설정 → 503."""
|
||||||
|
monkeypatch.setattr(stock_main, "ADMIN_API_KEY", "")
|
||||||
|
monkeypatch.delenv("ALLOW_UNAUTHENTICATED_ADMIN", raising=False)
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
stock_main.verify_admin(x_admin_key=None)
|
||||||
|
assert exc_info.value.status_code == 503
|
||||||
|
assert "ADMIN_API_KEY" in exc_info.value.detail
|
||||||
|
|
||||||
|
|
||||||
|
def test_verify_admin_allows_when_key_missing_with_dev_flag(monkeypatch):
|
||||||
|
"""ADMIN_API_KEY 미설정 + ALLOW_UNAUTHENTICATED_ADMIN=true → 통과 (개발 모드)."""
|
||||||
|
monkeypatch.setattr(stock_main, "ADMIN_API_KEY", "")
|
||||||
|
monkeypatch.setenv("ALLOW_UNAUTHENTICATED_ADMIN", "true")
|
||||||
|
stock_main.verify_admin(x_admin_key=None) # 예외 없으면 통과
|
||||||
|
|
||||||
|
|
||||||
|
def test_verify_admin_rejects_wrong_key(monkeypatch):
|
||||||
|
"""ADMIN_API_KEY 설정 + 잘못된 키 → 401 (regression)."""
|
||||||
|
monkeypatch.setattr(stock_main, "ADMIN_API_KEY", "secret123")
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
stock_main.verify_admin(x_admin_key="wrong")
|
||||||
|
assert exc_info.value.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_verify_admin_allows_correct_key(monkeypatch):
|
||||||
|
"""ADMIN_API_KEY 설정 + 올바른 키 → 통과 (regression)."""
|
||||||
|
monkeypatch.setattr(stock_main, "ADMIN_API_KEY", "secret123")
|
||||||
|
stock_main.verify_admin(x_admin_key="secret123") # 예외 없으면 통과
|
||||||
Reference in New Issue
Block a user