feat(saju-lab): interpret/pipeline.py — Claude 호출 + reroll 1회 (8 tests)

This commit is contained in:
2026-05-25 20:23:38 +09:00
parent f995f8739f
commit 6d752acbe1
2 changed files with 299 additions and 0 deletions

View File

@@ -0,0 +1,145 @@
"""사주 + 궁합 해석 파이프라인 — Claude Sonnet 호출 + reroll 1회.
tarot-lab/app/pipeline.py 패턴 재활용.
"""
import json
import logging
import time
from typing import Any, Dict
import httpx
from ..config import (
ANTHROPIC_API_KEY,
SAJU_MODEL,
SAJU_COST_INPUT_PER_M,
SAJU_COST_OUTPUT_PER_M,
SAJU_TIMEOUT_SEC,
)
logger = logging.getLogger("saju-lab.pipeline")
from .prompt import SAJU_SYSTEM_PROMPT, COMPAT_SYSTEM_PROMPT, build_saju_user_message, build_compat_user_message
from .schema import validate_saju_interpretation, validate_compat_interpretation
API_URL = "https://api.anthropic.com/v1/messages"
class SajuError(Exception):
pass
def calc_cost(tokens_in: int, tokens_out: int) -> float:
return (
tokens_in / 1_000_000 * SAJU_COST_INPUT_PER_M
+ tokens_out / 1_000_000 * SAJU_COST_OUTPUT_PER_M
)
def _strip_codeblock(text: str) -> str:
t = text.strip()
if t.startswith("```"):
t = t.strip("`")
if t.startswith("json"):
t = t[4:]
t = t.strip()
return t
def _extract_json(raw: str) -> dict:
cleaned = _strip_codeblock(raw)
try:
return json.loads(cleaned)
except json.JSONDecodeError:
start, end = cleaned.find("{"), cleaned.rfind("}")
if start >= 0 and end > start:
try:
return json.loads(cleaned[start : end + 1])
except json.JSONDecodeError:
pass
raise
async def _call_claude(system_prompt: str, user_text: str, feedback: str = "") -> tuple[dict, dict, str]:
if not ANTHROPIC_API_KEY:
raise SajuError("ANTHROPIC_API_KEY missing")
if feedback:
user_text = f"이전 응답이 다음 이유로 거절됨: {feedback}\n올바른 스키마(시스템 지침)로 다시 응답.\n\n{user_text}"
payload = {
"model": SAJU_MODEL,
"max_tokens": 2400,
"system": [{"type": "text", "text": system_prompt,
"cache_control": {"type": "ephemeral"}}],
"messages": [{"role": "user", "content": [{"type": "text", "text": user_text}]}],
}
headers = {
"x-api-key": ANTHROPIC_API_KEY,
"anthropic-version": "2023-06-01",
"anthropic-beta": "prompt-caching-2024-07-31",
"content-type": "application/json",
}
started = time.monotonic()
async with httpx.AsyncClient(timeout=SAJU_TIMEOUT_SEC) as client:
r = await client.post(API_URL, headers=headers, json=payload)
r.raise_for_status()
resp = r.json()
latency_ms = int((time.monotonic() - started) * 1000)
raw_text = "".join(
b.get("text", "") for b in resp.get("content", []) if b.get("type") == "text"
)
usage = resp.get("usage", {}) or {}
tokens_in = int(usage.get("input_tokens", 0) or 0)
tokens_out = int(usage.get("output_tokens", 0) or 0)
logger.info("saju claude call: latency=%dms, in=%d, out=%d", latency_ms, tokens_in, tokens_out)
parsed = _extract_json(raw_text)
meta = {"tokens_in": tokens_in, "tokens_out": tokens_out, "latency_ms": latency_ms}
return parsed, meta, raw_text
async def _interpret_loop(system_prompt: str, user_text: str, validate_fn) -> Dict[str, Any]:
total_in, total_out, total_latency = 0, 0, 0
last_error = ""
for attempt in range(2):
try:
parsed, meta, _raw = await _call_claude(system_prompt, user_text, feedback=last_error)
except httpx.HTTPError as e:
raise SajuError(f"Claude HTTP error: {e}") from e
except json.JSONDecodeError as e:
last_error = f"JSON 파싱 실패: {e}"
continue
total_in += meta["tokens_in"]
total_out += meta["tokens_out"]
total_latency += meta["latency_ms"]
ok, err = validate_fn(parsed)
if ok:
return {
"interpretation_json": parsed,
"model": SAJU_MODEL,
"tokens_in": total_in,
"tokens_out": total_out,
"cost_usd": calc_cost(total_in, total_out),
"latency_ms": total_latency,
"reroll_count": attempt,
}
last_error = err
raise SajuError(f"검증 실패 (reroll 2회): {last_error}")
async def interpret_saju(saju: dict, analysis: dict, daeun: list, current_year: int) -> Dict[str, Any]:
user_text = build_saju_user_message(saju, analysis, daeun, current_year)
return await _interpret_loop(SAJU_SYSTEM_PROMPT, user_text, validate_saju_interpretation)
async def interpret_compat(
saju_a: dict, saju_b: dict,
analysis_a: dict, analysis_b: dict,
score: int, breakdown: dict,
) -> Dict[str, Any]:
user_text = build_compat_user_message(
saju_a, saju_b, analysis_a, analysis_b, score, breakdown,
)
return await _interpret_loop(COMPAT_SYSTEM_PROMPT, user_text, validate_compat_interpretation)