import pytest from unittest.mock import AsyncMock, patch from fastapi.testclient import TestClient import app.main as main_module from app import db @pytest.fixture def client(monkeypatch, tmp_path): monkeypatch.setattr(db, "DB_PATH", str(tmp_path / "music.db")) db.init_db() monkeypatch.setenv("SUNO_API_KEY", "test") # main.py의 SUNO_API_KEY 모듈 변수도 갱신 필요할 수 있음 monkeypatch.setattr(main_module, "SUNO_API_KEY", "test", raising=False) return TestClient(main_module.app) def test_create_batch_201(client): with patch.object(main_module, "_run_batch", new=AsyncMock()): r = client.post("/api/music/generate-batch", json={"genre": "lo-fi", "count": 3}) assert r.status_code == 201, r.text body = r.json() assert body["genre"] == "lo-fi" assert body["count"] == 3 assert body["status"] == "queued" def test_create_batch_rejects_count_too_high(client): r = client.post("/api/music/generate-batch", json={"genre": "lo-fi", "count": 11}) assert r.status_code == 400 def test_create_batch_rejects_count_zero(client): r = client.post("/api/music/generate-batch", json={"genre": "lo-fi", "count": 0}) assert r.status_code == 400 def test_create_batch_rejects_no_genre(client): r = client.post("/api/music/generate-batch", json={"count": 3}) assert r.status_code in (400, 422) def test_create_batch_rejects_invalid_duration(client): r = client.post("/api/music/generate-batch", json={"genre": "lo-fi", "count": 3, "target_duration_sec": 30}) assert r.status_code == 400 def test_create_batch_rejects_no_suno_key(client, monkeypatch): monkeypatch.setattr(main_module, "SUNO_API_KEY", "", raising=False) r = client.post("/api/music/generate-batch", json={"genre": "lo-fi", "count": 3}) assert r.status_code == 400 def test_get_batch_returns_tracks(client): bid = db.create_batch_job(genre="lo-fi", count=2) db.append_batch_track(bid, 999) # phantom track r = client.get(f"/api/music/generate-batch/{bid}") assert r.status_code == 200 body = r.json() assert body["track_ids"] == [999] assert body["tracks"] == [] # 999 not in music_library def test_get_batch_404(client): r = client.get("/api/music/generate-batch/99999") assert r.status_code == 404 def test_list_batches(client): db.create_batch_job(genre="lo-fi", count=1) db.create_batch_job(genre="phonk", count=2) r = client.get("/api/music/generate-batch") assert r.status_code == 200 assert len(r.json()["batches"]) == 2 def test_list_batches_active_filter(client): b1 = db.create_batch_job(genre="lo-fi", count=1) b2 = db.create_batch_job(genre="phonk", count=2) db.update_batch_job(b1, status="failed") r = client.get("/api/music/generate-batch?status=active") ids = [j["id"] for j in r.json()["batches"]] assert b2 in ids assert b1 not in ids