204 lines
7.5 KiB
Python
204 lines
7.5 KiB
Python
import sqlite3
|
|
import os
|
|
import hashlib
|
|
from typing import List, Dict, Any
|
|
|
|
DB_PATH = "/app/data/stock.db"
|
|
|
|
def _conn() -> sqlite3.Connection:
|
|
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
|
|
conn = sqlite3.connect(DB_PATH)
|
|
conn.row_factory = sqlite3.Row
|
|
return conn
|
|
|
|
def init_db():
|
|
with _conn() as conn:
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS articles (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
hash TEXT UNIQUE NOT NULL,
|
|
category TEXT DEFAULT 'domestic',
|
|
title TEXT NOT NULL,
|
|
link TEXT,
|
|
summary TEXT,
|
|
press TEXT,
|
|
pub_date TEXT,
|
|
crawled_at TEXT
|
|
)
|
|
""")
|
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_articles_crawled ON articles(crawled_at DESC)")
|
|
|
|
# 컬럼 추가 (기존 테이블 마이그레이션)
|
|
cols = {r["name"] for r in conn.execute("PRAGMA table_info(articles)").fetchall()}
|
|
if "category" not in cols:
|
|
conn.execute("ALTER TABLE articles ADD COLUMN category TEXT DEFAULT 'domestic'")
|
|
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS portfolio (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
broker TEXT NOT NULL,
|
|
ticker TEXT NOT NULL,
|
|
name TEXT NOT NULL,
|
|
quantity INTEGER NOT NULL,
|
|
avg_price INTEGER NOT NULL,
|
|
created_at TEXT DEFAULT (datetime('now','localtime')),
|
|
updated_at TEXT DEFAULT (datetime('now','localtime'))
|
|
)
|
|
""")
|
|
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS broker_cash (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
broker TEXT UNIQUE NOT NULL,
|
|
cash INTEGER NOT NULL DEFAULT 0,
|
|
updated_at TEXT DEFAULT (datetime('now','localtime'))
|
|
)
|
|
""")
|
|
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS asset_snapshots (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
date TEXT UNIQUE NOT NULL,
|
|
total_eval INTEGER NOT NULL,
|
|
total_cash INTEGER NOT NULL,
|
|
total_assets INTEGER NOT NULL,
|
|
created_at TEXT DEFAULT (datetime('now','localtime'))
|
|
)
|
|
""")
|
|
|
|
def save_articles(articles: List[Dict[str, str]]) -> int:
|
|
count = 0
|
|
with _conn() as conn:
|
|
for a in articles:
|
|
# 중복 체크용 해시 (제목+링크)
|
|
unique_str = f"{a['title']}|{a['link']}"
|
|
h = hashlib.md5(unique_str.encode()).hexdigest()
|
|
|
|
try:
|
|
cat = a.get("category", "domestic")
|
|
conn.execute("""
|
|
INSERT INTO articles (hash, category, title, link, summary, press, pub_date, crawled_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
""", (h, cat, a['title'], a['link'], a['summary'], a['press'], a['date'], a['crawled_at']))
|
|
count += 1
|
|
except sqlite3.IntegrityError:
|
|
pass # 이미 존재함
|
|
return count
|
|
|
|
def get_latest_articles(limit: int = 20, category: str = None) -> List[Dict[str, Any]]:
|
|
with _conn() as conn:
|
|
if category:
|
|
rows = conn.execute(
|
|
"SELECT * FROM articles WHERE category = ? ORDER BY crawled_at DESC, id DESC LIMIT ?",
|
|
(category, limit)
|
|
).fetchall()
|
|
else:
|
|
rows = conn.execute(
|
|
"SELECT * FROM articles ORDER BY crawled_at DESC, id DESC LIMIT ?",
|
|
(limit,)
|
|
).fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
|
|
# --- Portfolio CRUD ---
|
|
|
|
def add_portfolio_item(broker: str, ticker: str, name: str, quantity: int, avg_price: int) -> int:
|
|
with _conn() as conn:
|
|
cur = conn.execute(
|
|
"INSERT INTO portfolio (broker, ticker, name, quantity, avg_price) VALUES (?, ?, ?, ?, ?)",
|
|
(broker, ticker, name, quantity, avg_price),
|
|
)
|
|
return cur.lastrowid
|
|
|
|
|
|
def get_all_portfolio() -> List[Dict[str, Any]]:
|
|
with _conn() as conn:
|
|
rows = conn.execute("SELECT * FROM portfolio ORDER BY id").fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
|
|
def get_portfolio_item(item_id: int) -> Dict[str, Any] | None:
|
|
with _conn() as conn:
|
|
row = conn.execute("SELECT * FROM portfolio WHERE id = ?", (item_id,)).fetchone()
|
|
return dict(row) if row else None
|
|
|
|
|
|
def update_portfolio_item(item_id: int, **kwargs) -> bool:
|
|
allowed = {"broker", "ticker", "name", "quantity", "avg_price"}
|
|
fields = {k: v for k, v in kwargs.items() if k in allowed and v is not None}
|
|
if not fields:
|
|
return False
|
|
fields["updated_at"] = __import__("datetime").datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
set_clause = ", ".join(f"{k} = ?" for k in fields)
|
|
values = list(fields.values()) + [item_id]
|
|
with _conn() as conn:
|
|
cur = conn.execute(f"UPDATE portfolio SET {set_clause} WHERE id = ?", values)
|
|
return cur.rowcount > 0
|
|
|
|
|
|
def delete_portfolio_item(item_id: int) -> bool:
|
|
with _conn() as conn:
|
|
cur = conn.execute("DELETE FROM portfolio WHERE id = ?", (item_id,))
|
|
return cur.rowcount > 0
|
|
|
|
|
|
# --- Broker Cash CRUD ---
|
|
|
|
def upsert_broker_cash(broker: str, cash: int) -> None:
|
|
now = __import__("datetime").datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
with _conn() as conn:
|
|
conn.execute("""
|
|
INSERT INTO broker_cash (broker, cash, updated_at)
|
|
VALUES (?, ?, ?)
|
|
ON CONFLICT(broker) DO UPDATE SET cash = excluded.cash, updated_at = excluded.updated_at
|
|
""", (broker, cash, now))
|
|
|
|
|
|
def get_all_broker_cash() -> List[Dict[str, Any]]:
|
|
with _conn() as conn:
|
|
rows = conn.execute("SELECT * FROM broker_cash ORDER BY broker").fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
|
|
def get_broker_cash(broker: str) -> Dict[str, Any] | None:
|
|
with _conn() as conn:
|
|
row = conn.execute("SELECT * FROM broker_cash WHERE broker = ?", (broker,)).fetchone()
|
|
return dict(row) if row else None
|
|
|
|
|
|
def delete_broker_cash(broker: str) -> bool:
|
|
with _conn() as conn:
|
|
cur = conn.execute("DELETE FROM broker_cash WHERE broker = ?", (broker,))
|
|
return cur.rowcount > 0
|
|
|
|
|
|
# --- Asset Snapshot CRUD ---
|
|
|
|
def upsert_asset_snapshot(date: str, total_eval: int, total_cash: int, total_assets: int) -> None:
|
|
now = __import__("datetime").datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
with _conn() as conn:
|
|
conn.execute("""
|
|
INSERT INTO asset_snapshots (date, total_eval, total_cash, total_assets, created_at)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
ON CONFLICT(date) DO UPDATE SET
|
|
total_eval = excluded.total_eval,
|
|
total_cash = excluded.total_cash,
|
|
total_assets = excluded.total_assets,
|
|
created_at = excluded.created_at
|
|
""", (date, total_eval, total_cash, total_assets, now))
|
|
|
|
|
|
def get_asset_snapshots(days: int = 30) -> List[Dict[str, Any]]:
|
|
with _conn() as conn:
|
|
if days == 0:
|
|
rows = conn.execute(
|
|
"SELECT date, total_eval, total_cash, total_assets FROM asset_snapshots ORDER BY date ASC"
|
|
).fetchall()
|
|
else:
|
|
rows = conn.execute(
|
|
"SELECT date, total_eval, total_cash, total_assets FROM asset_snapshots ORDER BY date DESC LIMIT ?",
|
|
(days,)
|
|
).fetchall()
|
|
rows = list(reversed(rows))
|
|
return [dict(r) for r in rows]
|