import sqlite3 import os import hashlib from typing import List, Dict, Any, Optional from app.screener.schema import ensure_screener_schema DB_PATH = os.environ.get("STOCK_DB_PATH", "/app/data/stock.db") def _conn() -> sqlite3.Connection: db_path = os.environ.get("STOCK_DB_PATH", DB_PATH) parent = os.path.dirname(db_path) if parent: os.makedirs(parent, exist_ok=True) conn = sqlite3.connect(db_path) conn.row_factory = sqlite3.Row return conn def init_db(): with _conn() as conn: conn.execute(""" CREATE TABLE IF NOT EXISTS articles ( id INTEGER PRIMARY KEY AUTOINCREMENT, hash TEXT UNIQUE NOT NULL, category TEXT DEFAULT 'domestic', title TEXT NOT NULL, link TEXT, summary TEXT, press TEXT, pub_date TEXT, crawled_at TEXT ) """) conn.execute("CREATE INDEX IF NOT EXISTS idx_articles_crawled ON articles(crawled_at DESC)") # 컬럼 추가 (기존 테이블 마이그레이션) cols = {r["name"] for r in conn.execute("PRAGMA table_info(articles)").fetchall()} if "category" not in cols: conn.execute("ALTER TABLE articles ADD COLUMN category TEXT DEFAULT 'domestic'") conn.execute(""" CREATE TABLE IF NOT EXISTS portfolio ( id INTEGER PRIMARY KEY AUTOINCREMENT, broker TEXT NOT NULL, ticker TEXT NOT NULL, name TEXT NOT NULL, quantity INTEGER NOT NULL, avg_price INTEGER NOT NULL, purchase_price INTEGER, created_at TEXT DEFAULT (datetime('now','localtime')), updated_at TEXT DEFAULT (datetime('now','localtime')) ) """) # 마이그레이션: 기존 DB에 purchase_price 컬럼 없으면 추가 후 avg_price로 백필 _pf_cols = {r["name"] for r in conn.execute("PRAGMA table_info(portfolio)").fetchall()} if "purchase_price" not in _pf_cols: conn.execute("ALTER TABLE portfolio ADD COLUMN purchase_price INTEGER") conn.execute("UPDATE portfolio SET purchase_price = avg_price WHERE purchase_price IS NULL") conn.execute(""" CREATE TABLE IF NOT EXISTS broker_cash ( id INTEGER PRIMARY KEY AUTOINCREMENT, broker TEXT UNIQUE NOT NULL, cash INTEGER NOT NULL DEFAULT 0, updated_at TEXT DEFAULT (datetime('now','localtime')) ) """) conn.execute(""" CREATE TABLE IF NOT EXISTS asset_snapshots ( id INTEGER PRIMARY KEY AUTOINCREMENT, date TEXT UNIQUE NOT NULL, total_eval INTEGER NOT NULL, total_cash INTEGER NOT NULL, total_assets INTEGER NOT NULL, created_at TEXT DEFAULT (datetime('now','localtime')) ) """) conn.execute(""" CREATE TABLE IF NOT EXISTS sell_history ( id INTEGER PRIMARY KEY AUTOINCREMENT, broker TEXT NOT NULL, ticker TEXT NOT NULL, name TEXT NOT NULL, quantity INTEGER NOT NULL, avg_price REAL NOT NULL, sell_price REAL NOT NULL, commission REAL NOT NULL DEFAULT 0, buy_amount REAL NOT NULL, sell_amount REAL NOT NULL, realized_profit REAL NOT NULL, realized_rate REAL NOT NULL, sold_at TEXT NOT NULL ) """) # sell_history 마이그레이션: commission 컬럼 추가 sh_cols = {r["name"] for r in conn.execute("PRAGMA table_info(sell_history)").fetchall()} if "commission" not in sh_cols: conn.execute("ALTER TABLE sell_history ADD COLUMN commission REAL NOT NULL DEFAULT 0") # Screener 스키마 부트스트랩 (7테이블 + 디폴트 설정 시드) ensure_screener_schema(conn) def save_articles(articles: List[Dict[str, str]]) -> int: count = 0 with _conn() as conn: for a in articles: # 중복 체크용 해시 (제목+링크) unique_str = f"{a['title']}|{a['link']}" h = hashlib.md5(unique_str.encode()).hexdigest() try: cat = a.get("category", "domestic") conn.execute(""" INSERT INTO articles (hash, category, title, link, summary, press, pub_date, crawled_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, (h, cat, a['title'], a['link'], a['summary'], a['press'], a['date'], a['crawled_at'])) count += 1 except sqlite3.IntegrityError: pass # 이미 존재함 return count def get_latest_articles(limit: int = 20, category: str = None) -> List[Dict[str, Any]]: with _conn() as conn: if category: rows = conn.execute( "SELECT * FROM articles WHERE category = ? ORDER BY crawled_at DESC, id DESC LIMIT ?", (category, limit) ).fetchall() else: rows = conn.execute( "SELECT * FROM articles ORDER BY crawled_at DESC, id DESC LIMIT ?", (limit,) ).fetchall() return [dict(r) for r in rows] # --- Portfolio CRUD --- def add_portfolio_item( broker: str, ticker: str, name: str, quantity: int, avg_price: int, purchase_price: Optional[int] = None, ) -> int: # purchase_price 미입력 시 avg_price로 기본값 설정 (하위호환) if purchase_price is None: purchase_price = avg_price with _conn() as conn: cur = conn.execute( "INSERT INTO portfolio (broker, ticker, name, quantity, avg_price, purchase_price) VALUES (?, ?, ?, ?, ?, ?)", (broker, ticker, name, quantity, avg_price, purchase_price), ) return cur.lastrowid def get_all_portfolio() -> List[Dict[str, Any]]: with _conn() as conn: rows = conn.execute("SELECT * FROM portfolio ORDER BY id").fetchall() return [dict(r) for r in rows] def get_portfolio_item(item_id: int) -> Dict[str, Any] | None: with _conn() as conn: row = conn.execute("SELECT * FROM portfolio WHERE id = ?", (item_id,)).fetchone() return dict(row) if row else None def update_portfolio_item(item_id: int, **kwargs) -> bool: allowed = {"broker", "ticker", "name", "quantity", "avg_price", "purchase_price"} fields = {k: v for k, v in kwargs.items() if k in allowed and v is not None} if not fields: return False fields["updated_at"] = __import__("datetime").datetime.now().strftime("%Y-%m-%d %H:%M:%S") set_clause = ", ".join(f"{k} = ?" for k in fields) values = list(fields.values()) + [item_id] with _conn() as conn: cur = conn.execute(f"UPDATE portfolio SET {set_clause} WHERE id = ?", values) return cur.rowcount > 0 def delete_portfolio_item(item_id: int) -> bool: with _conn() as conn: cur = conn.execute("DELETE FROM portfolio WHERE id = ?", (item_id,)) return cur.rowcount > 0 # --- Broker Cash CRUD --- def upsert_broker_cash(broker: str, cash: int) -> None: now = __import__("datetime").datetime.now().strftime("%Y-%m-%d %H:%M:%S") with _conn() as conn: conn.execute(""" INSERT INTO broker_cash (broker, cash, updated_at) VALUES (?, ?, ?) ON CONFLICT(broker) DO UPDATE SET cash = excluded.cash, updated_at = excluded.updated_at """, (broker, cash, now)) def get_all_broker_cash() -> List[Dict[str, Any]]: with _conn() as conn: rows = conn.execute("SELECT * FROM broker_cash ORDER BY broker").fetchall() return [dict(r) for r in rows] def delete_broker_cash(broker: str) -> bool: with _conn() as conn: cur = conn.execute("DELETE FROM broker_cash WHERE broker = ?", (broker,)) return cur.rowcount > 0 # --- Asset Snapshot CRUD --- def upsert_asset_snapshot(date: str, total_eval: int, total_cash: int, total_assets: int) -> None: now = __import__("datetime").datetime.now().strftime("%Y-%m-%d %H:%M:%S") with _conn() as conn: conn.execute(""" INSERT INTO asset_snapshots (date, total_eval, total_cash, total_assets, created_at) VALUES (?, ?, ?, ?, ?) ON CONFLICT(date) DO UPDATE SET total_eval = excluded.total_eval, total_cash = excluded.total_cash, total_assets = excluded.total_assets, created_at = excluded.created_at """, (date, total_eval, total_cash, total_assets, now)) # --- Sell History CRUD --- def add_sell_history(data: Dict[str, Any]) -> Dict[str, Any]: with _conn() as conn: cur = conn.execute(""" INSERT INTO sell_history (broker, ticker, name, quantity, avg_price, sell_price, commission, buy_amount, sell_amount, realized_profit, realized_rate, sold_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( data["broker"], data["ticker"], data["name"], data["quantity"], data["avg_price"], data["sell_price"], data.get("commission", 0), data["buy_amount"], data["sell_amount"], data["realized_profit"], data["realized_rate"], data["sold_at"], )) row = conn.execute("SELECT * FROM sell_history WHERE id = ?", (cur.lastrowid,)).fetchone() return dict(row) def get_sell_history(broker: str = None, days: int = None) -> List[Dict[str, Any]]: conditions = [] params = [] if broker: conditions.append("broker = ?") params.append(broker) if days: conditions.append("sold_at >= datetime('now', ? || ' days')") params.append(f"-{days}") where = f"WHERE {' AND '.join(conditions)}" if conditions else "" with _conn() as conn: rows = conn.execute( f"SELECT * FROM sell_history {where} ORDER BY sold_at DESC", params, ).fetchall() return [dict(r) for r in rows] def update_sell_history(record_id: int, data: Dict[str, Any]) -> Dict[str, Any] | None: fields = ["broker", "ticker", "name", "quantity", "avg_price", "sell_price", "commission", "buy_amount", "sell_amount", "realized_profit", "realized_rate", "sold_at"] set_clause = ", ".join(f"{f} = ?" for f in fields) values = [data.get(f, 0) if f == "commission" else data[f] for f in fields] + [record_id] with _conn() as conn: cur = conn.execute(f"UPDATE sell_history SET {set_clause} WHERE id = ?", values) if cur.rowcount == 0: return None row = conn.execute("SELECT * FROM sell_history WHERE id = ?", (record_id,)).fetchone() return dict(row) def delete_sell_history(record_id: int) -> bool: with _conn() as conn: cur = conn.execute("DELETE FROM sell_history WHERE id = ?", (record_id,)) return cur.rowcount > 0 def get_asset_snapshots(days: int = 30) -> List[Dict[str, Any]]: with _conn() as conn: if days == 0: rows = conn.execute( "SELECT date, total_eval, total_cash, total_assets FROM asset_snapshots ORDER BY date ASC" ).fetchall() else: rows = conn.execute( "SELECT date, total_eval, total_cash, total_assets FROM asset_snapshots ORDER BY date DESC LIMIT ?", (days,) ).fetchall() rows = list(reversed(rows)) return [dict(r) for r in rows]