Initial public release

This commit is contained in:
Harvey
2026-04-03 15:53:00 +01:00
commit a99c738164
679 changed files with 108326 additions and 0 deletions
@@ -0,0 +1,64 @@
"""Tests for validate_audiodb_image_url SSRF protection."""
import pytest
from infrastructure.validators import validate_audiodb_image_url
class TestValidateAudiodbImageUrl:
@pytest.mark.parametrize("url", [
"https://www.theaudiodb.com/images/media/thumb.jpg",
"https://theaudiodb.com/images/media/artist/thumb/coldplay.jpg",
"https://r2.theaudiodb.com/images/album/thumb/parachutes.jpg",
"https://r2.theaudiodb.com/images/artist/fanart/coldplay1.jpg",
])
def test_valid_audiodb_urls(self, url: str) -> None:
assert validate_audiodb_image_url(url) is True
@pytest.mark.parametrize("url", [
"http://www.theaudiodb.com/images/media/thumb.jpg",
"http://r2.theaudiodb.com/images/album/thumb.jpg",
])
def test_rejects_http_scheme(self, url: str) -> None:
assert validate_audiodb_image_url(url) is False
@pytest.mark.parametrize("url", [
"ftp://r2.theaudiodb.com/file.jpg",
"file:///etc/passwd",
"data:text/html,<script>alert(1)</script>",
"javascript:alert(1)",
])
def test_rejects_non_https_schemes(self, url: str) -> None:
assert validate_audiodb_image_url(url) is False
@pytest.mark.parametrize("url", [
"https://evil.com/images/media/thumb.jpg",
"https://theaudiodb.com.evil.com/exploit.jpg",
"https://attacker.theaudiodb.com/images/thumb.jpg",
"https://notaudiodb.com/images/thumb.jpg",
"https://example.com/redirect?url=https://r2.theaudiodb.com/img.jpg",
])
def test_rejects_unknown_hosts(self, url: str) -> None:
assert validate_audiodb_image_url(url) is False
@pytest.mark.parametrize("url", [
"https://127.0.0.1/images/thumb.jpg",
"https://10.0.0.1/images/thumb.jpg",
"https://192.168.1.1/images/thumb.jpg",
"https://[::1]/images/thumb.jpg",
"https://169.254.169.254/latest/meta-data/",
])
def test_rejects_private_and_loopback_ips(self, url: str) -> None:
assert validate_audiodb_image_url(url) is False
@pytest.mark.parametrize("url", [
"",
None,
" ",
"not-a-url",
"://missing-scheme.com",
])
def test_rejects_invalid_inputs(self, url) -> None:
assert validate_audiodb_image_url(url) is False
def test_rejects_url_without_host(self) -> None:
assert validate_audiodb_image_url("https:///path/only") is False
@@ -0,0 +1,303 @@
import json
import sqlite3
import threading
import time
import pytest
from infrastructure.cache.disk_cache import DiskMetadataCache
from infrastructure.persistence.genre_index import GenreIndex
from infrastructure.persistence.library_db import LibraryDB
from infrastructure.persistence.youtube_store import YouTubeStore
def _make_stores(db_path):
lock = threading.Lock()
lib = LibraryDB(db_path=db_path, write_lock=lock)
genre = GenreIndex(db_path=db_path, write_lock=lock)
yt = YouTubeStore(db_path=db_path, write_lock=lock)
# All stores must be initialized so cross-domain DELETEs in save_library/clear succeed
from infrastructure.persistence.mbid_store import MBIDStore
from infrastructure.persistence.sync_state_store import SyncStateStore
SyncStateStore(db_path=db_path, write_lock=lock)
MBIDStore(db_path=db_path, write_lock=lock)
return lib, genre, yt
@pytest.mark.asyncio
async def test_library_cache_genre_queries_use_normalized_lookup(tmp_path):
lib, genre, _ = _make_stores(tmp_path / "library.db")
await lib.save_library(
artists=[
{"mbid": "artist-1", "name": "Artist One", "album_count": 1, "date_added": 10},
{"mbid": "artist-2", "name": "Artist Two", "album_count": 1, "date_added": 20},
],
albums=[
{
"mbid": "album-1",
"artist_mbid": "artist-1",
"artist_name": "Artist One",
"title": "First Album",
"date_added": 100,
"monitored": True,
},
{
"mbid": "album-2",
"artist_mbid": "artist-2",
"artist_name": "Artist Two",
"title": "Second Album",
"date_added": 200,
"monitored": True,
},
],
)
await genre.save_artist_genres(
{
"artist-1": [" Rock ", "Alternative", "rock"],
"artist-2": ["Jazz", "rock"],
}
)
artists = await genre.get_artists_by_genre("ROCK", limit=1)
albums = await genre.get_albums_by_genre(" rock ", limit=2)
assert [artist["mbid"] for artist in artists] == ["artist-2"]
assert [album["mbid"] for album in albums] == ["album-2", "album-1"]
@pytest.mark.asyncio
async def test_library_cache_backfills_genre_lookup_from_existing_json_rows(tmp_path):
db_path = tmp_path / "library.db"
lib, genre, _ = _make_stores(db_path)
await lib.save_library(
artists=[{"mbid": "artist-1", "name": "Artist One", "album_count": 1, "date_added": 10}],
albums=[],
)
conn = sqlite3.connect(db_path)
try:
conn.execute("DELETE FROM artist_genre_lookup")
conn.execute(
"INSERT OR REPLACE INTO artist_genres (artist_mbid_lower, artist_mbid, genres_json) VALUES (?, ?, ?)",
("artist-1", "artist-1", json.dumps(["post-rock"])),
)
conn.commit()
finally:
conn.close()
_, genre2, _ = _make_stores(db_path)
artists = await genre2.get_artists_by_genre("POST-ROCK")
assert [artist["mbid"] for artist in artists] == ["artist-1"]
@pytest.mark.asyncio
async def test_cleanup_expired_covers_removes_expired_cover_payload(tmp_path):
cache = DiskMetadataCache(base_path=tmp_path)
cover_dir = tmp_path / "recent" / "covers"
cover_file = cover_dir / "cover.bin"
meta_file = cover_dir / "cover.meta.json"
cover_file.write_bytes(b"image-bytes")
meta_file.write_text(json.dumps({"expires_at": time.time() - 60, "last_accessed": 1}))
removed = await cache.cleanup_expired_covers()
assert removed == 1
assert not cover_file.exists()
assert not meta_file.exists()
@pytest.mark.asyncio
async def test_enforce_cover_size_limits_evicts_oldest_recent_cover(tmp_path):
cache = DiskMetadataCache(base_path=tmp_path)
cache.recent_covers_max_size_bytes = 6
cover_dir = tmp_path / "recent" / "covers"
old_cover = cover_dir / "old.bin"
new_cover = cover_dir / "new.bin"
old_meta = cover_dir / "old.meta.json"
new_meta = cover_dir / "new.meta.json"
old_cover.write_bytes(b"1234")
new_cover.write_bytes(b"5678")
old_meta.write_text(json.dumps({"last_accessed": 1}))
new_meta.write_text(json.dumps({"last_accessed": 2}))
freed = await cache.enforce_cover_size_limits()
assert freed == 4
assert not old_cover.exists()
assert not old_meta.exists()
assert new_cover.exists()
assert new_meta.exists()
@pytest.mark.asyncio
async def test_library_cache_keeps_youtube_track_links_distinct_per_disc(tmp_path):
_, _, yt = _make_stores(tmp_path / "library.db")
await yt.save_youtube_track_links_batch(
"album-1",
[
{
"track_number": 1,
"disc_number": 1,
"album_name": "Album",
"track_name": "Disc One Track One",
"video_id": "video-1",
"artist_name": "Artist",
"embed_url": "https://example.com/1",
"created_at": "2024-01-01T00:00:00Z",
},
{
"track_number": 1,
"disc_number": 2,
"album_name": "Album",
"track_name": "Disc Two Track One",
"video_id": "video-2",
"artist_name": "Artist",
"embed_url": "https://example.com/2",
"created_at": "2024-01-01T00:00:00Z",
},
],
)
links = await yt.get_youtube_track_links("album-1")
assert [(link["disc_number"], link["track_number"], link["video_id"]) for link in links] == [
(1, 1, "video-1"),
(2, 1, "video-2"),
]
await yt.delete_youtube_track_link("album-1", 2, 1)
remaining = await yt.get_youtube_track_links("album-1")
assert [(link["disc_number"], link["track_number"], link["video_id"]) for link in remaining] == [
(1, 1, "video-1")
]
@pytest.mark.asyncio
async def test_library_cache_migrates_legacy_youtube_track_links_with_default_disc_number(tmp_path):
db_path = tmp_path / "library.db"
conn = sqlite3.connect(db_path)
try:
conn.execute(
"""
CREATE TABLE youtube_track_links (
album_id TEXT NOT NULL,
track_number INTEGER NOT NULL,
album_name TEXT NOT NULL,
track_name TEXT NOT NULL,
video_id TEXT NOT NULL,
artist_name TEXT NOT NULL,
embed_url TEXT NOT NULL,
created_at TEXT NOT NULL,
PRIMARY KEY (album_id, track_number)
)
"""
)
conn.execute(
"""
INSERT INTO youtube_track_links (
album_id, track_number, album_name, track_name,
video_id, artist_name, embed_url, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
"album-legacy",
5,
"Legacy Album",
"Legacy Track",
"legacy-video",
"Artist",
"https://example.com/legacy",
"2024-01-01T00:00:00Z",
),
)
conn.commit()
finally:
conn.close()
_, _, yt = _make_stores(db_path)
links = await yt.get_youtube_track_links("album-legacy")
assert len(links) == 1
assert links[0]["disc_number"] == 1
assert links[0]["track_number"] == 5
@pytest.mark.asyncio
async def test_library_cache_youtube_track_links_uniqueness_allows_same_track_different_disc(
tmp_path,
):
"""Verify the new (album_id, disc_number, track_number) PK allows same track number across discs."""
_, _, yt = _make_stores(tmp_path / "library.db")
await yt.save_youtube_track_link(
album_id="album-uniq",
track_number=1,
disc_number=1,
album_name="Test Album",
track_name="Track One Disc One",
video_id="vid-d1t1",
artist_name="Artist",
embed_url="https://example.com/d1t1",
created_at="2024-01-01T00:00:00Z",
)
await yt.save_youtube_track_link(
album_id="album-uniq",
track_number=1,
disc_number=2,
album_name="Test Album",
track_name="Track One Disc Two",
video_id="vid-d2t1",
artist_name="Artist",
embed_url="https://example.com/d2t1",
created_at="2024-01-01T00:00:00Z",
)
await yt.save_youtube_track_link(
album_id="album-uniq",
track_number=1,
disc_number=1,
album_name="Test Album",
track_name="Updated Track One",
video_id="vid-d1t1-updated",
artist_name="Artist",
embed_url="https://example.com/d1t1-v2",
created_at="2024-01-02T00:00:00Z",
)
links = await yt.get_youtube_track_links("album-uniq")
assert len(links) == 2
d1 = next(link for link in links if link["disc_number"] == 1)
d2 = next(link for link in links if link["disc_number"] == 2)
assert d1["video_id"] == "vid-d1t1-updated"
assert d1["track_name"] == "Updated Track One"
assert d2["video_id"] == "vid-d2t1"
@pytest.mark.asyncio
async def test_library_cache_save_single_youtube_track_link_with_disc_number(tmp_path):
"""Verify save_youtube_track_link (single-row path) correctly stores disc_number."""
_, _, yt = _make_stores(tmp_path / "library.db")
await yt.save_youtube_track_link(
album_id="album-single",
track_number=3,
disc_number=2,
album_name="Single Test",
track_name="Track Three Disc Two",
video_id="vid-single",
artist_name="Artist",
embed_url="https://example.com/single",
created_at="2024-06-15T00:00:00Z",
)
links = await yt.get_youtube_track_links("album-single")
assert len(links) == 1
assert links[0]["disc_number"] == 2
assert links[0]["track_number"] == 3
assert links[0]["video_id"] == "vid-single"
@@ -0,0 +1,118 @@
import asyncio
import time
import pytest
from infrastructure.resilience.retry import CircuitBreaker, CircuitState
@pytest.mark.asyncio
async def test_concurrent_arecord_failure_does_not_overcount():
cb = CircuitBreaker(failure_threshold=5, name="test-overcount")
async def fail_once():
await cb.arecord_failure()
await asyncio.gather(*[fail_once() for _ in range(10)])
assert cb.failure_count <= 10
assert cb.state == CircuitState.OPEN
@pytest.mark.asyncio
async def test_concurrent_arecord_success_transitions_half_open_to_closed():
cb = CircuitBreaker(failure_threshold=3, success_threshold=2, name="test-success-transition")
for _ in range(3):
cb.record_failure()
assert cb.state == CircuitState.OPEN
cb.state = CircuitState.HALF_OPEN
cb.success_count = 0
async def succeed_once():
await cb.arecord_success()
await asyncio.gather(*[succeed_once() for _ in range(10)])
assert cb.state == CircuitState.CLOSED
assert cb.failure_count == 0
assert cb.success_count == 0
@pytest.mark.asyncio
async def test_atry_transition_open_to_half_open():
cb = CircuitBreaker(failure_threshold=3, timeout=0.0, name="test-transition")
for _ in range(3):
cb.record_failure()
assert cb.state == CircuitState.OPEN
await asyncio.sleep(0.01)
await cb.atry_transition()
assert cb.state == CircuitState.HALF_OPEN
assert cb.success_count == 0
@pytest.mark.asyncio
async def test_atry_transition_noop_when_not_open():
cb = CircuitBreaker(name="test-noop")
assert cb.state == CircuitState.CLOSED
await cb.atry_transition()
assert cb.state == CircuitState.CLOSED
@pytest.mark.asyncio
async def test_atry_transition_noop_when_timeout_not_elapsed():
cb = CircuitBreaker(failure_threshold=3, timeout=60.0, name="test-timeout-not-elapsed")
for _ in range(3):
cb.record_failure()
assert cb.state == CircuitState.OPEN
await cb.atry_transition()
assert cb.state == CircuitState.OPEN
@pytest.mark.asyncio
async def test_sync_methods_still_work_without_await():
cb = CircuitBreaker(failure_threshold=3, success_threshold=1, name="test-sync-compat")
cb.record_failure()
cb.record_failure()
assert cb.state == CircuitState.CLOSED
assert cb.failure_count == 2
cb.record_failure()
assert cb.state == CircuitState.OPEN
cb.reset()
assert cb.state == CircuitState.CLOSED
assert cb.failure_count == 0
cb.record_success()
assert cb.failure_count == 0
@pytest.mark.asyncio
async def test_concurrent_atry_transition_only_transitions_once():
cb = CircuitBreaker(failure_threshold=3, timeout=0.0, name="test-double-transition")
for _ in range(3):
cb.record_failure()
assert cb.state == CircuitState.OPEN
await asyncio.sleep(0.01)
transition_states = []
async def try_transition():
await cb.atry_transition()
transition_states.append(cb.state)
await asyncio.gather(*[try_transition() for _ in range(5)])
assert cb.state == CircuitState.HALF_OPEN
assert all(s == CircuitState.HALF_OPEN for s in transition_states)
@@ -0,0 +1,126 @@
"""Tests for DegradationContext and contextvar lifecycle."""
import asyncio
import pytest
from infrastructure.degradation import (
DegradationContext,
clear_degradation_context,
get_degradation_context,
init_degradation_context,
try_get_degradation_context,
)
from infrastructure.integration_result import IntegrationResult
class TestDegradationContext:
def test_empty_context(self):
ctx = DegradationContext()
assert ctx.summary() == {}
assert ctx.has_degradation() is False
assert ctx.degraded_summary() == {}
def test_record_ok(self):
ctx = DegradationContext()
ctx.record(IntegrationResult.ok(data=[1], source="musicbrainz"))
assert ctx.summary() == {"musicbrainz": "ok"}
assert ctx.has_degradation() is False
def test_record_error(self):
ctx = DegradationContext()
ctx.record(IntegrationResult.error(source="jellyfin", msg="timeout"))
assert ctx.summary() == {"jellyfin": "error"}
assert ctx.has_degradation() is True
assert ctx.degraded_summary() == {"jellyfin": "error"}
def test_record_degraded(self):
ctx = DegradationContext()
ctx.record(
IntegrationResult.degraded(data=[], source="audiodb", msg="rate limit")
)
assert ctx.summary() == {"audiodb": "degraded"}
assert ctx.has_degradation() is True
def test_worst_status_wins(self):
ctx = DegradationContext()
ctx.record(IntegrationResult.ok(data=[], source="musicbrainz"))
ctx.record(IntegrationResult.degraded(data=[], source="musicbrainz", msg="slow"))
assert ctx.summary() == {"musicbrainz": "degraded"}
def test_error_beats_degraded(self):
ctx = DegradationContext()
ctx.record(
IntegrationResult.degraded(data=[], source="musicbrainz", msg="slow")
)
ctx.record(IntegrationResult.error(source="musicbrainz", msg="503"))
assert ctx.summary() == {"musicbrainz": "error"}
def test_cannot_downgrade(self):
ctx = DegradationContext()
ctx.record(IntegrationResult.error(source="jellyfin", msg="down"))
ctx.record(IntegrationResult.ok(data=[1], source="jellyfin"))
assert ctx.summary() == {"jellyfin": "error"}
def test_multiple_sources(self):
ctx = DegradationContext()
ctx.record(IntegrationResult.ok(data=[], source="musicbrainz"))
ctx.record(IntegrationResult.error(source="jellyfin", msg="down"))
ctx.record(
IntegrationResult.degraded(data={}, source="audiodb", msg="slow")
)
assert ctx.summary() == {
"musicbrainz": "ok",
"jellyfin": "error",
"audiodb": "degraded",
}
assert ctx.degraded_summary() == {
"jellyfin": "error",
"audiodb": "degraded",
}
class TestContextVarLifecycle:
def test_no_context_raises(self):
clear_degradation_context()
with pytest.raises(RuntimeError, match="outside a request scope"):
get_degradation_context()
def test_try_get_returns_none_outside(self):
clear_degradation_context()
assert try_get_degradation_context() is None
def test_init_and_get(self):
ctx = init_degradation_context()
assert get_degradation_context() is ctx
clear_degradation_context()
def test_clear_removes_context(self):
init_degradation_context()
clear_degradation_context()
assert try_get_degradation_context() is None
@pytest.mark.asyncio
async def test_isolated_across_tasks(self):
"""Context in one asyncio task must not leak into another."""
results: dict[str, bool] = {}
async def task_a():
init_degradation_context()
ctx = get_degradation_context()
ctx.record(IntegrationResult.error(source="a", msg="fail"))
await asyncio.sleep(0.01)
results["a_has_degradation"] = get_degradation_context().has_degradation()
clear_degradation_context()
async def task_b():
await asyncio.sleep(0.005)
results["b_is_none"] = try_get_degradation_context() is None
await asyncio.gather(task_a(), task_b())
assert results["a_has_degradation"] is True
assert results["b_is_none"] is True
@@ -0,0 +1,62 @@
import asyncio
import pytest
from unittest.mock import AsyncMock
from core.exceptions import ClientDisconnectedError
from infrastructure.http.disconnect import check_disconnected
from infrastructure.http.deduplication import RequestDeduplicator
@pytest.mark.anyio
async def test_check_disconnected_raises_when_disconnected():
is_disconnected = AsyncMock(return_value=True)
with pytest.raises(ClientDisconnectedError):
await check_disconnected(is_disconnected)
assert is_disconnected.await_count == 1
@pytest.mark.anyio
async def test_check_disconnected_noop_when_connected():
is_disconnected = AsyncMock(return_value=False)
await check_disconnected(is_disconnected)
assert is_disconnected.await_count == 1
@pytest.mark.anyio
async def test_check_disconnected_noop_when_none():
await check_disconnected(None)
@pytest.mark.anyio
async def test_dedup_leader_disconnect_follower_retries_as_leader():
dedup = RequestDeduplicator()
follower_registered = asyncio.Event()
leader_error = None
expected_result = ("image-bytes", "image/png", "source")
async def leader_coro():
await follower_registered.wait()
raise ClientDisconnectedError("leader disconnected")
async def run_leader():
nonlocal leader_error
try:
await dedup.dedupe("key1", leader_coro)
except ClientDisconnectedError as e:
leader_error = e
async def follower_coro():
return expected_result
async def run_follower():
await asyncio.sleep(0)
follower_registered.set()
return await dedup.dedupe("key1", follower_coro)
leader_task = asyncio.create_task(run_leader())
await asyncio.sleep(0)
follower_task = asyncio.create_task(run_follower())
await asyncio.gather(leader_task, follower_task)
assert isinstance(leader_error, ClientDisconnectedError)
assert follower_task.result() == expected_result
@@ -0,0 +1,76 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock
import pytest
from core.tasks import cleanup_disk_cache_periodically
@pytest.mark.asyncio
async def test_periodic_cleanup_calls_both_caches():
disk_cache = AsyncMock()
cover_disk_cache = AsyncMock()
iteration_count = 0
original_cleanup = cleanup_disk_cache_periodically
async def run_one_iteration():
nonlocal iteration_count
task = asyncio.create_task(
original_cleanup(disk_cache, interval=0, cover_disk_cache=cover_disk_cache)
)
await asyncio.sleep(0.05)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
await run_one_iteration()
disk_cache.cleanup_expired_recent.assert_called()
disk_cache.enforce_recent_size_limits.assert_called()
disk_cache.cleanup_expired_covers.assert_called()
disk_cache.enforce_cover_size_limits.assert_called()
cover_disk_cache.enforce_size_limit.assert_called_with(force=True)
@pytest.mark.asyncio
async def test_periodic_cleanup_works_without_cover_cache():
disk_cache = AsyncMock()
task = asyncio.create_task(
cleanup_disk_cache_periodically(disk_cache, interval=0, cover_disk_cache=None)
)
await asyncio.sleep(0.05)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
disk_cache.cleanup_expired_recent.assert_called()
disk_cache.enforce_recent_size_limits.assert_called()
disk_cache.cleanup_expired_covers.assert_called()
disk_cache.enforce_cover_size_limits.assert_called()
@pytest.mark.asyncio
async def test_periodic_cleanup_continues_on_cover_cache_error():
disk_cache = AsyncMock()
cover_disk_cache = AsyncMock()
cover_disk_cache.enforce_size_limit.side_effect = [RuntimeError("disk full"), None]
task = asyncio.create_task(
cleanup_disk_cache_periodically(disk_cache, interval=0, cover_disk_cache=cover_disk_cache)
)
await asyncio.sleep(0.1)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
assert cover_disk_cache.enforce_size_limit.call_count >= 1
assert disk_cache.cleanup_expired_recent.call_count >= 1
@@ -0,0 +1,172 @@
import hashlib
import json
import pytest
from api.v1.schemas.album import AlbumInfo
from infrastructure.cache.disk_cache import DiskMetadataCache
from repositories.audiodb_models import AudioDBArtistImages, AudioDBAlbumImages
@pytest.mark.asyncio
async def test_set_album_serializes_msgspec_struct_as_mapping(tmp_path):
cache = DiskMetadataCache(base_path=tmp_path)
mbid = "4549a80c-efe6-4386-b3a2-4b4a918eb31f"
album_info = AlbumInfo(
title="The Moon Song",
musicbrainz_id=mbid,
artist_name="beabadoobee",
artist_id="88d17133-abbc-42db-9526-4e2c1db60336",
in_library=True,
)
await cache.set_album(mbid, album_info, is_monitored=True)
cache_hash = hashlib.sha1(mbid.encode()).hexdigest()
cache_file = tmp_path / "persistent" / "albums" / f"{cache_hash}.json"
payload = json.loads(cache_file.read_text())
assert isinstance(payload, dict)
assert payload["musicbrainz_id"] == mbid
cached = await cache.get_album(mbid)
assert isinstance(cached, dict)
assert cached["title"] == "The Moon Song"
@pytest.mark.asyncio
async def test_get_album_deletes_corrupt_string_payload(tmp_path):
cache = DiskMetadataCache(base_path=tmp_path)
mbid = "8e1e9e51-38dc-4df3-8027-a0ada37d4674"
cache_hash = hashlib.sha1(mbid.encode()).hexdigest()
cache_file = tmp_path / "persistent" / "albums" / f"{cache_hash}.json"
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps("AlbumInfo(title='Corrupt')"))
cached = await cache.get_album(mbid)
assert cached is None
assert not cache_file.exists()
@pytest.mark.asyncio
async def test_audiodb_artist_entity_routing(tmp_path):
cache = DiskMetadataCache(base_path=tmp_path)
mbid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
images = AudioDBArtistImages(
thumb_url="https://example.com/thumb.jpg",
fanart_url="https://example.com/fanart.jpg",
lookup_source="mbid",
matched_mbid=mbid,
)
await cache._set_entity("audiodb_artist", mbid, images, is_monitored=False, ttl_seconds=None)
result = await cache._get_entity("audiodb_artist", mbid)
assert result is not None
assert result["thumb_url"] == "https://example.com/thumb.jpg"
assert result["fanart_url"] == "https://example.com/fanart.jpg"
assert result["lookup_source"] == "mbid"
cache_hash = hashlib.sha1(mbid.encode()).hexdigest()
data_file = tmp_path / "recent" / "audiodb_artists" / f"{cache_hash}.json"
assert data_file.exists()
@pytest.mark.asyncio
async def test_audiodb_album_entity_routing(tmp_path):
cache = DiskMetadataCache(base_path=tmp_path)
mbid = "b2c3d4e5-f6a7-8901-bcde-f12345678901"
images = AudioDBAlbumImages(
album_thumb_url="https://example.com/album_thumb.jpg",
album_back_url="https://example.com/album_back.jpg",
lookup_source="name",
matched_mbid=mbid,
)
await cache._set_entity("audiodb_album", mbid, images, is_monitored=True, ttl_seconds=None)
result = await cache._get_entity("audiodb_album", mbid)
assert result is not None
assert result["album_thumb_url"] == "https://example.com/album_thumb.jpg"
assert result["album_back_url"] == "https://example.com/album_back.jpg"
assert result["lookup_source"] == "name"
cache_hash = hashlib.sha1(mbid.encode()).hexdigest()
persistent_file = tmp_path / "persistent" / "audiodb_albums" / f"{cache_hash}.json"
assert persistent_file.exists()
@pytest.mark.asyncio
async def test_get_stats_counts_audiodb_entries(tmp_path):
cache = DiskMetadataCache(base_path=tmp_path)
artist_images = AudioDBArtistImages(thumb_url="https://example.com/a.jpg")
album_images = AudioDBAlbumImages(album_thumb_url="https://example.com/b.jpg")
await cache._set_entity("audiodb_artist", "artist-1", artist_images, is_monitored=False, ttl_seconds=None)
await cache._set_entity("audiodb_artist", "artist-2", artist_images, is_monitored=True, ttl_seconds=None)
await cache._set_entity("audiodb_album", "album-1", album_images, is_monitored=False, ttl_seconds=None)
stats = cache.get_stats()
assert stats["audiodb_artist_count"] == 2
assert stats["audiodb_album_count"] == 1
assert stats["album_count"] == 0
assert stats["artist_count"] == 0
assert stats["total_count"] == 3
@pytest.mark.asyncio
async def test_clear_audiodb_isolates_from_other_entities(tmp_path):
cache = DiskMetadataCache(base_path=tmp_path)
album_mbid = "c3d4e5f6-a7b8-9012-cdef-123456789012"
album_info = AlbumInfo(
title="Regular Album",
musicbrainz_id=album_mbid,
artist_name="Test Artist",
artist_id="d4e5f6a7-b8c9-0123-defa-234567890123",
in_library=False,
)
await cache.set_album(album_mbid, album_info, is_monitored=False)
artist_images = AudioDBArtistImages(thumb_url="https://example.com/thumb.jpg")
album_images = AudioDBAlbumImages(album_thumb_url="https://example.com/album.jpg")
await cache._set_entity("audiodb_artist", "adb-artist-1", artist_images, is_monitored=False, ttl_seconds=None)
await cache._set_entity("audiodb_album", "adb-album-1", album_images, is_monitored=True, ttl_seconds=None)
stats_before = cache.get_stats()
assert stats_before["audiodb_artist_count"] == 1
assert stats_before["audiodb_album_count"] == 1
assert stats_before["album_count"] == 1
await cache.clear_audiodb()
stats_after = cache.get_stats()
assert stats_after["audiodb_artist_count"] == 0
assert stats_after["audiodb_album_count"] == 0
assert stats_after["album_count"] == 1
regular_album = await cache.get_album(album_mbid)
assert regular_album is not None
assert regular_album["title"] == "Regular Album"
@pytest.mark.asyncio
async def test_audiodb_monitored_persistent_vs_recent(tmp_path):
cache = DiskMetadataCache(base_path=tmp_path)
mbid = "e5f6a7b8-c9d0-1234-efab-567890123456"
images = AudioDBArtistImages(thumb_url="https://example.com/t.jpg")
await cache._set_entity("audiodb_artist", mbid, images, is_monitored=True, ttl_seconds=None)
cache_hash = hashlib.sha1(mbid.encode()).hexdigest()
persistent_file = tmp_path / "persistent" / "audiodb_artists" / f"{cache_hash}.json"
recent_file = tmp_path / "recent" / "audiodb_artists" / f"{cache_hash}.json"
assert persistent_file.exists()
assert not recent_file.exists()
await cache._set_entity("audiodb_artist", mbid, images, is_monitored=False, ttl_seconds=None)
assert not persistent_file.exists()
assert recent_file.exists()
@@ -0,0 +1,114 @@
"""Tests for IntegrationResult and aggregate_status."""
import pytest
from infrastructure.integration_result import (
IntegrationResult,
aggregate_status,
)
class TestIntegrationResultOk:
def test_ok_carries_data(self):
r = IntegrationResult.ok(data=["a", "b"], source="musicbrainz")
assert r.data == ["a", "b"]
assert r.source == "musicbrainz"
assert r.status == "ok"
assert r.error_message is None
def test_is_ok_true(self):
r = IntegrationResult.ok(data=42, source="jellyfin")
assert r.is_ok is True
assert r.is_degraded is False
assert r.is_error is False
class TestIntegrationResultDegraded:
def test_degraded_carries_partial_data(self):
r = IntegrationResult.degraded(
data={"stale": True}, source="audiodb", msg="rate limited"
)
assert r.data == {"stale": True}
assert r.source == "audiodb"
assert r.status == "degraded"
assert r.error_message == "rate limited"
def test_is_degraded_true(self):
r = IntegrationResult.degraded(data=[], source="lastfm", msg="timeout")
assert r.is_degraded is True
assert r.is_ok is False
assert r.is_error is False
class TestIntegrationResultError:
def test_error_has_no_data(self):
r = IntegrationResult.error(source="musicbrainz", msg="503 Service Unavailable")
assert r.data is None
assert r.source == "musicbrainz"
assert r.status == "error"
assert r.error_message == "503 Service Unavailable"
def test_is_error_true(self):
r = IntegrationResult.error(source="wikidata", msg="boom")
assert r.is_error is True
assert r.is_ok is False
assert r.is_degraded is False
class TestDataOr:
def test_returns_data_when_present(self):
r = IntegrationResult.ok(data=[1, 2, 3], source="mb")
assert r.data_or([]) == [1, 2, 3]
def test_returns_default_when_none(self):
r = IntegrationResult.error(source="mb", msg="down")
assert r.data_or([]) == []
def test_returns_data_for_degraded(self):
r = IntegrationResult.degraded(data={"partial": True}, source="mb", msg="slow")
assert r.data_or({}) == {"partial": True}
class TestImmutability:
def test_frozen(self):
r = IntegrationResult.ok(data="hello", source="test")
with pytest.raises(AttributeError):
r.data = "goodbye" # type: ignore[misc]
class TestAggregateStatus:
def test_all_ok(self):
assert aggregate_status(
IntegrationResult.ok(1, "a"),
IntegrationResult.ok(2, "b"),
) == "ok"
def test_one_degraded(self):
assert aggregate_status(
IntegrationResult.ok(1, "a"),
IntegrationResult.degraded(2, "b", "slow"),
) == "degraded"
def test_one_error(self):
assert aggregate_status(
IntegrationResult.ok(1, "a"),
IntegrationResult.degraded(2, "b", "slow"),
IntegrationResult.error("c", "down"),
) == "error"
def test_empty(self):
assert aggregate_status() == "ok"
def test_error_short_circuits(self):
assert aggregate_status(
IntegrationResult.error("a", "x"),
IntegrationResult.ok(1, "b"),
) == "error"
@@ -0,0 +1,304 @@
"""Tests for LibraryDB paginated query methods."""
import asyncio
import threading
from pathlib import Path
import pytest
from infrastructure.persistence.library_db import LibraryDB
@pytest.fixture
def db(tmp_path: Path) -> LibraryDB:
return LibraryDB(db_path=tmp_path / "test.db", write_lock=threading.Lock())
def _make_albums(count: int, *, start: int = 1) -> list[dict]:
"""Generate album dicts with predictable, sortable data."""
albums = []
for i in range(start, start + count):
albums.append(
{
"mbid": f"album-{i:04d}",
"artist_mbid": f"artist-{(i % 5) + 1:04d}",
"artist_name": f"Artist {chr(65 + (i % 26))}",
"title": f"Album {chr(65 + ((i + 13) % 26))} {i:04d}",
"year": 2000 + (i % 24),
"cover_url": None,
"monitored": True,
"date_added": 1700000000 + i * 100,
}
)
return albums
def _make_artists(count: int) -> list[dict]:
"""Generate artist dicts with predictable data."""
artists = []
for i in range(1, count + 1):
artists.append(
{
"mbid": f"artist-{i:04d}",
"name": f"Artist {chr(65 + (i % 26))}",
"album_count": i,
"date_added": 1700000000 + i * 100,
}
)
return artists
async def _seed(db: LibraryDB, n_albums: int = 100, n_artists: int = 20) -> None:
await db.save_library(_make_artists(n_artists), _make_albums(n_albums))
# --- Album pagination ---
def test_albums_basic_pagination(db: LibraryDB):
asyncio.get_event_loop().run_until_complete(_seed(db))
items, total = asyncio.get_event_loop().run_until_complete(
db.get_albums_paginated(limit=10, offset=0)
)
assert total == 100
assert len(items) == 10
def test_albums_offset_beyond_total(db: LibraryDB):
asyncio.get_event_loop().run_until_complete(_seed(db))
items, total = asyncio.get_event_loop().run_until_complete(
db.get_albums_paginated(limit=10, offset=200)
)
assert total == 100
assert len(items) == 0
def test_albums_last_partial_page(db: LibraryDB):
asyncio.get_event_loop().run_until_complete(_seed(db))
items, total = asyncio.get_event_loop().run_until_complete(
db.get_albums_paginated(limit=30, offset=90)
)
assert total == 100
assert len(items) == 10
def test_albums_sort_by_title_asc(db: LibraryDB):
asyncio.get_event_loop().run_until_complete(_seed(db))
items, _ = asyncio.get_event_loop().run_until_complete(
db.get_albums_paginated(limit=100, offset=0, sort_by="title", sort_order="asc")
)
titles = [i.get("title", "") for i in items]
assert titles == sorted(titles, key=str.casefold)
def test_albums_sort_by_title_desc(db: LibraryDB):
asyncio.get_event_loop().run_until_complete(_seed(db))
items, _ = asyncio.get_event_loop().run_until_complete(
db.get_albums_paginated(limit=100, offset=0, sort_by="title", sort_order="desc")
)
titles = [i.get("title", "") for i in items]
assert titles == sorted(titles, key=str.casefold, reverse=True)
def test_albums_sort_by_year(db: LibraryDB):
asyncio.get_event_loop().run_until_complete(_seed(db))
items, _ = asyncio.get_event_loop().run_until_complete(
db.get_albums_paginated(limit=100, offset=0, sort_by="year", sort_order="desc")
)
years = [i.get("year", 0) or 0 for i in items]
assert years == sorted(years, reverse=True)
def test_albums_sort_by_date_added(db: LibraryDB):
asyncio.get_event_loop().run_until_complete(_seed(db))
items, _ = asyncio.get_event_loop().run_until_complete(
db.get_albums_paginated(limit=100, offset=0, sort_by="date_added", sort_order="desc")
)
dates = [i.get("date_added", 0) or 0 for i in items]
assert dates == sorted(dates, reverse=True)
def test_albums_search_by_title(db: LibraryDB):
asyncio.get_event_loop().run_until_complete(_seed(db))
items, total = asyncio.get_event_loop().run_until_complete(
db.get_albums_paginated(limit=100, offset=0, search="Album A")
)
assert total > 0
assert all("Album A" in i.get("title", "") for i in items)
def test_albums_search_by_artist(db: LibraryDB):
asyncio.get_event_loop().run_until_complete(_seed(db))
items, total = asyncio.get_event_loop().run_until_complete(
db.get_albums_paginated(limit=100, offset=0, search="Artist A")
)
assert total > 0
assert all(
"Artist A" in i.get("artist_name", "") or "Artist A" in i.get("title", "")
for i in items
)
def test_albums_search_no_results(db: LibraryDB):
asyncio.get_event_loop().run_until_complete(_seed(db))
items, total = asyncio.get_event_loop().run_until_complete(
db.get_albums_paginated(limit=10, offset=0, search="zzz_no_match_zzz")
)
assert total == 0
assert len(items) == 0
def test_albums_search_case_insensitive(db: LibraryDB):
asyncio.get_event_loop().run_until_complete(_seed(db))
items_upper, total_upper = asyncio.get_event_loop().run_until_complete(
db.get_albums_paginated(limit=100, offset=0, search="ALBUM A")
)
items_lower, total_lower = asyncio.get_event_loop().run_until_complete(
db.get_albums_paginated(limit=100, offset=0, search="album a")
)
assert total_upper == total_lower
assert len(items_upper) == len(items_lower)
def test_albums_search_escapes_like_metacharacters(db: LibraryDB):
asyncio.get_event_loop().run_until_complete(_seed(db))
items_pct, total_pct = asyncio.get_event_loop().run_until_complete(
db.get_albums_paginated(limit=100, offset=0, search="100%")
)
assert total_pct == 0
assert len(items_pct) == 0
items_under, total_under = asyncio.get_event_loop().run_until_complete(
db.get_albums_paginated(limit=100, offset=0, search="Album_A")
)
assert total_under == 0
def test_artists_search_escapes_like_metacharacters(db: LibraryDB):
asyncio.get_event_loop().run_until_complete(_seed(db))
items, total = asyncio.get_event_loop().run_until_complete(
db.get_artists_paginated(limit=100, offset=0, search="Artist%B")
)
assert total == 0
def test_albums_invalid_sort_falls_back(db: LibraryDB):
asyncio.get_event_loop().run_until_complete(_seed(db))
items, total = asyncio.get_event_loop().run_until_complete(
db.get_albums_paginated(limit=10, offset=0, sort_by="nonexistent")
)
assert total == 100
assert len(items) == 10
def test_albums_empty_library(db: LibraryDB):
items, total = asyncio.get_event_loop().run_until_complete(
db.get_albums_paginated(limit=10, offset=0)
)
assert total == 0
assert len(items) == 0
# --- Artist pagination ---
def test_artists_basic_pagination(db: LibraryDB):
asyncio.get_event_loop().run_until_complete(_seed(db))
items, total = asyncio.get_event_loop().run_until_complete(
db.get_artists_paginated(limit=5, offset=0)
)
assert total == 20
assert len(items) == 5
def test_artists_offset_beyond_total(db: LibraryDB):
asyncio.get_event_loop().run_until_complete(_seed(db))
items, total = asyncio.get_event_loop().run_until_complete(
db.get_artists_paginated(limit=10, offset=50)
)
assert total == 20
assert len(items) == 0
def test_artists_sort_by_name_asc(db: LibraryDB):
asyncio.get_event_loop().run_until_complete(_seed(db))
items, _ = asyncio.get_event_loop().run_until_complete(
db.get_artists_paginated(limit=20, offset=0, sort_by="name", sort_order="asc")
)
names = [i.get("name", "") for i in items]
assert names == sorted(names, key=str.casefold)
def test_artists_sort_by_album_count_desc(db: LibraryDB):
asyncio.get_event_loop().run_until_complete(_seed(db))
items, _ = asyncio.get_event_loop().run_until_complete(
db.get_artists_paginated(limit=20, offset=0, sort_by="album_count", sort_order="desc")
)
counts = [i.get("album_count", 0) for i in items]
assert counts == sorted(counts, reverse=True)
def test_artists_search(db: LibraryDB):
asyncio.get_event_loop().run_until_complete(_seed(db))
items, total = asyncio.get_event_loop().run_until_complete(
db.get_artists_paginated(limit=20, offset=0, search="Artist B")
)
assert total > 0
assert all("Artist B" in i.get("name", "") for i in items)
def test_artists_search_no_results(db: LibraryDB):
asyncio.get_event_loop().run_until_complete(_seed(db))
items, total = asyncio.get_event_loop().run_until_complete(
db.get_artists_paginated(limit=10, offset=0, search="zzz_no_match_zzz")
)
assert total == 0
assert len(items) == 0
def test_artists_empty_library(db: LibraryDB):
items, total = asyncio.get_event_loop().run_until_complete(
db.get_artists_paginated(limit=10, offset=0)
)
assert total == 0
assert len(items) == 0
# --- Pagination consistency (no duplicates/missing across pages) ---
def test_albums_pagination_no_duplicates(db: LibraryDB):
asyncio.get_event_loop().run_until_complete(_seed(db, n_albums=50))
all_mbids: list[str] = []
offset = 0
page_size = 10
while True:
items, total = asyncio.get_event_loop().run_until_complete(
db.get_albums_paginated(
limit=page_size, offset=offset, sort_by="title", sort_order="asc"
)
)
if not items:
break
all_mbids.extend(i.get("mbid", "") for i in items)
offset += page_size
assert len(all_mbids) == 50
assert len(set(all_mbids)) == 50
def test_artists_pagination_no_duplicates(db: LibraryDB):
asyncio.get_event_loop().run_until_complete(_seed(db, n_albums=10, n_artists=30))
all_mbids: list[str] = []
offset = 0
page_size = 7
while True:
items, total = asyncio.get_event_loop().run_until_complete(
db.get_artists_paginated(
limit=page_size, offset=offset, sort_by="name", sort_order="asc"
)
)
if not items:
break
all_mbids.extend(i.get("mbid", "") for i in items)
offset += page_size
assert len(all_mbids) == 30
assert len(set(all_mbids)) == 30
@@ -0,0 +1,117 @@
import json
import pytest
from fastapi import APIRouter, FastAPI
from fastapi.testclient import TestClient
from infrastructure.msgspec_fastapi import (
AppStruct,
MsgSpecBody,
MsgSpecJSONRequest,
MsgSpecJSONResponse,
MsgSpecRoute,
_contains_msgspec_struct,
_merge_response_schema,
)
class SamplePayload(AppStruct):
value: int
@pytest.mark.asyncio
async def test_msgspec_json_request_caches_decoded_body():
body = b'{"value": 7}'
async def receive():
return {"type": "http.request", "body": body, "more_body": False}
request = MsgSpecJSONRequest(
{
"type": "http",
"http_version": "1.1",
"method": "POST",
"path": "/",
"raw_path": b"/",
"scheme": "http",
"headers": [],
"query_string": b"",
"client": ("testclient", 123),
"server": ("testserver", 80),
},
receive,
)
first = await request.json()
second = await request.json()
assert first == {"value": 7}
assert second == first
@pytest.mark.asyncio
async def test_msgspec_json_request_raises_json_decode_error():
async def receive():
return {"type": "http.request", "body": b"{", "more_body": False}
request = MsgSpecJSONRequest(
{
"type": "http",
"http_version": "1.1",
"method": "POST",
"path": "/",
"raw_path": b"/",
"scheme": "http",
"headers": [],
"query_string": b"",
"client": ("testclient", 123),
"server": ("testserver", 80),
},
receive,
)
with pytest.raises(json.JSONDecodeError):
await request.json()
def test_app_struct_iteration_and_json_response_render():
payload = SamplePayload(value=5)
assert dict(payload) == {"value": 5}
response = MsgSpecJSONResponse(content=payload)
assert response.body == b'{"value":5}'
def test_msgspec_body_and_route_work_with_fastapi():
app = FastAPI()
router = APIRouter(route_class=MsgSpecRoute)
@router.post("/items", response_model=SamplePayload)
async def create_item(body: SamplePayload = MsgSpecBody(SamplePayload)):
return body
app.include_router(router)
client = TestClient(app)
ok = client.post("/items", json={"value": 11})
assert ok.status_code == 200
assert ok.json() == {"value": 11}
bad = client.post("/items", json={"value": "nope"})
assert bad.status_code == 422
def test_contains_msgspec_struct_and_merge_response_schema():
assert _contains_msgspec_struct(SamplePayload) is True
assert _contains_msgspec_struct(list[SamplePayload]) is True
assert _contains_msgspec_struct(str | None) is False
merged = _merge_response_schema(
{"responses": {"200": {"description": "ok"}}},
{"type": "object", "properties": {"value": {"type": "integer"}}},
)
schema = merged["responses"]["200"]["content"]["application/json"]["schema"]
assert schema["type"] == "object"
assert "value" in schema["properties"]
@@ -0,0 +1,110 @@
import asyncio
import pytest
from pathlib import Path
from infrastructure.queue.queue_store import QueueStore
from infrastructure.queue.request_queue import RequestQueue
@pytest.fixture
def store(tmp_path: Path) -> QueueStore:
return QueueStore(db_path=tmp_path / "test_queue.db")
@pytest.mark.asyncio
async def test_jobs_survive_restart(store: QueueStore):
processed = []
async def slow_processor(mbid: str) -> dict:
await asyncio.sleep(100)
processed.append(mbid)
return {"status": "ok"}
q1 = RequestQueue(processor=slow_processor, store=store)
await q1.start()
store.enqueue("job-1", "mbid-abc")
store.mark_processing("job-1")
q1._processor_task.cancel()
try:
await q1._processor_task
except asyncio.CancelledError:
pass
fast_processed = []
async def fast_processor(mbid: str) -> dict:
fast_processed.append(mbid)
return {"status": "ok"}
q2 = RequestQueue(processor=fast_processor, store=store)
await q2.start()
await asyncio.sleep(0.5)
assert "mbid-abc" in fast_processed
await q2.stop()
@pytest.mark.asyncio
async def test_failed_job_lands_in_dead_letter(store: QueueStore):
async def failing_processor(mbid: str) -> dict:
raise ValueError("Lidarr is down")
q = RequestQueue(processor=failing_processor, store=store)
await q.start()
try:
await asyncio.wait_for(q.add("mbid-fail"), timeout=2.0)
except (ValueError, asyncio.TimeoutError):
pass
await asyncio.sleep(0.1)
assert store.get_dead_letter_count() >= 1
await q.stop()
@pytest.mark.asyncio
async def test_dead_letter_retry_on_restart(store: QueueStore):
store.add_dead_letter("dlj-1", "mbid-retry", "old error", retry_count=1, max_retries=3)
processed = []
async def processor(mbid: str) -> dict:
processed.append(mbid)
return {"status": "ok"}
q = RequestQueue(processor=processor, store=store)
await q.start()
await asyncio.sleep(0.5)
assert "mbid-retry" in processed
await q.stop()
@pytest.mark.asyncio
async def test_successful_job_removed_from_store(store: QueueStore):
async def ok_processor(mbid: str) -> dict:
return {"status": "ok"}
q = RequestQueue(processor=ok_processor, store=store)
await q.start()
await asyncio.wait_for(q.add("mbid-ok"), timeout=2.0)
assert len(store.get_all()) == 0
await q.stop()
@pytest.mark.asyncio
async def test_exhausted_dead_letter_not_retried(store: QueueStore):
store.add_dead_letter("dlj-ex", "mbid-exhausted", "fatal", retry_count=3, max_retries=3)
processed = []
async def processor(mbid: str) -> dict:
processed.append(mbid)
return {"status": "ok"}
q = RequestQueue(processor=processor, store=store)
await q.start()
await asyncio.sleep(0.3)
assert "mbid-exhausted" not in processed
await q.stop()
@@ -0,0 +1,88 @@
import pytest
from pathlib import Path
from infrastructure.queue.queue_store import QueueStore
@pytest.fixture
def store(tmp_path: Path) -> QueueStore:
return QueueStore(db_path=tmp_path / "test_queue.db")
def test_enqueue_and_get_pending(store: QueueStore):
store.enqueue("j1", "mbid-1")
store.enqueue("j2", "mbid-2")
store.enqueue("j3", "mbid-3")
assert len(store.get_pending()) == 3
def test_dequeue_removes_job(store: QueueStore):
store.enqueue("j1", "mbid-1")
store.dequeue("j1")
assert len(store.get_pending()) == 0
def test_duplicate_enqueue_ignored(store: QueueStore):
assert store.enqueue("j1", "mbid-1") is True
assert store.enqueue("j2", "mbid-1") is False
assert len(store.get_pending()) == 1
def test_mark_processing(store: QueueStore):
store.enqueue("j1", "mbid-1")
store.mark_processing("j1")
assert len(store.get_pending()) == 0
assert len(store.get_all()) == 1
def test_reset_processing(store: QueueStore):
store.enqueue("j1", "mbid-1")
store.mark_processing("j1")
store.reset_processing()
assert len(store.get_pending()) == 1
def test_add_dead_letter_retryable(store: QueueStore):
store.add_dead_letter("j1", "mbid-1", "error", retry_count=1, max_retries=3)
retryable = store.get_retryable_dead_letters()
assert len(retryable) == 1
assert retryable[0]["album_mbid"] == "mbid-1"
def test_add_dead_letter_exhausted(store: QueueStore):
store.add_dead_letter("j1", "mbid-1", "error", retry_count=3, max_retries=3)
assert len(store.get_retryable_dead_letters()) == 0
def test_remove_dead_letter(store: QueueStore):
store.add_dead_letter("j1", "mbid-1", "error", retry_count=1, max_retries=3)
store.remove_dead_letter("j1")
assert len(store.get_retryable_dead_letters()) == 0
def test_update_dead_letter_attempt(store: QueueStore):
store.add_dead_letter("j1", "mbid-1", "error1", retry_count=1, max_retries=3)
store.update_dead_letter_attempt("j1", "error2", retry_count=3)
assert len(store.get_retryable_dead_letters()) == 0
assert store.get_dead_letter_count() == 1
def test_get_dead_letter_count(store: QueueStore):
store.add_dead_letter("j1", "mbid-1", "e1", 1, 3)
store.add_dead_letter("j2", "mbid-2", "e2", 1, 3)
store.add_dead_letter("j3", "mbid-3", "e3", 1, 3)
assert store.get_dead_letter_count() == 3
def test_has_pending_mbid(store: QueueStore):
assert store.has_pending_mbid("mbid-1") is False
store.enqueue("j1", "mbid-1")
assert store.has_pending_mbid("mbid-1") is True
store.mark_processing("j1")
assert store.has_pending_mbid("mbid-1") is False
store.dequeue("j1")
assert store.has_pending_mbid("mbid-1") is False
def test_enqueue_returns_bool(store: QueueStore):
assert store.enqueue("j1", "mbid-1") is True
assert store.enqueue("j1", "mbid-1") is False
@@ -0,0 +1,199 @@
"""Tests that non_breaking_exceptions bypass circuit breaker failure recording."""
import asyncio
import pytest
from infrastructure.resilience.retry import (
CircuitBreaker,
CircuitOpenError,
CircuitState,
with_retry,
)
class _RateLimited(Exception):
def __init__(self, retry_after: float = 1.0):
super().__init__("rate limited")
self.retry_after_seconds = retry_after
class _ServiceDown(Exception):
pass
@pytest.mark.asyncio
async def test_non_breaking_exception_does_not_trip_circuit():
cb = CircuitBreaker(failure_threshold=3, name="test-non-breaking")
call_count = 0
@with_retry(
max_attempts=4,
base_delay=0.01,
max_delay=0.05,
circuit_breaker=cb,
retriable_exceptions=(_RateLimited, _ServiceDown),
non_breaking_exceptions=(_RateLimited,),
)
async def flaky():
nonlocal call_count
call_count += 1
if call_count < 4:
raise _RateLimited(retry_after=0.01)
return "ok"
result = await flaky()
assert result == "ok"
assert call_count == 4
assert cb.state == CircuitState.CLOSED
assert cb.failure_count == 0
@pytest.mark.asyncio
async def test_breaking_exception_still_trips_circuit():
cb = CircuitBreaker(failure_threshold=2, name="test-breaking")
call_count = 0
@with_retry(
max_attempts=3,
base_delay=0.01,
max_delay=0.05,
circuit_breaker=cb,
retriable_exceptions=(_RateLimited, _ServiceDown),
non_breaking_exceptions=(_RateLimited,),
)
async def fail():
nonlocal call_count
call_count += 1
raise _ServiceDown("down")
with pytest.raises(_ServiceDown):
await fail()
assert call_count == 3
assert cb.state == CircuitState.OPEN
@pytest.mark.asyncio
async def test_non_breaking_uses_retry_after_for_delay():
cb = CircuitBreaker(failure_threshold=5, name="test-retry-after")
call_count = 0
@with_retry(
max_attempts=2,
base_delay=100.0,
max_delay=100.0,
circuit_breaker=cb,
retriable_exceptions=(_RateLimited,),
non_breaking_exceptions=(_RateLimited,),
)
async def rate_limited_then_ok():
nonlocal call_count
call_count += 1
if call_count == 1:
raise _RateLimited(retry_after=0.01)
return "ok"
result = await rate_limited_then_ok()
assert result == "ok"
assert call_count == 2
assert cb.failure_count == 0
@pytest.mark.asyncio
async def test_circuit_still_opens_for_real_errors_amid_rate_limits():
cb = CircuitBreaker(failure_threshold=2, name="test-mixed")
@with_retry(
max_attempts=1,
base_delay=0.01,
max_delay=0.05,
circuit_breaker=cb,
retriable_exceptions=(_RateLimited, _ServiceDown),
non_breaking_exceptions=(_RateLimited,),
)
async def real_failure():
raise _ServiceDown("down")
for _ in range(2):
with pytest.raises(_ServiceDown):
await real_failure()
assert cb.state == CircuitState.OPEN
@with_retry(
max_attempts=1,
base_delay=0.01,
max_delay=0.05,
circuit_breaker=cb,
retriable_exceptions=(_RateLimited, _ServiceDown),
non_breaking_exceptions=(_RateLimited,),
)
async def subsequent_call():
return "should not reach"
with pytest.raises(CircuitOpenError):
await subsequent_call()
@pytest.mark.asyncio
async def test_non_breaking_in_half_open_reopens_circuit():
"""Non-breaking exceptions in HALF_OPEN must still reopen the circuit."""
cb = CircuitBreaker(failure_threshold=2, success_threshold=2, timeout=0.01, name="test-half-open")
for _ in range(2):
cb.record_failure()
assert cb.state == CircuitState.OPEN
await asyncio.sleep(0.02)
await cb.atry_transition()
assert cb.state == CircuitState.HALF_OPEN
@with_retry(
max_attempts=1,
base_delay=0.01,
max_delay=0.05,
circuit_breaker=cb,
retriable_exceptions=(_RateLimited,),
non_breaking_exceptions=(_RateLimited,),
)
async def rate_limited_in_half_open():
raise _RateLimited(retry_after=0.01)
with pytest.raises(_RateLimited):
await rate_limited_in_half_open()
assert cb.state == CircuitState.OPEN
@pytest.mark.asyncio
async def test_retry_after_not_clamped_by_max_delay():
"""Server-provided Retry-After should not be clamped by max_delay."""
cb = CircuitBreaker(failure_threshold=10, name="test-retry-after-clamp")
call_count = 0
observed_gap = 0.0
@with_retry(
max_attempts=2,
base_delay=0.01,
max_delay=0.05,
circuit_breaker=cb,
retriable_exceptions=(_RateLimited,),
non_breaking_exceptions=(_RateLimited,),
)
async def rate_limited_then_ok():
nonlocal call_count, observed_gap
call_count += 1
if call_count == 1:
raise _RateLimited(retry_after=0.3)
return "ok"
import time
start = time.monotonic()
result = await rate_limited_then_ok()
elapsed = time.monotonic() - start
assert result == "ok"
assert elapsed >= 0.25, f"Expected >=0.25s delay from retry_after=0.3, got {elapsed:.3f}s"
@@ -0,0 +1,41 @@
import msgspec
import pytest
from infrastructure.serialization import clone_with_updates, to_jsonable
class SampleStruct(msgspec.Struct):
value: int
name: str = "x"
def test_to_jsonable_struct_and_dict():
struct_value = SampleStruct(value=3, name="abc")
assert to_jsonable(struct_value) == {"value": 3, "name": "abc"}
assert to_jsonable({"x": 1}) == {"x": 1}
def test_clone_with_updates_struct():
original = SampleStruct(value=1, name="before")
updated = clone_with_updates(original, {"name": "after"})
assert isinstance(updated, SampleStruct)
assert updated.value == 1
assert updated.name == "after"
assert original.name == "before"
def test_clone_with_updates_dict():
original = {"value": 1, "name": "before"}
updated = clone_with_updates(original, {"name": "after", "extra": True})
assert updated == {"value": 1, "name": "after", "extra": True}
assert original == {"value": 1, "name": "before"}
def test_clone_with_updates_unsupported_type_raises():
with pytest.raises(TypeError):
clone_with_updates([1, 2, 3], {"value": 9})