From 6d752acbe10cd13a154f3d3ce3a107307e9eba7a Mon Sep 17 00:00:00 2001 From: gahusb Date: Mon, 25 May 2026 20:23:38 +0900 Subject: [PATCH] =?UTF-8?q?feat(saju-lab):=20interpret/pipeline.py=20?= =?UTF-8?q?=E2=80=94=20Claude=20=ED=98=B8=EC=B6=9C=20+=20reroll=201?= =?UTF-8?q?=ED=9A=8C=20(8=20tests)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- saju-lab/app/interpret/pipeline.py | 145 +++++++++++++++++++++++++++ saju-lab/tests/test_pipeline.py | 154 +++++++++++++++++++++++++++++ 2 files changed, 299 insertions(+) create mode 100644 saju-lab/app/interpret/pipeline.py create mode 100644 saju-lab/tests/test_pipeline.py diff --git a/saju-lab/app/interpret/pipeline.py b/saju-lab/app/interpret/pipeline.py new file mode 100644 index 0000000..0211366 --- /dev/null +++ b/saju-lab/app/interpret/pipeline.py @@ -0,0 +1,145 @@ +"""사주 + 궁합 해석 파이프라인 — Claude Sonnet 호출 + reroll 1회. + +tarot-lab/app/pipeline.py 패턴 재활용. +""" +import json +import logging +import time +from typing import Any, Dict + +import httpx + +from ..config import ( + ANTHROPIC_API_KEY, + SAJU_MODEL, + SAJU_COST_INPUT_PER_M, + SAJU_COST_OUTPUT_PER_M, + SAJU_TIMEOUT_SEC, +) + + +logger = logging.getLogger("saju-lab.pipeline") + +from .prompt import SAJU_SYSTEM_PROMPT, COMPAT_SYSTEM_PROMPT, build_saju_user_message, build_compat_user_message +from .schema import validate_saju_interpretation, validate_compat_interpretation + + +API_URL = "https://api.anthropic.com/v1/messages" + + +class SajuError(Exception): + pass + + +def calc_cost(tokens_in: int, tokens_out: int) -> float: + return ( + tokens_in / 1_000_000 * SAJU_COST_INPUT_PER_M + + tokens_out / 1_000_000 * SAJU_COST_OUTPUT_PER_M + ) + + +def _strip_codeblock(text: str) -> str: + t = text.strip() + if t.startswith("```"): + t = t.strip("`") + if t.startswith("json"): + t = t[4:] + t = t.strip() + return t + + +def _extract_json(raw: str) -> dict: + cleaned = _strip_codeblock(raw) + try: + return json.loads(cleaned) + except json.JSONDecodeError: + start, end = cleaned.find("{"), cleaned.rfind("}") + if start >= 0 and end > start: + try: + return json.loads(cleaned[start : end + 1]) + except json.JSONDecodeError: + pass + raise + + +async def _call_claude(system_prompt: str, user_text: str, feedback: str = "") -> tuple[dict, dict, str]: + if not ANTHROPIC_API_KEY: + raise SajuError("ANTHROPIC_API_KEY missing") + if feedback: + user_text = f"이전 응답이 다음 이유로 거절됨: {feedback}\n올바른 스키마(시스템 지침)로 다시 응답.\n\n{user_text}" + payload = { + "model": SAJU_MODEL, + "max_tokens": 2400, + "system": [{"type": "text", "text": system_prompt, + "cache_control": {"type": "ephemeral"}}], + "messages": [{"role": "user", "content": [{"type": "text", "text": user_text}]}], + } + headers = { + "x-api-key": ANTHROPIC_API_KEY, + "anthropic-version": "2023-06-01", + "anthropic-beta": "prompt-caching-2024-07-31", + "content-type": "application/json", + } + started = time.monotonic() + async with httpx.AsyncClient(timeout=SAJU_TIMEOUT_SEC) as client: + r = await client.post(API_URL, headers=headers, json=payload) + r.raise_for_status() + resp = r.json() + latency_ms = int((time.monotonic() - started) * 1000) + raw_text = "".join( + b.get("text", "") for b in resp.get("content", []) if b.get("type") == "text" + ) + usage = resp.get("usage", {}) or {} + tokens_in = int(usage.get("input_tokens", 0) or 0) + tokens_out = int(usage.get("output_tokens", 0) or 0) + logger.info("saju claude call: latency=%dms, in=%d, out=%d", latency_ms, tokens_in, tokens_out) + parsed = _extract_json(raw_text) + meta = {"tokens_in": tokens_in, "tokens_out": tokens_out, "latency_ms": latency_ms} + return parsed, meta, raw_text + + +async def _interpret_loop(system_prompt: str, user_text: str, validate_fn) -> Dict[str, Any]: + total_in, total_out, total_latency = 0, 0, 0 + last_error = "" + for attempt in range(2): + try: + parsed, meta, _raw = await _call_claude(system_prompt, user_text, feedback=last_error) + except httpx.HTTPError as e: + raise SajuError(f"Claude HTTP error: {e}") from e + except json.JSONDecodeError as e: + last_error = f"JSON 파싱 실패: {e}" + continue + total_in += meta["tokens_in"] + total_out += meta["tokens_out"] + total_latency += meta["latency_ms"] + + ok, err = validate_fn(parsed) + if ok: + return { + "interpretation_json": parsed, + "model": SAJU_MODEL, + "tokens_in": total_in, + "tokens_out": total_out, + "cost_usd": calc_cost(total_in, total_out), + "latency_ms": total_latency, + "reroll_count": attempt, + } + last_error = err + + raise SajuError(f"검증 실패 (reroll 2회): {last_error}") + + +async def interpret_saju(saju: dict, analysis: dict, daeun: list, current_year: int) -> Dict[str, Any]: + user_text = build_saju_user_message(saju, analysis, daeun, current_year) + return await _interpret_loop(SAJU_SYSTEM_PROMPT, user_text, validate_saju_interpretation) + + +async def interpret_compat( + saju_a: dict, saju_b: dict, + analysis_a: dict, analysis_b: dict, + score: int, breakdown: dict, +) -> Dict[str, Any]: + user_text = build_compat_user_message( + saju_a, saju_b, analysis_a, analysis_b, score, breakdown, + ) + return await _interpret_loop(COMPAT_SYSTEM_PROMPT, user_text, validate_compat_interpretation) diff --git a/saju-lab/tests/test_pipeline.py b/saju-lab/tests/test_pipeline.py new file mode 100644 index 0000000..e45d573 --- /dev/null +++ b/saju-lab/tests/test_pipeline.py @@ -0,0 +1,154 @@ +import json +import pytest +import respx +import httpx + +from app.interpret import pipeline +from app.interpret.pipeline import SajuError + + +@pytest.fixture(autouse=True) +def _patch_key(monkeypatch): + monkeypatch.setattr(pipeline, "ANTHROPIC_API_KEY", "test-key") + + +SAJU_ITEM_KEYS = [ + "기질", "오행밸런스", "지지상호작용", "신살영향", + "재물운", "직업적성", "애정운", "건강운", + "현재대운", "올해세운", "인생황금기", "종합조언", +] + + +def _valid_saju_response(): + items = [] + for k in SAJU_ITEM_KEYS: + items.append({ + "key": k, "title": "...", "content": "...", + "evidence": {"saju_element": "...", "reasoning": "..."} + }) + return { + "items": items, + "summary": "...", + "advice": "...", + "warning": None, + "confidence": "medium", + } + + +def _valid_compat_response(): + return { + "summary": "...", + "strengths": [ + {"title": "오행 상생", "explanation": "...", "evidence": "..."}, + {"title": "...", "explanation": "...", "evidence": "..."}, + ], + "challenges": [ + {"title": "...", "explanation": "...", "evidence": "..."}, + {"title": "...", "explanation": "...", "evidence": "..."}, + ], + "advice": "...", + "warning": None, + "confidence": "high", + } + + +def _claude_envelope(text: str, in_tok=200, out_tok=400): + return { + "content": [{"type": "text", "text": text}], + "usage": {"input_tokens": in_tok, "output_tokens": out_tok}, + } + + +@respx.mock +async def test_saju_interpret_success(): + respx.post("https://api.anthropic.com/v1/messages").mock( + return_value=httpx.Response(200, json=_claude_envelope(json.dumps(_valid_saju_response()))) + ) + result = await pipeline.interpret_saju( + saju={"day_stem": "辛"}, + analysis={"element_balance": {"金": 3.0}}, + daeun=[{"age": 10}], + current_year=2026, + ) + assert result["reroll_count"] == 0 + assert result["tokens_in"] == 200 + assert result["cost_usd"] > 0 + + +@respx.mock +async def test_saju_codeblock_stripped(): + text = "```json\n" + json.dumps(_valid_saju_response()) + "\n```" + respx.post("https://api.anthropic.com/v1/messages").mock( + return_value=httpx.Response(200, json=_claude_envelope(text)) + ) + result = await pipeline.interpret_saju(saju={}, analysis={}, daeun=[], current_year=2026) + assert "interpretation_json" in result + + +@respx.mock +async def test_saju_reroll_then_success(): + valid = json.dumps(_valid_saju_response()) + invalid = json.dumps({"items": [], "summary": "...", "advice": "", "confidence": "medium"}) + respx.post("https://api.anthropic.com/v1/messages").mock( + side_effect=[ + httpx.Response(200, json=_claude_envelope(invalid)), + httpx.Response(200, json=_claude_envelope(valid)), + ] + ) + result = await pipeline.interpret_saju(saju={}, analysis={}, daeun=[], current_year=2026) + assert result["reroll_count"] == 1 + + +@respx.mock +async def test_saju_reroll_fail_raises(): + invalid = json.dumps({"items": [], "summary": "...", "advice": "", "confidence": "medium"}) + respx.post("https://api.anthropic.com/v1/messages").mock( + return_value=httpx.Response(200, json=_claude_envelope(invalid)) + ) + with pytest.raises(SajuError): + await pipeline.interpret_saju(saju={}, analysis={}, daeun=[], current_year=2026) + + +@respx.mock +async def test_saju_http_error(): + respx.post("https://api.anthropic.com/v1/messages").mock( + return_value=httpx.Response(500, text="boom") + ) + with pytest.raises(SajuError): + await pipeline.interpret_saju(saju={}, analysis={}, daeun=[], current_year=2026) + + +@respx.mock +async def test_compat_interpret_success(): + respx.post("https://api.anthropic.com/v1/messages").mock( + return_value=httpx.Response(200, json=_claude_envelope(json.dumps(_valid_compat_response()))) + ) + result = await pipeline.interpret_compat( + saju_a={"day_stem": "辛"}, saju_b={"day_stem": "丁"}, + analysis_a={}, analysis_b={}, + score=85, breakdown={"day_master_element": {"score": 25}}, + ) + assert result["reroll_count"] == 0 + assert "interpretation_json" in result + + +@respx.mock +async def test_compat_reroll_then_success(): + valid = json.dumps(_valid_compat_response()) + invalid = json.dumps({"summary": "...", "strengths": [], "challenges": [], "advice": "", "confidence": "high"}) + respx.post("https://api.anthropic.com/v1/messages").mock( + side_effect=[ + httpx.Response(200, json=_claude_envelope(invalid)), + httpx.Response(200, json=_claude_envelope(valid)), + ] + ) + result = await pipeline.interpret_compat( + saju_a={}, saju_b={}, analysis_a={}, analysis_b={}, + score=50, breakdown={}, + ) + assert result["reroll_count"] == 1 + + +def test_calc_cost(): + cost = pipeline.calc_cost(1_000_000, 1_000_000) + assert cost == pipeline.SAJU_COST_INPUT_PER_M + pipeline.SAJU_COST_OUTPUT_PER_M