Files
web-page-backend/stock-lab/app/main.py
gahusb ff975defbd AI Coach 백엔드 프록시 추가 및 trade 엔드포인트 인증 적용
- POST /api/stock/ai-coach: Anthropic API 프록시 (API 키 서버 보관)
- trade/balance, trade/order: ADMIN_API_KEY 헤더 인증 추가
- print() → logging 모듈 전환 (stock-lab)
- .env.example: ADMIN_API_KEY, ANTHROPIC_API_KEY, CORS_ALLOW_ORIGINS 추가

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-03 01:12:31 +09:00

467 lines
15 KiB
Python

import os
import json
import logging
from datetime import date as date_type
from typing import Optional
from fastapi import FastAPI, Query, Header, Depends, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import requests
from apscheduler.schedulers.background import BackgroundScheduler
from pydantic import BaseModel
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(levelname)s %(message)s")
logger = logging.getLogger("stock-lab")
from .db import (
init_db, save_articles, get_latest_articles,
add_portfolio_item, get_all_portfolio, get_portfolio_item,
update_portfolio_item, delete_portfolio_item,
upsert_broker_cash, get_all_broker_cash, get_broker_cash, delete_broker_cash,
upsert_asset_snapshot, get_asset_snapshots,
add_sell_history, get_sell_history, update_sell_history, delete_sell_history,
)
from .scraper import fetch_market_news, fetch_major_indices, fetch_overseas_news
from .price_fetcher import get_current_prices
app = FastAPI()
# CORS 설정 (프론트엔드 접근 허용)
_cors_origins = os.getenv("CORS_ALLOW_ORIGINS", "http://localhost:3007,http://localhost:8080").split(",")
app.add_middleware(
CORSMiddleware,
allow_origins=[o.strip() for o in _cors_origins],
allow_credentials=False,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["Content-Type"],
)
scheduler = BackgroundScheduler(timezone=os.getenv("TZ", "Asia/Seoul"))
# Windows AI Server URL (NAS .env에서 설정)
WINDOWS_AI_SERVER_URL = os.getenv("WINDOWS_AI_SERVER_URL", "http://192.168.0.5:8000")
# Admin API Key 인증
ADMIN_API_KEY = os.getenv("ADMIN_API_KEY", "")
def verify_admin(x_admin_key: str = Header(None)):
"""admin/trade 엔드포인트 보호용 API 키 검증"""
if not ADMIN_API_KEY:
return # 키 미설정 시 인증 비활성화 (개발 환경)
if x_admin_key != ADMIN_API_KEY:
raise HTTPException(status_code=401, detail="Unauthorized")
# Anthropic API 프록시용 키 (서버 측 보관)
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
# 공휴일 목록 로드
_HOLIDAYS_PATH = os.path.join(os.path.dirname(__file__), "holidays.json")
try:
with open(_HOLIDAYS_PATH, "r") as f:
_HOLIDAYS: set = set(json.load(f))
except Exception:
_HOLIDAYS = set()
def is_market_open(d: date_type) -> bool:
return d.weekday() < 5 and d.strftime("%Y-%m-%d") not in _HOLIDAYS
def save_daily_snapshot():
today = date_type.today()
if not is_market_open(today):
print(f"[Snapshot] {today} 휴장일 — 스킵")
return
today_str = today.strftime("%Y-%m-%d")
items = get_all_portfolio()
cash_rows = get_all_broker_cash()
total_cash = sum(r["cash"] for r in cash_rows)
if items:
tickers = list({item["ticker"] for item in items})
prices = get_current_prices(tickers)
total_eval = sum(
prices.get(item["ticker"], item["avg_price"]) * item["quantity"]
for item in items
)
else:
total_eval = 0
total_assets = total_eval + total_cash
upsert_asset_snapshot(today_str, total_eval, total_cash, total_assets)
print(f"[Snapshot] {today_str} 저장 완료: eval={total_eval}, cash={total_cash}, total={total_assets}")
@app.on_event("startup")
def on_startup():
init_db()
# 매일 아침 8시 뉴스 스크랩 (NAS 자체 수행)
scheduler.add_job(run_scraping_job, "cron", hour="8", minute="0")
# 평일 15:40 총 자산 스냅샷 저장
scheduler.add_job(save_daily_snapshot, "cron", day_of_week="mon-fri", hour=15, minute=40)
# 앱 시작 시에도 한 번 실행 (데이터 없으면)
if not get_latest_articles(1):
run_scraping_job()
scheduler.start()
def run_scraping_job():
print("[StockLab] Starting news scraping...")
# 1. 국내
articles_kr = fetch_market_news()
count_kr = save_articles(articles_kr)
# 2. 해외 (임시 차단)
# articles_world = fetch_overseas_news()
# count_world = save_articles(articles_world)
count_world = 0
print(f"[StockLab] Saved {count_kr} domestic, {count_world} overseas articles.")
@app.get("/health")
def health():
return {"ok": True}
@app.get("/api/stock/news")
def get_news(limit: int = 20, category: str = None):
"""최신 주식 뉴스 조회 (category: 'domestic' | 'overseas')"""
return get_latest_articles(limit, category)
@app.get("/api/stock/indices")
def get_indices():
"""주요 지표(KOSPI 등) 실시간 크롤링 조회"""
return fetch_major_indices()
@app.post("/api/stock/scrap")
def trigger_scrap():
"""수동 스크랩 트리거"""
run_scraping_job()
return {"ok": True}
# --- Trading API (Windows Proxy, 인증 필요) ---
@app.get("/api/trade/balance", dependencies=[Depends(verify_admin)])
def get_balance():
"""계좌 잔고 조회 (Windows AI Server Proxy)"""
logger.info(f"Requesting Balance from {WINDOWS_AI_SERVER_URL}")
try:
resp = requests.get(f"{WINDOWS_AI_SERVER_URL}/trade/balance", timeout=5)
if resp.status_code != 200:
logger.error(f"Balance Error: {resp.status_code}")
return JSONResponse(status_code=resp.status_code, content=resp.json())
return resp.json()
except requests.JSONDecodeError:
return JSONResponse(status_code=resp.status_code, content={"error": f"Upstream error {resp.status_code}"})
except Exception as e:
logger.error(f"Balance Connection Failed: {e}")
return JSONResponse(status_code=500, content={"error": "Connection Failed"})
class OrderRequest(BaseModel):
ticker: str
action: str
quantity: int
price: int = 0
reason: Optional[str] = "Manual Order"
@app.post("/api/trade/order", dependencies=[Depends(verify_admin)])
def order_stock(req: OrderRequest):
"""주식 매수/매도 주문 (Windows AI Server Proxy)"""
logger.info(f"Order Request: {req.action} {req.ticker} x{req.quantity}")
try:
resp = requests.post(f"{WINDOWS_AI_SERVER_URL}/trade/order", json=req.dict(), timeout=10)
if resp.status_code != 200:
logger.error(f"Order Error: {resp.status_code}")
return JSONResponse(status_code=resp.status_code, content=resp.json())
return resp.json()
except requests.JSONDecodeError:
return JSONResponse(status_code=resp.status_code, content={"error": f"Upstream error {resp.status_code}"})
except Exception as e:
logger.error(f"Order Connection Failed: {e}")
return JSONResponse(status_code=500, content={"error": "Connection Failed"})
# --- AI Coach 프록시 (API 키를 서버에 보관) ---
class AiCoachRequest(BaseModel):
model: str = "claude-haiku-4-5-20251001"
prompt: str
max_tokens: int = 1024
@app.post("/api/stock/ai-coach")
def ai_coach(req: AiCoachRequest):
"""AI 포트폴리오 코치 — Anthropic API 프록시 (API 키 서버 보관)"""
if not ANTHROPIC_API_KEY:
raise HTTPException(503, "AI Coach not configured (ANTHROPIC_API_KEY missing)")
allowed_models = {"claude-haiku-4-5-20251001", "claude-sonnet-4-6"}
model = req.model if req.model in allowed_models else "claude-haiku-4-5-20251001"
try:
resp = requests.post(
"https://api.anthropic.com/v1/messages",
headers={
"Content-Type": "application/json",
"x-api-key": ANTHROPIC_API_KEY,
"anthropic-version": "2023-06-01",
},
json={
"model": model,
"max_tokens": req.max_tokens,
"messages": [{"role": "user", "content": req.prompt}],
},
timeout=30,
)
if resp.status_code != 200:
logger.error(f"Anthropic API error: {resp.status_code}")
return JSONResponse(status_code=resp.status_code, content={"error": "AI API error"})
return resp.json()
except requests.Timeout:
return JSONResponse(status_code=504, content={"error": "AI API timeout"})
except Exception as e:
logger.error(f"AI Coach error: {e}")
return JSONResponse(status_code=500, content={"error": "AI Coach failed"})
@app.get("/api/version")
def version():
return {"version": os.getenv("APP_VERSION", "dev")}
# --- Portfolio API ---
class PortfolioItemRequest(BaseModel):
broker: str
ticker: str
name: str
quantity: int
avg_price: int
class PortfolioUpdateRequest(BaseModel):
broker: Optional[str] = None
ticker: Optional[str] = None
name: Optional[str] = None
quantity: Optional[int] = None
avg_price: Optional[int] = None
@app.get("/api/portfolio")
def get_portfolio():
"""전체 포트폴리오 조회 (현재가 + 손익 + 예수금 포함)"""
items = get_all_portfolio()
cash_rows = get_all_broker_cash()
total_cash = sum(r["cash"] for r in cash_rows)
if not items:
return {
"holdings": [],
"cash": cash_rows,
"summary": {
"total_buy": 0,
"total_eval": 0,
"total_profit": 0,
"total_profit_rate": 0.0,
"total_cash": total_cash,
"total_assets": total_cash,
},
}
tickers = list({item["ticker"] for item in items})
prices = get_current_prices(tickers)
holdings = []
total_buy = 0
total_eval = 0
for item in items:
current_price = prices.get(item["ticker"])
buy_amount = item["avg_price"] * item["quantity"]
eval_amount = current_price * item["quantity"] if current_price is not None else None
profit_amount = (eval_amount - buy_amount) if eval_amount is not None else None
profit_rate = round((profit_amount / buy_amount) * 100, 2) if (profit_amount is not None and buy_amount) else None
holdings.append({
"id": item["id"],
"broker": item["broker"],
"ticker": item["ticker"],
"name": item["name"],
"quantity": item["quantity"],
"avg_price": item["avg_price"],
"current_price": current_price,
"eval_amount": eval_amount,
"profit_amount": profit_amount,
"profit_rate": profit_rate,
})
total_buy += buy_amount
if eval_amount is not None:
total_eval += eval_amount
total_profit = total_eval - total_buy
total_profit_rate = round((total_profit / total_buy) * 100, 2) if total_buy else 0.0
return {
"holdings": holdings,
"cash": cash_rows,
"summary": {
"total_buy": total_buy,
"total_eval": total_eval,
"total_profit": total_profit,
"total_profit_rate": total_profit_rate,
"total_cash": total_cash,
"total_assets": total_eval + total_cash,
},
}
@app.post("/api/portfolio", status_code=201)
def create_portfolio_item(req: PortfolioItemRequest):
"""포트폴리오 종목 추가"""
item_id = add_portfolio_item(req.broker, req.ticker, req.name, req.quantity, req.avg_price)
return {"id": item_id, "ok": True}
# --- Broker Cash API ---
# /{item_id} 라우트보다 반드시 먼저 정의해야 /cash가 item_id로 매칭되지 않음
class BrokerCashRequest(BaseModel):
broker: str
cash: int
@app.get("/api/portfolio/cash")
def list_broker_cash():
"""증권사별 예수금 전체 조회"""
return get_all_broker_cash()
@app.put("/api/portfolio/cash")
def set_broker_cash(req: BrokerCashRequest):
"""증권사 예수금 등록 또는 수정 (upsert)"""
upsert_broker_cash(req.broker, req.cash)
return {"ok": True, "broker": req.broker, "cash": req.cash}
@app.delete("/api/portfolio/cash/{broker}")
def remove_broker_cash(broker: str):
"""증권사 예수금 삭제"""
if not delete_broker_cash(broker):
return JSONResponse(status_code=404, content={"error": "Broker not found"})
return {"ok": True}
@app.put("/api/portfolio/{item_id}")
def update_portfolio(item_id: int, req: PortfolioUpdateRequest):
"""포트폴리오 종목 수정"""
if get_portfolio_item(item_id) is None:
return JSONResponse(status_code=404, content={"error": "Item not found"})
update_portfolio_item(item_id, **req.dict())
return {"ok": True}
@app.delete("/api/portfolio/{item_id}")
def delete_portfolio(item_id: int):
"""포트폴리오 종목 삭제"""
if not delete_portfolio_item(item_id):
return JSONResponse(status_code=404, content={"error": "Item not found"})
return {"ok": True}
# --- Asset Snapshot API ---
@app.post("/api/portfolio/snapshot")
def create_snapshot():
"""총 자산 스냅샷 수동 저장 (오늘 날짜 기준)"""
today = date_type.today()
today_str = today.strftime("%Y-%m-%d")
items = get_all_portfolio()
cash_rows = get_all_broker_cash()
total_cash = sum(r["cash"] for r in cash_rows)
if items:
tickers = list({item["ticker"] for item in items})
prices = get_current_prices(tickers)
total_eval = sum(
prices.get(item["ticker"], item["avg_price"]) * item["quantity"]
for item in items
)
else:
total_eval = 0
total_assets = total_eval + total_cash
upsert_asset_snapshot(today_str, total_eval, total_cash, total_assets)
return {
"ok": True,
"snapshot": {
"date": today_str,
"total_eval": total_eval,
"total_cash": total_cash,
"total_assets": total_assets,
},
}
@app.get("/api/portfolio/snapshot/history")
def get_snapshot_history(days: int = Query(30, ge=0)):
"""총 자산 스냅샷 이력 조회 (days=0: 전체, days=N: 최근 N일)"""
snapshots = get_asset_snapshots(days)
return {"snapshots": snapshots}
# --- Sell History API ---
class SellHistoryRequest(BaseModel):
broker: str
ticker: str
name: str
quantity: int
avg_price: float
sell_price: float
commission: float = 0
buy_amount: float
sell_amount: float
realized_profit: float
realized_rate: float
sold_at: str
@app.get("/api/portfolio/sell-history")
def list_sell_history(broker: Optional[str] = None, days: Optional[int] = None):
"""매도 내역 조회 (broker, days 필터 선택)"""
records = get_sell_history(broker=broker, days=days)
return {"records": records}
@app.post("/api/portfolio/sell-history")
def create_sell_history(req: SellHistoryRequest):
"""매도 기록 저장"""
record = add_sell_history(req.model_dump())
return record
@app.put("/api/portfolio/sell-history/{record_id}")
def modify_sell_history(record_id: int, req: SellHistoryRequest):
"""매도 기록 수정"""
record = update_sell_history(record_id, req.model_dump())
if record is None:
return JSONResponse(status_code=404, content={"error": "not found"})
return record
@app.delete("/api/portfolio/sell-history/{record_id}")
def remove_sell_history(record_id: int):
"""매도 기록 삭제"""
if not delete_sell_history(record_id):
return JSONResponse(status_code=404, content={"error": "not found"})
return {"ok": True}