feat(saju-lab): interpret/pipeline.py — Claude 호출 + reroll 1회 (8 tests)
This commit is contained in:
145
saju-lab/app/interpret/pipeline.py
Normal file
145
saju-lab/app/interpret/pipeline.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
"""사주 + 궁합 해석 파이프라인 — Claude Sonnet 호출 + reroll 1회.
|
||||||
|
|
||||||
|
tarot-lab/app/pipeline.py 패턴 재활용.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from ..config import (
|
||||||
|
ANTHROPIC_API_KEY,
|
||||||
|
SAJU_MODEL,
|
||||||
|
SAJU_COST_INPUT_PER_M,
|
||||||
|
SAJU_COST_OUTPUT_PER_M,
|
||||||
|
SAJU_TIMEOUT_SEC,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger("saju-lab.pipeline")
|
||||||
|
|
||||||
|
from .prompt import SAJU_SYSTEM_PROMPT, COMPAT_SYSTEM_PROMPT, build_saju_user_message, build_compat_user_message
|
||||||
|
from .schema import validate_saju_interpretation, validate_compat_interpretation
|
||||||
|
|
||||||
|
|
||||||
|
API_URL = "https://api.anthropic.com/v1/messages"
|
||||||
|
|
||||||
|
|
||||||
|
class SajuError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def calc_cost(tokens_in: int, tokens_out: int) -> float:
|
||||||
|
return (
|
||||||
|
tokens_in / 1_000_000 * SAJU_COST_INPUT_PER_M
|
||||||
|
+ tokens_out / 1_000_000 * SAJU_COST_OUTPUT_PER_M
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_codeblock(text: str) -> str:
|
||||||
|
t = text.strip()
|
||||||
|
if t.startswith("```"):
|
||||||
|
t = t.strip("`")
|
||||||
|
if t.startswith("json"):
|
||||||
|
t = t[4:]
|
||||||
|
t = t.strip()
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_json(raw: str) -> dict:
|
||||||
|
cleaned = _strip_codeblock(raw)
|
||||||
|
try:
|
||||||
|
return json.loads(cleaned)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
start, end = cleaned.find("{"), cleaned.rfind("}")
|
||||||
|
if start >= 0 and end > start:
|
||||||
|
try:
|
||||||
|
return json.loads(cleaned[start : end + 1])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def _call_claude(system_prompt: str, user_text: str, feedback: str = "") -> tuple[dict, dict, str]:
|
||||||
|
if not ANTHROPIC_API_KEY:
|
||||||
|
raise SajuError("ANTHROPIC_API_KEY missing")
|
||||||
|
if feedback:
|
||||||
|
user_text = f"이전 응답이 다음 이유로 거절됨: {feedback}\n올바른 스키마(시스템 지침)로 다시 응답.\n\n{user_text}"
|
||||||
|
payload = {
|
||||||
|
"model": SAJU_MODEL,
|
||||||
|
"max_tokens": 2400,
|
||||||
|
"system": [{"type": "text", "text": system_prompt,
|
||||||
|
"cache_control": {"type": "ephemeral"}}],
|
||||||
|
"messages": [{"role": "user", "content": [{"type": "text", "text": user_text}]}],
|
||||||
|
}
|
||||||
|
headers = {
|
||||||
|
"x-api-key": ANTHROPIC_API_KEY,
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
"anthropic-beta": "prompt-caching-2024-07-31",
|
||||||
|
"content-type": "application/json",
|
||||||
|
}
|
||||||
|
started = time.monotonic()
|
||||||
|
async with httpx.AsyncClient(timeout=SAJU_TIMEOUT_SEC) as client:
|
||||||
|
r = await client.post(API_URL, headers=headers, json=payload)
|
||||||
|
r.raise_for_status()
|
||||||
|
resp = r.json()
|
||||||
|
latency_ms = int((time.monotonic() - started) * 1000)
|
||||||
|
raw_text = "".join(
|
||||||
|
b.get("text", "") for b in resp.get("content", []) if b.get("type") == "text"
|
||||||
|
)
|
||||||
|
usage = resp.get("usage", {}) or {}
|
||||||
|
tokens_in = int(usage.get("input_tokens", 0) or 0)
|
||||||
|
tokens_out = int(usage.get("output_tokens", 0) or 0)
|
||||||
|
logger.info("saju claude call: latency=%dms, in=%d, out=%d", latency_ms, tokens_in, tokens_out)
|
||||||
|
parsed = _extract_json(raw_text)
|
||||||
|
meta = {"tokens_in": tokens_in, "tokens_out": tokens_out, "latency_ms": latency_ms}
|
||||||
|
return parsed, meta, raw_text
|
||||||
|
|
||||||
|
|
||||||
|
async def _interpret_loop(system_prompt: str, user_text: str, validate_fn) -> Dict[str, Any]:
|
||||||
|
total_in, total_out, total_latency = 0, 0, 0
|
||||||
|
last_error = ""
|
||||||
|
for attempt in range(2):
|
||||||
|
try:
|
||||||
|
parsed, meta, _raw = await _call_claude(system_prompt, user_text, feedback=last_error)
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
raise SajuError(f"Claude HTTP error: {e}") from e
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
last_error = f"JSON 파싱 실패: {e}"
|
||||||
|
continue
|
||||||
|
total_in += meta["tokens_in"]
|
||||||
|
total_out += meta["tokens_out"]
|
||||||
|
total_latency += meta["latency_ms"]
|
||||||
|
|
||||||
|
ok, err = validate_fn(parsed)
|
||||||
|
if ok:
|
||||||
|
return {
|
||||||
|
"interpretation_json": parsed,
|
||||||
|
"model": SAJU_MODEL,
|
||||||
|
"tokens_in": total_in,
|
||||||
|
"tokens_out": total_out,
|
||||||
|
"cost_usd": calc_cost(total_in, total_out),
|
||||||
|
"latency_ms": total_latency,
|
||||||
|
"reroll_count": attempt,
|
||||||
|
}
|
||||||
|
last_error = err
|
||||||
|
|
||||||
|
raise SajuError(f"검증 실패 (reroll 2회): {last_error}")
|
||||||
|
|
||||||
|
|
||||||
|
async def interpret_saju(saju: dict, analysis: dict, daeun: list, current_year: int) -> Dict[str, Any]:
|
||||||
|
user_text = build_saju_user_message(saju, analysis, daeun, current_year)
|
||||||
|
return await _interpret_loop(SAJU_SYSTEM_PROMPT, user_text, validate_saju_interpretation)
|
||||||
|
|
||||||
|
|
||||||
|
async def interpret_compat(
|
||||||
|
saju_a: dict, saju_b: dict,
|
||||||
|
analysis_a: dict, analysis_b: dict,
|
||||||
|
score: int, breakdown: dict,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
user_text = build_compat_user_message(
|
||||||
|
saju_a, saju_b, analysis_a, analysis_b, score, breakdown,
|
||||||
|
)
|
||||||
|
return await _interpret_loop(COMPAT_SYSTEM_PROMPT, user_text, validate_compat_interpretation)
|
||||||
154
saju-lab/tests/test_pipeline.py
Normal file
154
saju-lab/tests/test_pipeline.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
import respx
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.interpret import pipeline
|
||||||
|
from app.interpret.pipeline import SajuError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _patch_key(monkeypatch):
|
||||||
|
monkeypatch.setattr(pipeline, "ANTHROPIC_API_KEY", "test-key")
|
||||||
|
|
||||||
|
|
||||||
|
SAJU_ITEM_KEYS = [
|
||||||
|
"기질", "오행밸런스", "지지상호작용", "신살영향",
|
||||||
|
"재물운", "직업적성", "애정운", "건강운",
|
||||||
|
"현재대운", "올해세운", "인생황금기", "종합조언",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _valid_saju_response():
|
||||||
|
items = []
|
||||||
|
for k in SAJU_ITEM_KEYS:
|
||||||
|
items.append({
|
||||||
|
"key": k, "title": "...", "content": "...",
|
||||||
|
"evidence": {"saju_element": "...", "reasoning": "..."}
|
||||||
|
})
|
||||||
|
return {
|
||||||
|
"items": items,
|
||||||
|
"summary": "...",
|
||||||
|
"advice": "...",
|
||||||
|
"warning": None,
|
||||||
|
"confidence": "medium",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _valid_compat_response():
|
||||||
|
return {
|
||||||
|
"summary": "...",
|
||||||
|
"strengths": [
|
||||||
|
{"title": "오행 상생", "explanation": "...", "evidence": "..."},
|
||||||
|
{"title": "...", "explanation": "...", "evidence": "..."},
|
||||||
|
],
|
||||||
|
"challenges": [
|
||||||
|
{"title": "...", "explanation": "...", "evidence": "..."},
|
||||||
|
{"title": "...", "explanation": "...", "evidence": "..."},
|
||||||
|
],
|
||||||
|
"advice": "...",
|
||||||
|
"warning": None,
|
||||||
|
"confidence": "high",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _claude_envelope(text: str, in_tok=200, out_tok=400):
|
||||||
|
return {
|
||||||
|
"content": [{"type": "text", "text": text}],
|
||||||
|
"usage": {"input_tokens": in_tok, "output_tokens": out_tok},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_saju_interpret_success():
|
||||||
|
respx.post("https://api.anthropic.com/v1/messages").mock(
|
||||||
|
return_value=httpx.Response(200, json=_claude_envelope(json.dumps(_valid_saju_response())))
|
||||||
|
)
|
||||||
|
result = await pipeline.interpret_saju(
|
||||||
|
saju={"day_stem": "辛"},
|
||||||
|
analysis={"element_balance": {"金": 3.0}},
|
||||||
|
daeun=[{"age": 10}],
|
||||||
|
current_year=2026,
|
||||||
|
)
|
||||||
|
assert result["reroll_count"] == 0
|
||||||
|
assert result["tokens_in"] == 200
|
||||||
|
assert result["cost_usd"] > 0
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_saju_codeblock_stripped():
|
||||||
|
text = "```json\n" + json.dumps(_valid_saju_response()) + "\n```"
|
||||||
|
respx.post("https://api.anthropic.com/v1/messages").mock(
|
||||||
|
return_value=httpx.Response(200, json=_claude_envelope(text))
|
||||||
|
)
|
||||||
|
result = await pipeline.interpret_saju(saju={}, analysis={}, daeun=[], current_year=2026)
|
||||||
|
assert "interpretation_json" in result
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_saju_reroll_then_success():
|
||||||
|
valid = json.dumps(_valid_saju_response())
|
||||||
|
invalid = json.dumps({"items": [], "summary": "...", "advice": "", "confidence": "medium"})
|
||||||
|
respx.post("https://api.anthropic.com/v1/messages").mock(
|
||||||
|
side_effect=[
|
||||||
|
httpx.Response(200, json=_claude_envelope(invalid)),
|
||||||
|
httpx.Response(200, json=_claude_envelope(valid)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
result = await pipeline.interpret_saju(saju={}, analysis={}, daeun=[], current_year=2026)
|
||||||
|
assert result["reroll_count"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_saju_reroll_fail_raises():
|
||||||
|
invalid = json.dumps({"items": [], "summary": "...", "advice": "", "confidence": "medium"})
|
||||||
|
respx.post("https://api.anthropic.com/v1/messages").mock(
|
||||||
|
return_value=httpx.Response(200, json=_claude_envelope(invalid))
|
||||||
|
)
|
||||||
|
with pytest.raises(SajuError):
|
||||||
|
await pipeline.interpret_saju(saju={}, analysis={}, daeun=[], current_year=2026)
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_saju_http_error():
|
||||||
|
respx.post("https://api.anthropic.com/v1/messages").mock(
|
||||||
|
return_value=httpx.Response(500, text="boom")
|
||||||
|
)
|
||||||
|
with pytest.raises(SajuError):
|
||||||
|
await pipeline.interpret_saju(saju={}, analysis={}, daeun=[], current_year=2026)
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_compat_interpret_success():
|
||||||
|
respx.post("https://api.anthropic.com/v1/messages").mock(
|
||||||
|
return_value=httpx.Response(200, json=_claude_envelope(json.dumps(_valid_compat_response())))
|
||||||
|
)
|
||||||
|
result = await pipeline.interpret_compat(
|
||||||
|
saju_a={"day_stem": "辛"}, saju_b={"day_stem": "丁"},
|
||||||
|
analysis_a={}, analysis_b={},
|
||||||
|
score=85, breakdown={"day_master_element": {"score": 25}},
|
||||||
|
)
|
||||||
|
assert result["reroll_count"] == 0
|
||||||
|
assert "interpretation_json" in result
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_compat_reroll_then_success():
|
||||||
|
valid = json.dumps(_valid_compat_response())
|
||||||
|
invalid = json.dumps({"summary": "...", "strengths": [], "challenges": [], "advice": "", "confidence": "high"})
|
||||||
|
respx.post("https://api.anthropic.com/v1/messages").mock(
|
||||||
|
side_effect=[
|
||||||
|
httpx.Response(200, json=_claude_envelope(invalid)),
|
||||||
|
httpx.Response(200, json=_claude_envelope(valid)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
result = await pipeline.interpret_compat(
|
||||||
|
saju_a={}, saju_b={}, analysis_a={}, analysis_b={},
|
||||||
|
score=50, breakdown={},
|
||||||
|
)
|
||||||
|
assert result["reroll_count"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_calc_cost():
|
||||||
|
cost = pipeline.calc_cost(1_000_000, 1_000_000)
|
||||||
|
assert cost == pipeline.SAJU_COST_INPUT_PER_M + pipeline.SAJU_COST_OUTPUT_PER_M
|
||||||
Reference in New Issue
Block a user