Initial public release
This commit is contained in:
@@ -0,0 +1,64 @@
|
||||
"""Tests for validate_audiodb_image_url SSRF protection."""
|
||||
import pytest
|
||||
from infrastructure.validators import validate_audiodb_image_url
|
||||
|
||||
|
||||
class TestValidateAudiodbImageUrl:
|
||||
|
||||
@pytest.mark.parametrize("url", [
|
||||
"https://www.theaudiodb.com/images/media/thumb.jpg",
|
||||
"https://theaudiodb.com/images/media/artist/thumb/coldplay.jpg",
|
||||
"https://r2.theaudiodb.com/images/album/thumb/parachutes.jpg",
|
||||
"https://r2.theaudiodb.com/images/artist/fanart/coldplay1.jpg",
|
||||
])
|
||||
def test_valid_audiodb_urls(self, url: str) -> None:
|
||||
assert validate_audiodb_image_url(url) is True
|
||||
|
||||
@pytest.mark.parametrize("url", [
|
||||
"http://www.theaudiodb.com/images/media/thumb.jpg",
|
||||
"http://r2.theaudiodb.com/images/album/thumb.jpg",
|
||||
])
|
||||
def test_rejects_http_scheme(self, url: str) -> None:
|
||||
assert validate_audiodb_image_url(url) is False
|
||||
|
||||
@pytest.mark.parametrize("url", [
|
||||
"ftp://r2.theaudiodb.com/file.jpg",
|
||||
"file:///etc/passwd",
|
||||
"data:text/html,<script>alert(1)</script>",
|
||||
"javascript:alert(1)",
|
||||
])
|
||||
def test_rejects_non_https_schemes(self, url: str) -> None:
|
||||
assert validate_audiodb_image_url(url) is False
|
||||
|
||||
@pytest.mark.parametrize("url", [
|
||||
"https://evil.com/images/media/thumb.jpg",
|
||||
"https://theaudiodb.com.evil.com/exploit.jpg",
|
||||
"https://attacker.theaudiodb.com/images/thumb.jpg",
|
||||
"https://notaudiodb.com/images/thumb.jpg",
|
||||
"https://example.com/redirect?url=https://r2.theaudiodb.com/img.jpg",
|
||||
])
|
||||
def test_rejects_unknown_hosts(self, url: str) -> None:
|
||||
assert validate_audiodb_image_url(url) is False
|
||||
|
||||
@pytest.mark.parametrize("url", [
|
||||
"https://127.0.0.1/images/thumb.jpg",
|
||||
"https://10.0.0.1/images/thumb.jpg",
|
||||
"https://192.168.1.1/images/thumb.jpg",
|
||||
"https://[::1]/images/thumb.jpg",
|
||||
"https://169.254.169.254/latest/meta-data/",
|
||||
])
|
||||
def test_rejects_private_and_loopback_ips(self, url: str) -> None:
|
||||
assert validate_audiodb_image_url(url) is False
|
||||
|
||||
@pytest.mark.parametrize("url", [
|
||||
"",
|
||||
None,
|
||||
" ",
|
||||
"not-a-url",
|
||||
"://missing-scheme.com",
|
||||
])
|
||||
def test_rejects_invalid_inputs(self, url) -> None:
|
||||
assert validate_audiodb_image_url(url) is False
|
||||
|
||||
def test_rejects_url_without_host(self) -> None:
|
||||
assert validate_audiodb_image_url("https:///path/only") is False
|
||||
@@ -0,0 +1,303 @@
|
||||
import json
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.cache.disk_cache import DiskMetadataCache
|
||||
from infrastructure.persistence.genre_index import GenreIndex
|
||||
from infrastructure.persistence.library_db import LibraryDB
|
||||
from infrastructure.persistence.youtube_store import YouTubeStore
|
||||
|
||||
|
||||
def _make_stores(db_path):
|
||||
lock = threading.Lock()
|
||||
lib = LibraryDB(db_path=db_path, write_lock=lock)
|
||||
genre = GenreIndex(db_path=db_path, write_lock=lock)
|
||||
yt = YouTubeStore(db_path=db_path, write_lock=lock)
|
||||
# All stores must be initialized so cross-domain DELETEs in save_library/clear succeed
|
||||
from infrastructure.persistence.mbid_store import MBIDStore
|
||||
from infrastructure.persistence.sync_state_store import SyncStateStore
|
||||
|
||||
SyncStateStore(db_path=db_path, write_lock=lock)
|
||||
MBIDStore(db_path=db_path, write_lock=lock)
|
||||
return lib, genre, yt
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_library_cache_genre_queries_use_normalized_lookup(tmp_path):
|
||||
lib, genre, _ = _make_stores(tmp_path / "library.db")
|
||||
|
||||
await lib.save_library(
|
||||
artists=[
|
||||
{"mbid": "artist-1", "name": "Artist One", "album_count": 1, "date_added": 10},
|
||||
{"mbid": "artist-2", "name": "Artist Two", "album_count": 1, "date_added": 20},
|
||||
],
|
||||
albums=[
|
||||
{
|
||||
"mbid": "album-1",
|
||||
"artist_mbid": "artist-1",
|
||||
"artist_name": "Artist One",
|
||||
"title": "First Album",
|
||||
"date_added": 100,
|
||||
"monitored": True,
|
||||
},
|
||||
{
|
||||
"mbid": "album-2",
|
||||
"artist_mbid": "artist-2",
|
||||
"artist_name": "Artist Two",
|
||||
"title": "Second Album",
|
||||
"date_added": 200,
|
||||
"monitored": True,
|
||||
},
|
||||
],
|
||||
)
|
||||
await genre.save_artist_genres(
|
||||
{
|
||||
"artist-1": [" Rock ", "Alternative", "rock"],
|
||||
"artist-2": ["Jazz", "rock"],
|
||||
}
|
||||
)
|
||||
|
||||
artists = await genre.get_artists_by_genre("ROCK", limit=1)
|
||||
albums = await genre.get_albums_by_genre(" rock ", limit=2)
|
||||
|
||||
assert [artist["mbid"] for artist in artists] == ["artist-2"]
|
||||
assert [album["mbid"] for album in albums] == ["album-2", "album-1"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_library_cache_backfills_genre_lookup_from_existing_json_rows(tmp_path):
|
||||
db_path = tmp_path / "library.db"
|
||||
lib, genre, _ = _make_stores(db_path)
|
||||
await lib.save_library(
|
||||
artists=[{"mbid": "artist-1", "name": "Artist One", "album_count": 1, "date_added": 10}],
|
||||
albums=[],
|
||||
)
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
try:
|
||||
conn.execute("DELETE FROM artist_genre_lookup")
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO artist_genres (artist_mbid_lower, artist_mbid, genres_json) VALUES (?, ?, ?)",
|
||||
("artist-1", "artist-1", json.dumps(["post-rock"])),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
_, genre2, _ = _make_stores(db_path)
|
||||
artists = await genre2.get_artists_by_genre("POST-ROCK")
|
||||
|
||||
assert [artist["mbid"] for artist in artists] == ["artist-1"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_covers_removes_expired_cover_payload(tmp_path):
|
||||
cache = DiskMetadataCache(base_path=tmp_path)
|
||||
cover_dir = tmp_path / "recent" / "covers"
|
||||
cover_file = cover_dir / "cover.bin"
|
||||
meta_file = cover_dir / "cover.meta.json"
|
||||
|
||||
cover_file.write_bytes(b"image-bytes")
|
||||
meta_file.write_text(json.dumps({"expires_at": time.time() - 60, "last_accessed": 1}))
|
||||
|
||||
removed = await cache.cleanup_expired_covers()
|
||||
|
||||
assert removed == 1
|
||||
assert not cover_file.exists()
|
||||
assert not meta_file.exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_cover_size_limits_evicts_oldest_recent_cover(tmp_path):
|
||||
cache = DiskMetadataCache(base_path=tmp_path)
|
||||
cache.recent_covers_max_size_bytes = 6
|
||||
|
||||
cover_dir = tmp_path / "recent" / "covers"
|
||||
old_cover = cover_dir / "old.bin"
|
||||
new_cover = cover_dir / "new.bin"
|
||||
old_meta = cover_dir / "old.meta.json"
|
||||
new_meta = cover_dir / "new.meta.json"
|
||||
|
||||
old_cover.write_bytes(b"1234")
|
||||
new_cover.write_bytes(b"5678")
|
||||
old_meta.write_text(json.dumps({"last_accessed": 1}))
|
||||
new_meta.write_text(json.dumps({"last_accessed": 2}))
|
||||
|
||||
freed = await cache.enforce_cover_size_limits()
|
||||
|
||||
assert freed == 4
|
||||
assert not old_cover.exists()
|
||||
assert not old_meta.exists()
|
||||
assert new_cover.exists()
|
||||
assert new_meta.exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_library_cache_keeps_youtube_track_links_distinct_per_disc(tmp_path):
|
||||
_, _, yt = _make_stores(tmp_path / "library.db")
|
||||
|
||||
await yt.save_youtube_track_links_batch(
|
||||
"album-1",
|
||||
[
|
||||
{
|
||||
"track_number": 1,
|
||||
"disc_number": 1,
|
||||
"album_name": "Album",
|
||||
"track_name": "Disc One Track One",
|
||||
"video_id": "video-1",
|
||||
"artist_name": "Artist",
|
||||
"embed_url": "https://example.com/1",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
},
|
||||
{
|
||||
"track_number": 1,
|
||||
"disc_number": 2,
|
||||
"album_name": "Album",
|
||||
"track_name": "Disc Two Track One",
|
||||
"video_id": "video-2",
|
||||
"artist_name": "Artist",
|
||||
"embed_url": "https://example.com/2",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
links = await yt.get_youtube_track_links("album-1")
|
||||
assert [(link["disc_number"], link["track_number"], link["video_id"]) for link in links] == [
|
||||
(1, 1, "video-1"),
|
||||
(2, 1, "video-2"),
|
||||
]
|
||||
|
||||
await yt.delete_youtube_track_link("album-1", 2, 1)
|
||||
|
||||
remaining = await yt.get_youtube_track_links("album-1")
|
||||
assert [(link["disc_number"], link["track_number"], link["video_id"]) for link in remaining] == [
|
||||
(1, 1, "video-1")
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_library_cache_migrates_legacy_youtube_track_links_with_default_disc_number(tmp_path):
|
||||
db_path = tmp_path / "library.db"
|
||||
conn = sqlite3.connect(db_path)
|
||||
try:
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE youtube_track_links (
|
||||
album_id TEXT NOT NULL,
|
||||
track_number INTEGER NOT NULL,
|
||||
album_name TEXT NOT NULL,
|
||||
track_name TEXT NOT NULL,
|
||||
video_id TEXT NOT NULL,
|
||||
artist_name TEXT NOT NULL,
|
||||
embed_url TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
PRIMARY KEY (album_id, track_number)
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO youtube_track_links (
|
||||
album_id, track_number, album_name, track_name,
|
||||
video_id, artist_name, embed_url, created_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
"album-legacy",
|
||||
5,
|
||||
"Legacy Album",
|
||||
"Legacy Track",
|
||||
"legacy-video",
|
||||
"Artist",
|
||||
"https://example.com/legacy",
|
||||
"2024-01-01T00:00:00Z",
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
_, _, yt = _make_stores(db_path)
|
||||
links = await yt.get_youtube_track_links("album-legacy")
|
||||
|
||||
assert len(links) == 1
|
||||
assert links[0]["disc_number"] == 1
|
||||
assert links[0]["track_number"] == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_library_cache_youtube_track_links_uniqueness_allows_same_track_different_disc(
|
||||
tmp_path,
|
||||
):
|
||||
"""Verify the new (album_id, disc_number, track_number) PK allows same track number across discs."""
|
||||
_, _, yt = _make_stores(tmp_path / "library.db")
|
||||
|
||||
await yt.save_youtube_track_link(
|
||||
album_id="album-uniq",
|
||||
track_number=1,
|
||||
disc_number=1,
|
||||
album_name="Test Album",
|
||||
track_name="Track One Disc One",
|
||||
video_id="vid-d1t1",
|
||||
artist_name="Artist",
|
||||
embed_url="https://example.com/d1t1",
|
||||
created_at="2024-01-01T00:00:00Z",
|
||||
)
|
||||
await yt.save_youtube_track_link(
|
||||
album_id="album-uniq",
|
||||
track_number=1,
|
||||
disc_number=2,
|
||||
album_name="Test Album",
|
||||
track_name="Track One Disc Two",
|
||||
video_id="vid-d2t1",
|
||||
artist_name="Artist",
|
||||
embed_url="https://example.com/d2t1",
|
||||
created_at="2024-01-01T00:00:00Z",
|
||||
)
|
||||
await yt.save_youtube_track_link(
|
||||
album_id="album-uniq",
|
||||
track_number=1,
|
||||
disc_number=1,
|
||||
album_name="Test Album",
|
||||
track_name="Updated Track One",
|
||||
video_id="vid-d1t1-updated",
|
||||
artist_name="Artist",
|
||||
embed_url="https://example.com/d1t1-v2",
|
||||
created_at="2024-01-02T00:00:00Z",
|
||||
)
|
||||
|
||||
links = await yt.get_youtube_track_links("album-uniq")
|
||||
assert len(links) == 2
|
||||
d1 = next(link for link in links if link["disc_number"] == 1)
|
||||
d2 = next(link for link in links if link["disc_number"] == 2)
|
||||
assert d1["video_id"] == "vid-d1t1-updated"
|
||||
assert d1["track_name"] == "Updated Track One"
|
||||
assert d2["video_id"] == "vid-d2t1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_library_cache_save_single_youtube_track_link_with_disc_number(tmp_path):
|
||||
"""Verify save_youtube_track_link (single-row path) correctly stores disc_number."""
|
||||
_, _, yt = _make_stores(tmp_path / "library.db")
|
||||
|
||||
await yt.save_youtube_track_link(
|
||||
album_id="album-single",
|
||||
track_number=3,
|
||||
disc_number=2,
|
||||
album_name="Single Test",
|
||||
track_name="Track Three Disc Two",
|
||||
video_id="vid-single",
|
||||
artist_name="Artist",
|
||||
embed_url="https://example.com/single",
|
||||
created_at="2024-06-15T00:00:00Z",
|
||||
)
|
||||
|
||||
links = await yt.get_youtube_track_links("album-single")
|
||||
assert len(links) == 1
|
||||
assert links[0]["disc_number"] == 2
|
||||
assert links[0]["track_number"] == 3
|
||||
assert links[0]["video_id"] == "vid-single"
|
||||
@@ -0,0 +1,118 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.resilience.retry import CircuitBreaker, CircuitState
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_arecord_failure_does_not_overcount():
|
||||
cb = CircuitBreaker(failure_threshold=5, name="test-overcount")
|
||||
|
||||
async def fail_once():
|
||||
await cb.arecord_failure()
|
||||
|
||||
await asyncio.gather(*[fail_once() for _ in range(10)])
|
||||
|
||||
assert cb.failure_count <= 10
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_arecord_success_transitions_half_open_to_closed():
|
||||
cb = CircuitBreaker(failure_threshold=3, success_threshold=2, name="test-success-transition")
|
||||
|
||||
for _ in range(3):
|
||||
cb.record_failure()
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
cb.state = CircuitState.HALF_OPEN
|
||||
cb.success_count = 0
|
||||
|
||||
async def succeed_once():
|
||||
await cb.arecord_success()
|
||||
|
||||
await asyncio.gather(*[succeed_once() for _ in range(10)])
|
||||
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
assert cb.failure_count == 0
|
||||
assert cb.success_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_atry_transition_open_to_half_open():
|
||||
cb = CircuitBreaker(failure_threshold=3, timeout=0.0, name="test-transition")
|
||||
|
||||
for _ in range(3):
|
||||
cb.record_failure()
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
await cb.atry_transition()
|
||||
|
||||
assert cb.state == CircuitState.HALF_OPEN
|
||||
assert cb.success_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_atry_transition_noop_when_not_open():
|
||||
cb = CircuitBreaker(name="test-noop")
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
|
||||
await cb.atry_transition()
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_atry_transition_noop_when_timeout_not_elapsed():
|
||||
cb = CircuitBreaker(failure_threshold=3, timeout=60.0, name="test-timeout-not-elapsed")
|
||||
|
||||
for _ in range(3):
|
||||
cb.record_failure()
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
await cb.atry_transition()
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_methods_still_work_without_await():
|
||||
cb = CircuitBreaker(failure_threshold=3, success_threshold=1, name="test-sync-compat")
|
||||
|
||||
cb.record_failure()
|
||||
cb.record_failure()
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
assert cb.failure_count == 2
|
||||
|
||||
cb.record_failure()
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
cb.reset()
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
assert cb.failure_count == 0
|
||||
|
||||
cb.record_success()
|
||||
assert cb.failure_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_atry_transition_only_transitions_once():
|
||||
cb = CircuitBreaker(failure_threshold=3, timeout=0.0, name="test-double-transition")
|
||||
|
||||
for _ in range(3):
|
||||
cb.record_failure()
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
transition_states = []
|
||||
|
||||
async def try_transition():
|
||||
await cb.atry_transition()
|
||||
transition_states.append(cb.state)
|
||||
|
||||
await asyncio.gather(*[try_transition() for _ in range(5)])
|
||||
|
||||
assert cb.state == CircuitState.HALF_OPEN
|
||||
assert all(s == CircuitState.HALF_OPEN for s in transition_states)
|
||||
@@ -0,0 +1,126 @@
|
||||
"""Tests for DegradationContext and contextvar lifecycle."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.degradation import (
|
||||
DegradationContext,
|
||||
clear_degradation_context,
|
||||
get_degradation_context,
|
||||
init_degradation_context,
|
||||
try_get_degradation_context,
|
||||
)
|
||||
from infrastructure.integration_result import IntegrationResult
|
||||
|
||||
|
||||
|
||||
|
||||
class TestDegradationContext:
|
||||
def test_empty_context(self):
|
||||
ctx = DegradationContext()
|
||||
assert ctx.summary() == {}
|
||||
assert ctx.has_degradation() is False
|
||||
assert ctx.degraded_summary() == {}
|
||||
|
||||
def test_record_ok(self):
|
||||
ctx = DegradationContext()
|
||||
ctx.record(IntegrationResult.ok(data=[1], source="musicbrainz"))
|
||||
assert ctx.summary() == {"musicbrainz": "ok"}
|
||||
assert ctx.has_degradation() is False
|
||||
|
||||
def test_record_error(self):
|
||||
ctx = DegradationContext()
|
||||
ctx.record(IntegrationResult.error(source="jellyfin", msg="timeout"))
|
||||
assert ctx.summary() == {"jellyfin": "error"}
|
||||
assert ctx.has_degradation() is True
|
||||
assert ctx.degraded_summary() == {"jellyfin": "error"}
|
||||
|
||||
def test_record_degraded(self):
|
||||
ctx = DegradationContext()
|
||||
ctx.record(
|
||||
IntegrationResult.degraded(data=[], source="audiodb", msg="rate limit")
|
||||
)
|
||||
assert ctx.summary() == {"audiodb": "degraded"}
|
||||
assert ctx.has_degradation() is True
|
||||
|
||||
def test_worst_status_wins(self):
|
||||
ctx = DegradationContext()
|
||||
ctx.record(IntegrationResult.ok(data=[], source="musicbrainz"))
|
||||
ctx.record(IntegrationResult.degraded(data=[], source="musicbrainz", msg="slow"))
|
||||
assert ctx.summary() == {"musicbrainz": "degraded"}
|
||||
|
||||
def test_error_beats_degraded(self):
|
||||
ctx = DegradationContext()
|
||||
ctx.record(
|
||||
IntegrationResult.degraded(data=[], source="musicbrainz", msg="slow")
|
||||
)
|
||||
ctx.record(IntegrationResult.error(source="musicbrainz", msg="503"))
|
||||
assert ctx.summary() == {"musicbrainz": "error"}
|
||||
|
||||
def test_cannot_downgrade(self):
|
||||
ctx = DegradationContext()
|
||||
ctx.record(IntegrationResult.error(source="jellyfin", msg="down"))
|
||||
ctx.record(IntegrationResult.ok(data=[1], source="jellyfin"))
|
||||
assert ctx.summary() == {"jellyfin": "error"}
|
||||
|
||||
def test_multiple_sources(self):
|
||||
ctx = DegradationContext()
|
||||
ctx.record(IntegrationResult.ok(data=[], source="musicbrainz"))
|
||||
ctx.record(IntegrationResult.error(source="jellyfin", msg="down"))
|
||||
ctx.record(
|
||||
IntegrationResult.degraded(data={}, source="audiodb", msg="slow")
|
||||
)
|
||||
assert ctx.summary() == {
|
||||
"musicbrainz": "ok",
|
||||
"jellyfin": "error",
|
||||
"audiodb": "degraded",
|
||||
}
|
||||
assert ctx.degraded_summary() == {
|
||||
"jellyfin": "error",
|
||||
"audiodb": "degraded",
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
class TestContextVarLifecycle:
|
||||
def test_no_context_raises(self):
|
||||
clear_degradation_context()
|
||||
with pytest.raises(RuntimeError, match="outside a request scope"):
|
||||
get_degradation_context()
|
||||
|
||||
def test_try_get_returns_none_outside(self):
|
||||
clear_degradation_context()
|
||||
assert try_get_degradation_context() is None
|
||||
|
||||
def test_init_and_get(self):
|
||||
ctx = init_degradation_context()
|
||||
assert get_degradation_context() is ctx
|
||||
clear_degradation_context()
|
||||
|
||||
def test_clear_removes_context(self):
|
||||
init_degradation_context()
|
||||
clear_degradation_context()
|
||||
assert try_get_degradation_context() is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_isolated_across_tasks(self):
|
||||
"""Context in one asyncio task must not leak into another."""
|
||||
results: dict[str, bool] = {}
|
||||
|
||||
async def task_a():
|
||||
init_degradation_context()
|
||||
ctx = get_degradation_context()
|
||||
ctx.record(IntegrationResult.error(source="a", msg="fail"))
|
||||
await asyncio.sleep(0.01)
|
||||
results["a_has_degradation"] = get_degradation_context().has_degradation()
|
||||
clear_degradation_context()
|
||||
|
||||
async def task_b():
|
||||
await asyncio.sleep(0.005)
|
||||
results["b_is_none"] = try_get_degradation_context() is None
|
||||
|
||||
await asyncio.gather(task_a(), task_b())
|
||||
assert results["a_has_degradation"] is True
|
||||
assert results["b_is_none"] is True
|
||||
@@ -0,0 +1,62 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
from core.exceptions import ClientDisconnectedError
|
||||
from infrastructure.http.disconnect import check_disconnected
|
||||
from infrastructure.http.deduplication import RequestDeduplicator
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_disconnected_raises_when_disconnected():
|
||||
is_disconnected = AsyncMock(return_value=True)
|
||||
with pytest.raises(ClientDisconnectedError):
|
||||
await check_disconnected(is_disconnected)
|
||||
assert is_disconnected.await_count == 1
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_disconnected_noop_when_connected():
|
||||
is_disconnected = AsyncMock(return_value=False)
|
||||
await check_disconnected(is_disconnected)
|
||||
assert is_disconnected.await_count == 1
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_check_disconnected_noop_when_none():
|
||||
await check_disconnected(None)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_dedup_leader_disconnect_follower_retries_as_leader():
|
||||
dedup = RequestDeduplicator()
|
||||
follower_registered = asyncio.Event()
|
||||
leader_error = None
|
||||
expected_result = ("image-bytes", "image/png", "source")
|
||||
|
||||
async def leader_coro():
|
||||
await follower_registered.wait()
|
||||
raise ClientDisconnectedError("leader disconnected")
|
||||
|
||||
async def run_leader():
|
||||
nonlocal leader_error
|
||||
try:
|
||||
await dedup.dedupe("key1", leader_coro)
|
||||
except ClientDisconnectedError as e:
|
||||
leader_error = e
|
||||
|
||||
async def follower_coro():
|
||||
return expected_result
|
||||
|
||||
async def run_follower():
|
||||
await asyncio.sleep(0)
|
||||
follower_registered.set()
|
||||
return await dedup.dedupe("key1", follower_coro)
|
||||
|
||||
leader_task = asyncio.create_task(run_leader())
|
||||
await asyncio.sleep(0)
|
||||
follower_task = asyncio.create_task(run_follower())
|
||||
|
||||
await asyncio.gather(leader_task, follower_task)
|
||||
|
||||
assert isinstance(leader_error, ClientDisconnectedError)
|
||||
assert follower_task.result() == expected_result
|
||||
@@ -0,0 +1,76 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.tasks import cleanup_disk_cache_periodically
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_periodic_cleanup_calls_both_caches():
|
||||
disk_cache = AsyncMock()
|
||||
cover_disk_cache = AsyncMock()
|
||||
|
||||
iteration_count = 0
|
||||
|
||||
original_cleanup = cleanup_disk_cache_periodically
|
||||
|
||||
async def run_one_iteration():
|
||||
nonlocal iteration_count
|
||||
task = asyncio.create_task(
|
||||
original_cleanup(disk_cache, interval=0, cover_disk_cache=cover_disk_cache)
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
await run_one_iteration()
|
||||
|
||||
disk_cache.cleanup_expired_recent.assert_called()
|
||||
disk_cache.enforce_recent_size_limits.assert_called()
|
||||
disk_cache.cleanup_expired_covers.assert_called()
|
||||
disk_cache.enforce_cover_size_limits.assert_called()
|
||||
cover_disk_cache.enforce_size_limit.assert_called_with(force=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_periodic_cleanup_works_without_cover_cache():
|
||||
disk_cache = AsyncMock()
|
||||
|
||||
task = asyncio.create_task(
|
||||
cleanup_disk_cache_periodically(disk_cache, interval=0, cover_disk_cache=None)
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
disk_cache.cleanup_expired_recent.assert_called()
|
||||
disk_cache.enforce_recent_size_limits.assert_called()
|
||||
disk_cache.cleanup_expired_covers.assert_called()
|
||||
disk_cache.enforce_cover_size_limits.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_periodic_cleanup_continues_on_cover_cache_error():
|
||||
disk_cache = AsyncMock()
|
||||
cover_disk_cache = AsyncMock()
|
||||
cover_disk_cache.enforce_size_limit.side_effect = [RuntimeError("disk full"), None]
|
||||
|
||||
task = asyncio.create_task(
|
||||
cleanup_disk_cache_periodically(disk_cache, interval=0, cover_disk_cache=cover_disk_cache)
|
||||
)
|
||||
await asyncio.sleep(0.1)
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
assert cover_disk_cache.enforce_size_limit.call_count >= 1
|
||||
assert disk_cache.cleanup_expired_recent.call_count >= 1
|
||||
@@ -0,0 +1,172 @@
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from api.v1.schemas.album import AlbumInfo
|
||||
from infrastructure.cache.disk_cache import DiskMetadataCache
|
||||
from repositories.audiodb_models import AudioDBArtistImages, AudioDBAlbumImages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_album_serializes_msgspec_struct_as_mapping(tmp_path):
|
||||
cache = DiskMetadataCache(base_path=tmp_path)
|
||||
mbid = "4549a80c-efe6-4386-b3a2-4b4a918eb31f"
|
||||
album_info = AlbumInfo(
|
||||
title="The Moon Song",
|
||||
musicbrainz_id=mbid,
|
||||
artist_name="beabadoobee",
|
||||
artist_id="88d17133-abbc-42db-9526-4e2c1db60336",
|
||||
in_library=True,
|
||||
)
|
||||
|
||||
await cache.set_album(mbid, album_info, is_monitored=True)
|
||||
|
||||
cache_hash = hashlib.sha1(mbid.encode()).hexdigest()
|
||||
cache_file = tmp_path / "persistent" / "albums" / f"{cache_hash}.json"
|
||||
payload = json.loads(cache_file.read_text())
|
||||
|
||||
assert isinstance(payload, dict)
|
||||
assert payload["musicbrainz_id"] == mbid
|
||||
|
||||
cached = await cache.get_album(mbid)
|
||||
assert isinstance(cached, dict)
|
||||
assert cached["title"] == "The Moon Song"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_album_deletes_corrupt_string_payload(tmp_path):
|
||||
cache = DiskMetadataCache(base_path=tmp_path)
|
||||
mbid = "8e1e9e51-38dc-4df3-8027-a0ada37d4674"
|
||||
|
||||
cache_hash = hashlib.sha1(mbid.encode()).hexdigest()
|
||||
cache_file = tmp_path / "persistent" / "albums" / f"{cache_hash}.json"
|
||||
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
cache_file.write_text(json.dumps("AlbumInfo(title='Corrupt')"))
|
||||
|
||||
cached = await cache.get_album(mbid)
|
||||
|
||||
assert cached is None
|
||||
assert not cache_file.exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audiodb_artist_entity_routing(tmp_path):
|
||||
cache = DiskMetadataCache(base_path=tmp_path)
|
||||
mbid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
images = AudioDBArtistImages(
|
||||
thumb_url="https://example.com/thumb.jpg",
|
||||
fanart_url="https://example.com/fanart.jpg",
|
||||
lookup_source="mbid",
|
||||
matched_mbid=mbid,
|
||||
)
|
||||
|
||||
await cache._set_entity("audiodb_artist", mbid, images, is_monitored=False, ttl_seconds=None)
|
||||
|
||||
result = await cache._get_entity("audiodb_artist", mbid)
|
||||
assert result is not None
|
||||
assert result["thumb_url"] == "https://example.com/thumb.jpg"
|
||||
assert result["fanart_url"] == "https://example.com/fanart.jpg"
|
||||
assert result["lookup_source"] == "mbid"
|
||||
|
||||
cache_hash = hashlib.sha1(mbid.encode()).hexdigest()
|
||||
data_file = tmp_path / "recent" / "audiodb_artists" / f"{cache_hash}.json"
|
||||
assert data_file.exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audiodb_album_entity_routing(tmp_path):
|
||||
cache = DiskMetadataCache(base_path=tmp_path)
|
||||
mbid = "b2c3d4e5-f6a7-8901-bcde-f12345678901"
|
||||
images = AudioDBAlbumImages(
|
||||
album_thumb_url="https://example.com/album_thumb.jpg",
|
||||
album_back_url="https://example.com/album_back.jpg",
|
||||
lookup_source="name",
|
||||
matched_mbid=mbid,
|
||||
)
|
||||
|
||||
await cache._set_entity("audiodb_album", mbid, images, is_monitored=True, ttl_seconds=None)
|
||||
|
||||
result = await cache._get_entity("audiodb_album", mbid)
|
||||
assert result is not None
|
||||
assert result["album_thumb_url"] == "https://example.com/album_thumb.jpg"
|
||||
assert result["album_back_url"] == "https://example.com/album_back.jpg"
|
||||
assert result["lookup_source"] == "name"
|
||||
|
||||
cache_hash = hashlib.sha1(mbid.encode()).hexdigest()
|
||||
persistent_file = tmp_path / "persistent" / "audiodb_albums" / f"{cache_hash}.json"
|
||||
assert persistent_file.exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats_counts_audiodb_entries(tmp_path):
|
||||
cache = DiskMetadataCache(base_path=tmp_path)
|
||||
|
||||
artist_images = AudioDBArtistImages(thumb_url="https://example.com/a.jpg")
|
||||
album_images = AudioDBAlbumImages(album_thumb_url="https://example.com/b.jpg")
|
||||
|
||||
await cache._set_entity("audiodb_artist", "artist-1", artist_images, is_monitored=False, ttl_seconds=None)
|
||||
await cache._set_entity("audiodb_artist", "artist-2", artist_images, is_monitored=True, ttl_seconds=None)
|
||||
await cache._set_entity("audiodb_album", "album-1", album_images, is_monitored=False, ttl_seconds=None)
|
||||
|
||||
stats = cache.get_stats()
|
||||
assert stats["audiodb_artist_count"] == 2
|
||||
assert stats["audiodb_album_count"] == 1
|
||||
assert stats["album_count"] == 0
|
||||
assert stats["artist_count"] == 0
|
||||
assert stats["total_count"] == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_audiodb_isolates_from_other_entities(tmp_path):
|
||||
cache = DiskMetadataCache(base_path=tmp_path)
|
||||
album_mbid = "c3d4e5f6-a7b8-9012-cdef-123456789012"
|
||||
album_info = AlbumInfo(
|
||||
title="Regular Album",
|
||||
musicbrainz_id=album_mbid,
|
||||
artist_name="Test Artist",
|
||||
artist_id="d4e5f6a7-b8c9-0123-defa-234567890123",
|
||||
in_library=False,
|
||||
)
|
||||
await cache.set_album(album_mbid, album_info, is_monitored=False)
|
||||
|
||||
artist_images = AudioDBArtistImages(thumb_url="https://example.com/thumb.jpg")
|
||||
album_images = AudioDBAlbumImages(album_thumb_url="https://example.com/album.jpg")
|
||||
await cache._set_entity("audiodb_artist", "adb-artist-1", artist_images, is_monitored=False, ttl_seconds=None)
|
||||
await cache._set_entity("audiodb_album", "adb-album-1", album_images, is_monitored=True, ttl_seconds=None)
|
||||
|
||||
stats_before = cache.get_stats()
|
||||
assert stats_before["audiodb_artist_count"] == 1
|
||||
assert stats_before["audiodb_album_count"] == 1
|
||||
assert stats_before["album_count"] == 1
|
||||
|
||||
await cache.clear_audiodb()
|
||||
|
||||
stats_after = cache.get_stats()
|
||||
assert stats_after["audiodb_artist_count"] == 0
|
||||
assert stats_after["audiodb_album_count"] == 0
|
||||
assert stats_after["album_count"] == 1
|
||||
|
||||
regular_album = await cache.get_album(album_mbid)
|
||||
assert regular_album is not None
|
||||
assert regular_album["title"] == "Regular Album"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audiodb_monitored_persistent_vs_recent(tmp_path):
|
||||
cache = DiskMetadataCache(base_path=tmp_path)
|
||||
mbid = "e5f6a7b8-c9d0-1234-efab-567890123456"
|
||||
images = AudioDBArtistImages(thumb_url="https://example.com/t.jpg")
|
||||
|
||||
await cache._set_entity("audiodb_artist", mbid, images, is_monitored=True, ttl_seconds=None)
|
||||
|
||||
cache_hash = hashlib.sha1(mbid.encode()).hexdigest()
|
||||
persistent_file = tmp_path / "persistent" / "audiodb_artists" / f"{cache_hash}.json"
|
||||
recent_file = tmp_path / "recent" / "audiodb_artists" / f"{cache_hash}.json"
|
||||
assert persistent_file.exists()
|
||||
assert not recent_file.exists()
|
||||
|
||||
await cache._set_entity("audiodb_artist", mbid, images, is_monitored=False, ttl_seconds=None)
|
||||
|
||||
assert not persistent_file.exists()
|
||||
assert recent_file.exists()
|
||||
@@ -0,0 +1,114 @@
|
||||
"""Tests for IntegrationResult and aggregate_status."""
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.integration_result import (
|
||||
IntegrationResult,
|
||||
aggregate_status,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
class TestIntegrationResultOk:
|
||||
def test_ok_carries_data(self):
|
||||
r = IntegrationResult.ok(data=["a", "b"], source="musicbrainz")
|
||||
assert r.data == ["a", "b"]
|
||||
assert r.source == "musicbrainz"
|
||||
assert r.status == "ok"
|
||||
assert r.error_message is None
|
||||
|
||||
def test_is_ok_true(self):
|
||||
r = IntegrationResult.ok(data=42, source="jellyfin")
|
||||
assert r.is_ok is True
|
||||
assert r.is_degraded is False
|
||||
assert r.is_error is False
|
||||
|
||||
|
||||
class TestIntegrationResultDegraded:
|
||||
def test_degraded_carries_partial_data(self):
|
||||
r = IntegrationResult.degraded(
|
||||
data={"stale": True}, source="audiodb", msg="rate limited"
|
||||
)
|
||||
assert r.data == {"stale": True}
|
||||
assert r.source == "audiodb"
|
||||
assert r.status == "degraded"
|
||||
assert r.error_message == "rate limited"
|
||||
|
||||
def test_is_degraded_true(self):
|
||||
r = IntegrationResult.degraded(data=[], source="lastfm", msg="timeout")
|
||||
assert r.is_degraded is True
|
||||
assert r.is_ok is False
|
||||
assert r.is_error is False
|
||||
|
||||
|
||||
class TestIntegrationResultError:
|
||||
def test_error_has_no_data(self):
|
||||
r = IntegrationResult.error(source="musicbrainz", msg="503 Service Unavailable")
|
||||
assert r.data is None
|
||||
assert r.source == "musicbrainz"
|
||||
assert r.status == "error"
|
||||
assert r.error_message == "503 Service Unavailable"
|
||||
|
||||
def test_is_error_true(self):
|
||||
r = IntegrationResult.error(source="wikidata", msg="boom")
|
||||
assert r.is_error is True
|
||||
assert r.is_ok is False
|
||||
assert r.is_degraded is False
|
||||
|
||||
|
||||
|
||||
|
||||
class TestDataOr:
|
||||
def test_returns_data_when_present(self):
|
||||
r = IntegrationResult.ok(data=[1, 2, 3], source="mb")
|
||||
assert r.data_or([]) == [1, 2, 3]
|
||||
|
||||
def test_returns_default_when_none(self):
|
||||
r = IntegrationResult.error(source="mb", msg="down")
|
||||
assert r.data_or([]) == []
|
||||
|
||||
def test_returns_data_for_degraded(self):
|
||||
r = IntegrationResult.degraded(data={"partial": True}, source="mb", msg="slow")
|
||||
assert r.data_or({}) == {"partial": True}
|
||||
|
||||
|
||||
|
||||
|
||||
class TestImmutability:
|
||||
def test_frozen(self):
|
||||
r = IntegrationResult.ok(data="hello", source="test")
|
||||
with pytest.raises(AttributeError):
|
||||
r.data = "goodbye" # type: ignore[misc]
|
||||
|
||||
|
||||
|
||||
|
||||
class TestAggregateStatus:
|
||||
def test_all_ok(self):
|
||||
assert aggregate_status(
|
||||
IntegrationResult.ok(1, "a"),
|
||||
IntegrationResult.ok(2, "b"),
|
||||
) == "ok"
|
||||
|
||||
def test_one_degraded(self):
|
||||
assert aggregate_status(
|
||||
IntegrationResult.ok(1, "a"),
|
||||
IntegrationResult.degraded(2, "b", "slow"),
|
||||
) == "degraded"
|
||||
|
||||
def test_one_error(self):
|
||||
assert aggregate_status(
|
||||
IntegrationResult.ok(1, "a"),
|
||||
IntegrationResult.degraded(2, "b", "slow"),
|
||||
IntegrationResult.error("c", "down"),
|
||||
) == "error"
|
||||
|
||||
def test_empty(self):
|
||||
assert aggregate_status() == "ok"
|
||||
|
||||
def test_error_short_circuits(self):
|
||||
assert aggregate_status(
|
||||
IntegrationResult.error("a", "x"),
|
||||
IntegrationResult.ok(1, "b"),
|
||||
) == "error"
|
||||
@@ -0,0 +1,304 @@
|
||||
"""Tests for LibraryDB paginated query methods."""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.persistence.library_db import LibraryDB
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db(tmp_path: Path) -> LibraryDB:
|
||||
return LibraryDB(db_path=tmp_path / "test.db", write_lock=threading.Lock())
|
||||
|
||||
|
||||
def _make_albums(count: int, *, start: int = 1) -> list[dict]:
|
||||
"""Generate album dicts with predictable, sortable data."""
|
||||
albums = []
|
||||
for i in range(start, start + count):
|
||||
albums.append(
|
||||
{
|
||||
"mbid": f"album-{i:04d}",
|
||||
"artist_mbid": f"artist-{(i % 5) + 1:04d}",
|
||||
"artist_name": f"Artist {chr(65 + (i % 26))}",
|
||||
"title": f"Album {chr(65 + ((i + 13) % 26))} {i:04d}",
|
||||
"year": 2000 + (i % 24),
|
||||
"cover_url": None,
|
||||
"monitored": True,
|
||||
"date_added": 1700000000 + i * 100,
|
||||
}
|
||||
)
|
||||
return albums
|
||||
|
||||
|
||||
def _make_artists(count: int) -> list[dict]:
|
||||
"""Generate artist dicts with predictable data."""
|
||||
artists = []
|
||||
for i in range(1, count + 1):
|
||||
artists.append(
|
||||
{
|
||||
"mbid": f"artist-{i:04d}",
|
||||
"name": f"Artist {chr(65 + (i % 26))}",
|
||||
"album_count": i,
|
||||
"date_added": 1700000000 + i * 100,
|
||||
}
|
||||
)
|
||||
return artists
|
||||
|
||||
|
||||
async def _seed(db: LibraryDB, n_albums: int = 100, n_artists: int = 20) -> None:
|
||||
await db.save_library(_make_artists(n_artists), _make_albums(n_albums))
|
||||
|
||||
|
||||
# --- Album pagination ---
|
||||
|
||||
|
||||
def test_albums_basic_pagination(db: LibraryDB):
|
||||
asyncio.get_event_loop().run_until_complete(_seed(db))
|
||||
items, total = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_albums_paginated(limit=10, offset=0)
|
||||
)
|
||||
assert total == 100
|
||||
assert len(items) == 10
|
||||
|
||||
|
||||
def test_albums_offset_beyond_total(db: LibraryDB):
|
||||
asyncio.get_event_loop().run_until_complete(_seed(db))
|
||||
items, total = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_albums_paginated(limit=10, offset=200)
|
||||
)
|
||||
assert total == 100
|
||||
assert len(items) == 0
|
||||
|
||||
|
||||
def test_albums_last_partial_page(db: LibraryDB):
|
||||
asyncio.get_event_loop().run_until_complete(_seed(db))
|
||||
items, total = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_albums_paginated(limit=30, offset=90)
|
||||
)
|
||||
assert total == 100
|
||||
assert len(items) == 10
|
||||
|
||||
|
||||
def test_albums_sort_by_title_asc(db: LibraryDB):
|
||||
asyncio.get_event_loop().run_until_complete(_seed(db))
|
||||
items, _ = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_albums_paginated(limit=100, offset=0, sort_by="title", sort_order="asc")
|
||||
)
|
||||
titles = [i.get("title", "") for i in items]
|
||||
assert titles == sorted(titles, key=str.casefold)
|
||||
|
||||
|
||||
def test_albums_sort_by_title_desc(db: LibraryDB):
|
||||
asyncio.get_event_loop().run_until_complete(_seed(db))
|
||||
items, _ = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_albums_paginated(limit=100, offset=0, sort_by="title", sort_order="desc")
|
||||
)
|
||||
titles = [i.get("title", "") for i in items]
|
||||
assert titles == sorted(titles, key=str.casefold, reverse=True)
|
||||
|
||||
|
||||
def test_albums_sort_by_year(db: LibraryDB):
|
||||
asyncio.get_event_loop().run_until_complete(_seed(db))
|
||||
items, _ = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_albums_paginated(limit=100, offset=0, sort_by="year", sort_order="desc")
|
||||
)
|
||||
years = [i.get("year", 0) or 0 for i in items]
|
||||
assert years == sorted(years, reverse=True)
|
||||
|
||||
|
||||
def test_albums_sort_by_date_added(db: LibraryDB):
|
||||
asyncio.get_event_loop().run_until_complete(_seed(db))
|
||||
items, _ = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_albums_paginated(limit=100, offset=0, sort_by="date_added", sort_order="desc")
|
||||
)
|
||||
dates = [i.get("date_added", 0) or 0 for i in items]
|
||||
assert dates == sorted(dates, reverse=True)
|
||||
|
||||
|
||||
def test_albums_search_by_title(db: LibraryDB):
|
||||
asyncio.get_event_loop().run_until_complete(_seed(db))
|
||||
items, total = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_albums_paginated(limit=100, offset=0, search="Album A")
|
||||
)
|
||||
assert total > 0
|
||||
assert all("Album A" in i.get("title", "") for i in items)
|
||||
|
||||
|
||||
def test_albums_search_by_artist(db: LibraryDB):
|
||||
asyncio.get_event_loop().run_until_complete(_seed(db))
|
||||
items, total = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_albums_paginated(limit=100, offset=0, search="Artist A")
|
||||
)
|
||||
assert total > 0
|
||||
assert all(
|
||||
"Artist A" in i.get("artist_name", "") or "Artist A" in i.get("title", "")
|
||||
for i in items
|
||||
)
|
||||
|
||||
|
||||
def test_albums_search_no_results(db: LibraryDB):
|
||||
asyncio.get_event_loop().run_until_complete(_seed(db))
|
||||
items, total = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_albums_paginated(limit=10, offset=0, search="zzz_no_match_zzz")
|
||||
)
|
||||
assert total == 0
|
||||
assert len(items) == 0
|
||||
|
||||
|
||||
def test_albums_search_case_insensitive(db: LibraryDB):
|
||||
asyncio.get_event_loop().run_until_complete(_seed(db))
|
||||
items_upper, total_upper = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_albums_paginated(limit=100, offset=0, search="ALBUM A")
|
||||
)
|
||||
items_lower, total_lower = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_albums_paginated(limit=100, offset=0, search="album a")
|
||||
)
|
||||
assert total_upper == total_lower
|
||||
assert len(items_upper) == len(items_lower)
|
||||
|
||||
|
||||
def test_albums_search_escapes_like_metacharacters(db: LibraryDB):
|
||||
asyncio.get_event_loop().run_until_complete(_seed(db))
|
||||
items_pct, total_pct = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_albums_paginated(limit=100, offset=0, search="100%")
|
||||
)
|
||||
assert total_pct == 0
|
||||
assert len(items_pct) == 0
|
||||
items_under, total_under = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_albums_paginated(limit=100, offset=0, search="Album_A")
|
||||
)
|
||||
assert total_under == 0
|
||||
|
||||
|
||||
def test_artists_search_escapes_like_metacharacters(db: LibraryDB):
|
||||
asyncio.get_event_loop().run_until_complete(_seed(db))
|
||||
items, total = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_artists_paginated(limit=100, offset=0, search="Artist%B")
|
||||
)
|
||||
assert total == 0
|
||||
|
||||
|
||||
def test_albums_invalid_sort_falls_back(db: LibraryDB):
|
||||
asyncio.get_event_loop().run_until_complete(_seed(db))
|
||||
items, total = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_albums_paginated(limit=10, offset=0, sort_by="nonexistent")
|
||||
)
|
||||
assert total == 100
|
||||
assert len(items) == 10
|
||||
|
||||
|
||||
def test_albums_empty_library(db: LibraryDB):
|
||||
items, total = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_albums_paginated(limit=10, offset=0)
|
||||
)
|
||||
assert total == 0
|
||||
assert len(items) == 0
|
||||
|
||||
|
||||
# --- Artist pagination ---
|
||||
|
||||
|
||||
def test_artists_basic_pagination(db: LibraryDB):
|
||||
asyncio.get_event_loop().run_until_complete(_seed(db))
|
||||
items, total = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_artists_paginated(limit=5, offset=0)
|
||||
)
|
||||
assert total == 20
|
||||
assert len(items) == 5
|
||||
|
||||
|
||||
def test_artists_offset_beyond_total(db: LibraryDB):
|
||||
asyncio.get_event_loop().run_until_complete(_seed(db))
|
||||
items, total = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_artists_paginated(limit=10, offset=50)
|
||||
)
|
||||
assert total == 20
|
||||
assert len(items) == 0
|
||||
|
||||
|
||||
def test_artists_sort_by_name_asc(db: LibraryDB):
|
||||
asyncio.get_event_loop().run_until_complete(_seed(db))
|
||||
items, _ = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_artists_paginated(limit=20, offset=0, sort_by="name", sort_order="asc")
|
||||
)
|
||||
names = [i.get("name", "") for i in items]
|
||||
assert names == sorted(names, key=str.casefold)
|
||||
|
||||
|
||||
def test_artists_sort_by_album_count_desc(db: LibraryDB):
|
||||
asyncio.get_event_loop().run_until_complete(_seed(db))
|
||||
items, _ = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_artists_paginated(limit=20, offset=0, sort_by="album_count", sort_order="desc")
|
||||
)
|
||||
counts = [i.get("album_count", 0) for i in items]
|
||||
assert counts == sorted(counts, reverse=True)
|
||||
|
||||
|
||||
def test_artists_search(db: LibraryDB):
|
||||
asyncio.get_event_loop().run_until_complete(_seed(db))
|
||||
items, total = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_artists_paginated(limit=20, offset=0, search="Artist B")
|
||||
)
|
||||
assert total > 0
|
||||
assert all("Artist B" in i.get("name", "") for i in items)
|
||||
|
||||
|
||||
def test_artists_search_no_results(db: LibraryDB):
|
||||
asyncio.get_event_loop().run_until_complete(_seed(db))
|
||||
items, total = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_artists_paginated(limit=10, offset=0, search="zzz_no_match_zzz")
|
||||
)
|
||||
assert total == 0
|
||||
assert len(items) == 0
|
||||
|
||||
|
||||
def test_artists_empty_library(db: LibraryDB):
|
||||
items, total = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_artists_paginated(limit=10, offset=0)
|
||||
)
|
||||
assert total == 0
|
||||
assert len(items) == 0
|
||||
|
||||
|
||||
# --- Pagination consistency (no duplicates/missing across pages) ---
|
||||
|
||||
|
||||
def test_albums_pagination_no_duplicates(db: LibraryDB):
|
||||
asyncio.get_event_loop().run_until_complete(_seed(db, n_albums=50))
|
||||
all_mbids: list[str] = []
|
||||
offset = 0
|
||||
page_size = 10
|
||||
while True:
|
||||
items, total = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_albums_paginated(
|
||||
limit=page_size, offset=offset, sort_by="title", sort_order="asc"
|
||||
)
|
||||
)
|
||||
if not items:
|
||||
break
|
||||
all_mbids.extend(i.get("mbid", "") for i in items)
|
||||
offset += page_size
|
||||
assert len(all_mbids) == 50
|
||||
assert len(set(all_mbids)) == 50
|
||||
|
||||
|
||||
def test_artists_pagination_no_duplicates(db: LibraryDB):
|
||||
asyncio.get_event_loop().run_until_complete(_seed(db, n_albums=10, n_artists=30))
|
||||
all_mbids: list[str] = []
|
||||
offset = 0
|
||||
page_size = 7
|
||||
while True:
|
||||
items, total = asyncio.get_event_loop().run_until_complete(
|
||||
db.get_artists_paginated(
|
||||
limit=page_size, offset=offset, sort_by="name", sort_order="asc"
|
||||
)
|
||||
)
|
||||
if not items:
|
||||
break
|
||||
all_mbids.extend(i.get("mbid", "") for i in items)
|
||||
offset += page_size
|
||||
assert len(all_mbids) == 30
|
||||
assert len(set(all_mbids)) == 30
|
||||
@@ -0,0 +1,117 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from infrastructure.msgspec_fastapi import (
|
||||
AppStruct,
|
||||
MsgSpecBody,
|
||||
MsgSpecJSONRequest,
|
||||
MsgSpecJSONResponse,
|
||||
MsgSpecRoute,
|
||||
_contains_msgspec_struct,
|
||||
_merge_response_schema,
|
||||
)
|
||||
|
||||
|
||||
class SamplePayload(AppStruct):
|
||||
value: int
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_msgspec_json_request_caches_decoded_body():
|
||||
body = b'{"value": 7}'
|
||||
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
request = MsgSpecJSONRequest(
|
||||
{
|
||||
"type": "http",
|
||||
"http_version": "1.1",
|
||||
"method": "POST",
|
||||
"path": "/",
|
||||
"raw_path": b"/",
|
||||
"scheme": "http",
|
||||
"headers": [],
|
||||
"query_string": b"",
|
||||
"client": ("testclient", 123),
|
||||
"server": ("testserver", 80),
|
||||
},
|
||||
receive,
|
||||
)
|
||||
|
||||
first = await request.json()
|
||||
second = await request.json()
|
||||
|
||||
assert first == {"value": 7}
|
||||
assert second == first
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_msgspec_json_request_raises_json_decode_error():
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": b"{", "more_body": False}
|
||||
|
||||
request = MsgSpecJSONRequest(
|
||||
{
|
||||
"type": "http",
|
||||
"http_version": "1.1",
|
||||
"method": "POST",
|
||||
"path": "/",
|
||||
"raw_path": b"/",
|
||||
"scheme": "http",
|
||||
"headers": [],
|
||||
"query_string": b"",
|
||||
"client": ("testclient", 123),
|
||||
"server": ("testserver", 80),
|
||||
},
|
||||
receive,
|
||||
)
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
await request.json()
|
||||
|
||||
|
||||
def test_app_struct_iteration_and_json_response_render():
|
||||
payload = SamplePayload(value=5)
|
||||
|
||||
assert dict(payload) == {"value": 5}
|
||||
|
||||
response = MsgSpecJSONResponse(content=payload)
|
||||
assert response.body == b'{"value":5}'
|
||||
|
||||
|
||||
def test_msgspec_body_and_route_work_with_fastapi():
|
||||
app = FastAPI()
|
||||
router = APIRouter(route_class=MsgSpecRoute)
|
||||
|
||||
@router.post("/items", response_model=SamplePayload)
|
||||
async def create_item(body: SamplePayload = MsgSpecBody(SamplePayload)):
|
||||
return body
|
||||
|
||||
app.include_router(router)
|
||||
client = TestClient(app)
|
||||
|
||||
ok = client.post("/items", json={"value": 11})
|
||||
assert ok.status_code == 200
|
||||
assert ok.json() == {"value": 11}
|
||||
|
||||
bad = client.post("/items", json={"value": "nope"})
|
||||
assert bad.status_code == 422
|
||||
|
||||
|
||||
def test_contains_msgspec_struct_and_merge_response_schema():
|
||||
assert _contains_msgspec_struct(SamplePayload) is True
|
||||
assert _contains_msgspec_struct(list[SamplePayload]) is True
|
||||
assert _contains_msgspec_struct(str | None) is False
|
||||
|
||||
merged = _merge_response_schema(
|
||||
{"responses": {"200": {"description": "ok"}}},
|
||||
{"type": "object", "properties": {"value": {"type": "integer"}}},
|
||||
)
|
||||
|
||||
schema = merged["responses"]["200"]["content"]["application/json"]["schema"]
|
||||
assert schema["type"] == "object"
|
||||
assert "value" in schema["properties"]
|
||||
@@ -0,0 +1,110 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from infrastructure.queue.queue_store import QueueStore
|
||||
from infrastructure.queue.request_queue import RequestQueue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(tmp_path: Path) -> QueueStore:
|
||||
return QueueStore(db_path=tmp_path / "test_queue.db")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jobs_survive_restart(store: QueueStore):
|
||||
processed = []
|
||||
|
||||
async def slow_processor(mbid: str) -> dict:
|
||||
await asyncio.sleep(100)
|
||||
processed.append(mbid)
|
||||
return {"status": "ok"}
|
||||
|
||||
q1 = RequestQueue(processor=slow_processor, store=store)
|
||||
await q1.start()
|
||||
|
||||
store.enqueue("job-1", "mbid-abc")
|
||||
store.mark_processing("job-1")
|
||||
|
||||
q1._processor_task.cancel()
|
||||
try:
|
||||
await q1._processor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
fast_processed = []
|
||||
|
||||
async def fast_processor(mbid: str) -> dict:
|
||||
fast_processed.append(mbid)
|
||||
return {"status": "ok"}
|
||||
|
||||
q2 = RequestQueue(processor=fast_processor, store=store)
|
||||
await q2.start()
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
assert "mbid-abc" in fast_processed
|
||||
await q2.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_job_lands_in_dead_letter(store: QueueStore):
|
||||
async def failing_processor(mbid: str) -> dict:
|
||||
raise ValueError("Lidarr is down")
|
||||
|
||||
q = RequestQueue(processor=failing_processor, store=store)
|
||||
await q.start()
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(q.add("mbid-fail"), timeout=2.0)
|
||||
except (ValueError, asyncio.TimeoutError):
|
||||
pass
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
assert store.get_dead_letter_count() >= 1
|
||||
await q.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dead_letter_retry_on_restart(store: QueueStore):
|
||||
store.add_dead_letter("dlj-1", "mbid-retry", "old error", retry_count=1, max_retries=3)
|
||||
|
||||
processed = []
|
||||
|
||||
async def processor(mbid: str) -> dict:
|
||||
processed.append(mbid)
|
||||
return {"status": "ok"}
|
||||
|
||||
q = RequestQueue(processor=processor, store=store)
|
||||
await q.start()
|
||||
await asyncio.sleep(0.5)
|
||||
assert "mbid-retry" in processed
|
||||
await q.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_job_removed_from_store(store: QueueStore):
|
||||
async def ok_processor(mbid: str) -> dict:
|
||||
return {"status": "ok"}
|
||||
|
||||
q = RequestQueue(processor=ok_processor, store=store)
|
||||
await q.start()
|
||||
|
||||
await asyncio.wait_for(q.add("mbid-ok"), timeout=2.0)
|
||||
assert len(store.get_all()) == 0
|
||||
await q.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exhausted_dead_letter_not_retried(store: QueueStore):
|
||||
store.add_dead_letter("dlj-ex", "mbid-exhausted", "fatal", retry_count=3, max_retries=3)
|
||||
|
||||
processed = []
|
||||
|
||||
async def processor(mbid: str) -> dict:
|
||||
processed.append(mbid)
|
||||
return {"status": "ok"}
|
||||
|
||||
q = RequestQueue(processor=processor, store=store)
|
||||
await q.start()
|
||||
await asyncio.sleep(0.3)
|
||||
assert "mbid-exhausted" not in processed
|
||||
await q.stop()
|
||||
@@ -0,0 +1,88 @@
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from infrastructure.queue.queue_store import QueueStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(tmp_path: Path) -> QueueStore:
|
||||
return QueueStore(db_path=tmp_path / "test_queue.db")
|
||||
|
||||
|
||||
def test_enqueue_and_get_pending(store: QueueStore):
|
||||
store.enqueue("j1", "mbid-1")
|
||||
store.enqueue("j2", "mbid-2")
|
||||
store.enqueue("j3", "mbid-3")
|
||||
assert len(store.get_pending()) == 3
|
||||
|
||||
|
||||
def test_dequeue_removes_job(store: QueueStore):
|
||||
store.enqueue("j1", "mbid-1")
|
||||
store.dequeue("j1")
|
||||
assert len(store.get_pending()) == 0
|
||||
|
||||
|
||||
def test_duplicate_enqueue_ignored(store: QueueStore):
|
||||
assert store.enqueue("j1", "mbid-1") is True
|
||||
assert store.enqueue("j2", "mbid-1") is False
|
||||
assert len(store.get_pending()) == 1
|
||||
|
||||
|
||||
def test_mark_processing(store: QueueStore):
|
||||
store.enqueue("j1", "mbid-1")
|
||||
store.mark_processing("j1")
|
||||
assert len(store.get_pending()) == 0
|
||||
assert len(store.get_all()) == 1
|
||||
|
||||
|
||||
def test_reset_processing(store: QueueStore):
|
||||
store.enqueue("j1", "mbid-1")
|
||||
store.mark_processing("j1")
|
||||
store.reset_processing()
|
||||
assert len(store.get_pending()) == 1
|
||||
|
||||
|
||||
def test_add_dead_letter_retryable(store: QueueStore):
|
||||
store.add_dead_letter("j1", "mbid-1", "error", retry_count=1, max_retries=3)
|
||||
retryable = store.get_retryable_dead_letters()
|
||||
assert len(retryable) == 1
|
||||
assert retryable[0]["album_mbid"] == "mbid-1"
|
||||
|
||||
|
||||
def test_add_dead_letter_exhausted(store: QueueStore):
|
||||
store.add_dead_letter("j1", "mbid-1", "error", retry_count=3, max_retries=3)
|
||||
assert len(store.get_retryable_dead_letters()) == 0
|
||||
|
||||
|
||||
def test_remove_dead_letter(store: QueueStore):
|
||||
store.add_dead_letter("j1", "mbid-1", "error", retry_count=1, max_retries=3)
|
||||
store.remove_dead_letter("j1")
|
||||
assert len(store.get_retryable_dead_letters()) == 0
|
||||
|
||||
|
||||
def test_update_dead_letter_attempt(store: QueueStore):
|
||||
store.add_dead_letter("j1", "mbid-1", "error1", retry_count=1, max_retries=3)
|
||||
store.update_dead_letter_attempt("j1", "error2", retry_count=3)
|
||||
assert len(store.get_retryable_dead_letters()) == 0
|
||||
assert store.get_dead_letter_count() == 1
|
||||
|
||||
|
||||
def test_get_dead_letter_count(store: QueueStore):
|
||||
store.add_dead_letter("j1", "mbid-1", "e1", 1, 3)
|
||||
store.add_dead_letter("j2", "mbid-2", "e2", 1, 3)
|
||||
store.add_dead_letter("j3", "mbid-3", "e3", 1, 3)
|
||||
assert store.get_dead_letter_count() == 3
|
||||
|
||||
|
||||
def test_has_pending_mbid(store: QueueStore):
|
||||
assert store.has_pending_mbid("mbid-1") is False
|
||||
store.enqueue("j1", "mbid-1")
|
||||
assert store.has_pending_mbid("mbid-1") is True
|
||||
store.mark_processing("j1")
|
||||
assert store.has_pending_mbid("mbid-1") is False
|
||||
store.dequeue("j1")
|
||||
assert store.has_pending_mbid("mbid-1") is False
|
||||
|
||||
|
||||
def test_enqueue_returns_bool(store: QueueStore):
|
||||
assert store.enqueue("j1", "mbid-1") is True
|
||||
assert store.enqueue("j1", "mbid-1") is False
|
||||
@@ -0,0 +1,199 @@
|
||||
"""Tests that non_breaking_exceptions bypass circuit breaker failure recording."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.resilience.retry import (
|
||||
CircuitBreaker,
|
||||
CircuitOpenError,
|
||||
CircuitState,
|
||||
with_retry,
|
||||
)
|
||||
|
||||
|
||||
class _RateLimited(Exception):
|
||||
def __init__(self, retry_after: float = 1.0):
|
||||
super().__init__("rate limited")
|
||||
self.retry_after_seconds = retry_after
|
||||
|
||||
|
||||
class _ServiceDown(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_breaking_exception_does_not_trip_circuit():
|
||||
cb = CircuitBreaker(failure_threshold=3, name="test-non-breaking")
|
||||
call_count = 0
|
||||
|
||||
@with_retry(
|
||||
max_attempts=4,
|
||||
base_delay=0.01,
|
||||
max_delay=0.05,
|
||||
circuit_breaker=cb,
|
||||
retriable_exceptions=(_RateLimited, _ServiceDown),
|
||||
non_breaking_exceptions=(_RateLimited,),
|
||||
)
|
||||
async def flaky():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 4:
|
||||
raise _RateLimited(retry_after=0.01)
|
||||
return "ok"
|
||||
|
||||
result = await flaky()
|
||||
|
||||
assert result == "ok"
|
||||
assert call_count == 4
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
assert cb.failure_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_breaking_exception_still_trips_circuit():
|
||||
cb = CircuitBreaker(failure_threshold=2, name="test-breaking")
|
||||
call_count = 0
|
||||
|
||||
@with_retry(
|
||||
max_attempts=3,
|
||||
base_delay=0.01,
|
||||
max_delay=0.05,
|
||||
circuit_breaker=cb,
|
||||
retriable_exceptions=(_RateLimited, _ServiceDown),
|
||||
non_breaking_exceptions=(_RateLimited,),
|
||||
)
|
||||
async def fail():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise _ServiceDown("down")
|
||||
|
||||
with pytest.raises(_ServiceDown):
|
||||
await fail()
|
||||
|
||||
assert call_count == 3
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_breaking_uses_retry_after_for_delay():
|
||||
cb = CircuitBreaker(failure_threshold=5, name="test-retry-after")
|
||||
call_count = 0
|
||||
|
||||
@with_retry(
|
||||
max_attempts=2,
|
||||
base_delay=100.0,
|
||||
max_delay=100.0,
|
||||
circuit_breaker=cb,
|
||||
retriable_exceptions=(_RateLimited,),
|
||||
non_breaking_exceptions=(_RateLimited,),
|
||||
)
|
||||
async def rate_limited_then_ok():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise _RateLimited(retry_after=0.01)
|
||||
return "ok"
|
||||
|
||||
result = await rate_limited_then_ok()
|
||||
|
||||
assert result == "ok"
|
||||
assert call_count == 2
|
||||
assert cb.failure_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_still_opens_for_real_errors_amid_rate_limits():
|
||||
cb = CircuitBreaker(failure_threshold=2, name="test-mixed")
|
||||
|
||||
@with_retry(
|
||||
max_attempts=1,
|
||||
base_delay=0.01,
|
||||
max_delay=0.05,
|
||||
circuit_breaker=cb,
|
||||
retriable_exceptions=(_RateLimited, _ServiceDown),
|
||||
non_breaking_exceptions=(_RateLimited,),
|
||||
)
|
||||
async def real_failure():
|
||||
raise _ServiceDown("down")
|
||||
|
||||
for _ in range(2):
|
||||
with pytest.raises(_ServiceDown):
|
||||
await real_failure()
|
||||
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
@with_retry(
|
||||
max_attempts=1,
|
||||
base_delay=0.01,
|
||||
max_delay=0.05,
|
||||
circuit_breaker=cb,
|
||||
retriable_exceptions=(_RateLimited, _ServiceDown),
|
||||
non_breaking_exceptions=(_RateLimited,),
|
||||
)
|
||||
async def subsequent_call():
|
||||
return "should not reach"
|
||||
|
||||
with pytest.raises(CircuitOpenError):
|
||||
await subsequent_call()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_breaking_in_half_open_reopens_circuit():
|
||||
"""Non-breaking exceptions in HALF_OPEN must still reopen the circuit."""
|
||||
cb = CircuitBreaker(failure_threshold=2, success_threshold=2, timeout=0.01, name="test-half-open")
|
||||
|
||||
for _ in range(2):
|
||||
cb.record_failure()
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
await asyncio.sleep(0.02)
|
||||
await cb.atry_transition()
|
||||
assert cb.state == CircuitState.HALF_OPEN
|
||||
|
||||
@with_retry(
|
||||
max_attempts=1,
|
||||
base_delay=0.01,
|
||||
max_delay=0.05,
|
||||
circuit_breaker=cb,
|
||||
retriable_exceptions=(_RateLimited,),
|
||||
non_breaking_exceptions=(_RateLimited,),
|
||||
)
|
||||
async def rate_limited_in_half_open():
|
||||
raise _RateLimited(retry_after=0.01)
|
||||
|
||||
with pytest.raises(_RateLimited):
|
||||
await rate_limited_in_half_open()
|
||||
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_after_not_clamped_by_max_delay():
|
||||
"""Server-provided Retry-After should not be clamped by max_delay."""
|
||||
cb = CircuitBreaker(failure_threshold=10, name="test-retry-after-clamp")
|
||||
call_count = 0
|
||||
observed_gap = 0.0
|
||||
|
||||
@with_retry(
|
||||
max_attempts=2,
|
||||
base_delay=0.01,
|
||||
max_delay=0.05,
|
||||
circuit_breaker=cb,
|
||||
retriable_exceptions=(_RateLimited,),
|
||||
non_breaking_exceptions=(_RateLimited,),
|
||||
)
|
||||
async def rate_limited_then_ok():
|
||||
nonlocal call_count, observed_gap
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise _RateLimited(retry_after=0.3)
|
||||
return "ok"
|
||||
|
||||
import time
|
||||
start = time.monotonic()
|
||||
result = await rate_limited_then_ok()
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
assert result == "ok"
|
||||
assert elapsed >= 0.25, f"Expected >=0.25s delay from retry_after=0.3, got {elapsed:.3f}s"
|
||||
@@ -0,0 +1,41 @@
|
||||
import msgspec
|
||||
import pytest
|
||||
|
||||
from infrastructure.serialization import clone_with_updates, to_jsonable
|
||||
|
||||
|
||||
class SampleStruct(msgspec.Struct):
|
||||
value: int
|
||||
name: str = "x"
|
||||
|
||||
|
||||
def test_to_jsonable_struct_and_dict():
|
||||
struct_value = SampleStruct(value=3, name="abc")
|
||||
|
||||
assert to_jsonable(struct_value) == {"value": 3, "name": "abc"}
|
||||
assert to_jsonable({"x": 1}) == {"x": 1}
|
||||
|
||||
|
||||
def test_clone_with_updates_struct():
|
||||
original = SampleStruct(value=1, name="before")
|
||||
|
||||
updated = clone_with_updates(original, {"name": "after"})
|
||||
|
||||
assert isinstance(updated, SampleStruct)
|
||||
assert updated.value == 1
|
||||
assert updated.name == "after"
|
||||
assert original.name == "before"
|
||||
|
||||
|
||||
def test_clone_with_updates_dict():
|
||||
original = {"value": 1, "name": "before"}
|
||||
|
||||
updated = clone_with_updates(original, {"name": "after", "extra": True})
|
||||
|
||||
assert updated == {"value": 1, "name": "after", "extra": True}
|
||||
assert original == {"value": 1, "name": "before"}
|
||||
|
||||
|
||||
def test_clone_with_updates_unsupported_type_raises():
|
||||
with pytest.raises(TypeError):
|
||||
clone_with_updates([1, 2, 3], {"value": 9})
|
||||
Reference in New Issue
Block a user