212 lines
6.2 KiB
Python
212 lines
6.2 KiB
Python
import asyncio
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, ANY
|
|
|
|
from api.v1.schemas.artist import ArtistInfo
|
|
from core.exceptions import ResourceNotFoundError
|
|
from services.artist_service import ArtistService
|
|
|
|
|
|
MBID = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
|
|
|
|
|
def _fake_artist_info() -> ArtistInfo:
|
|
return ArtistInfo(
|
|
name="Test Artist",
|
|
musicbrainz_id=MBID,
|
|
)
|
|
|
|
|
|
def _make_service() -> ArtistService:
|
|
mb = AsyncMock()
|
|
lidarr = AsyncMock()
|
|
lidarr.is_configured.return_value = False
|
|
wikidata = AsyncMock()
|
|
prefs = MagicMock()
|
|
mem_cache = AsyncMock()
|
|
mem_cache.get = AsyncMock(return_value=None)
|
|
mem_cache.set = AsyncMock()
|
|
disk_cache = MagicMock()
|
|
disk_cache.get_artist = AsyncMock(return_value=None)
|
|
disk_cache.set_artist = AsyncMock()
|
|
audiodb_img = MagicMock()
|
|
audiodb_img.fetch_and_cache_artist_images = AsyncMock(return_value=None)
|
|
|
|
svc = ArtistService(
|
|
mb_repo=mb,
|
|
lidarr_repo=lidarr,
|
|
wikidata_repo=wikidata,
|
|
preferences_service=prefs,
|
|
memory_cache=mem_cache,
|
|
disk_cache=disk_cache,
|
|
audiodb_image_service=audiodb_img,
|
|
)
|
|
return svc
|
|
|
|
|
|
class TestArtistSingleflight:
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_calls_fetch_once(self):
|
|
"""Multiple concurrent get_artist_info calls should invoke
|
|
_do_get_artist_info once."""
|
|
svc = _make_service()
|
|
call_count = 0
|
|
fake = _fake_artist_info()
|
|
|
|
async def counting_fetch(*args, **kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
await asyncio.sleep(0.05)
|
|
return fake
|
|
|
|
svc._do_get_artist_info = counting_fetch
|
|
|
|
results = await asyncio.gather(
|
|
svc.get_artist_info(MBID),
|
|
svc.get_artist_info(MBID),
|
|
svc.get_artist_info(MBID),
|
|
)
|
|
|
|
assert call_count == 1
|
|
assert all(r.name == "Test Artist" for r in results)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_singleflight_cleared_after_completion(self):
|
|
"""After completion, the in-flight dict should be empty."""
|
|
svc = _make_service()
|
|
fake = _fake_artist_info()
|
|
|
|
async def quick_fetch(*args, **kwargs):
|
|
return fake
|
|
|
|
svc._do_get_artist_info = quick_fetch
|
|
|
|
await svc.get_artist_info(MBID)
|
|
assert MBID not in svc._artist_in_flight
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_singleflight_propagates_exception(self):
|
|
"""If fetch raises, all concurrent callers should get the exception."""
|
|
svc = _make_service()
|
|
|
|
async def failing_fetch(*args, **kwargs):
|
|
await asyncio.sleep(0.05)
|
|
raise RuntimeError("upstream timeout")
|
|
|
|
svc._do_get_artist_info = failing_fetch
|
|
|
|
results = await asyncio.gather(
|
|
svc.get_artist_info(MBID),
|
|
svc.get_artist_info(MBID),
|
|
svc.get_artist_info(MBID),
|
|
return_exceptions=True,
|
|
)
|
|
|
|
assert all(isinstance(r, ResourceNotFoundError) for r in results)
|
|
assert MBID not in svc._artist_in_flight
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_hit_bypasses_singleflight(self):
|
|
"""Cache hit should skip the fetch entirely."""
|
|
svc = _make_service()
|
|
fake = _fake_artist_info()
|
|
svc._cache.get = AsyncMock(return_value=fake)
|
|
call_count = 0
|
|
|
|
async def should_not_run(*args, **kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return fake
|
|
|
|
svc._do_get_artist_info = should_not_run
|
|
|
|
result = await svc.get_artist_info(MBID)
|
|
assert result.name == "Test Artist"
|
|
assert call_count == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_different_ids_run_independently(self):
|
|
"""Different artist_ids should run in parallel."""
|
|
svc = _make_service()
|
|
call_ids: list[str] = []
|
|
|
|
async def tracking_fetch(aid, *args, **kwargs):
|
|
call_ids.append(aid)
|
|
await asyncio.sleep(0.02)
|
|
return _fake_artist_info()
|
|
|
|
svc._do_get_artist_info = tracking_fetch
|
|
|
|
mbid_a = "aaaa1111-bbbb-cccc-dddd-eeeeeeeeeeee"
|
|
mbid_b = "bbbb2222-bbbb-cccc-dddd-eeeeeeeeeeee"
|
|
await asyncio.gather(
|
|
svc.get_artist_info(mbid_a),
|
|
svc.get_artist_info(mbid_b),
|
|
)
|
|
|
|
assert len(call_ids) == 2
|
|
assert mbid_a in call_ids
|
|
assert mbid_b in call_ids
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_basic_uses_separate_inflight_dict(self):
|
|
"""get_artist_info_basic should use a separate in-flight dict from
|
|
get_artist_info, so they don't cross-contaminate results."""
|
|
svc = _make_service()
|
|
fake = _fake_artist_info()
|
|
|
|
full_count = 0
|
|
basic_count = 0
|
|
|
|
async def full_fetch(*args, **kwargs):
|
|
nonlocal full_count
|
|
full_count += 1
|
|
await asyncio.sleep(0.05)
|
|
return fake
|
|
|
|
svc._do_get_artist_info = full_fetch
|
|
|
|
original_build = svc._build_artist_from_musicbrainz
|
|
|
|
async def basic_build(*args, **kwargs):
|
|
nonlocal basic_count
|
|
basic_count += 1
|
|
await asyncio.sleep(0.05)
|
|
return fake
|
|
|
|
svc._build_artist_from_musicbrainz = basic_build
|
|
|
|
results = await asyncio.gather(
|
|
svc.get_artist_info(MBID),
|
|
svc.get_artist_info_basic(MBID),
|
|
)
|
|
|
|
assert full_count == 1
|
|
assert basic_count == 1
|
|
assert all(r.name == "Test Artist" for r in results)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_basic_concurrent_calls_fetch_once(self):
|
|
"""Multiple concurrent get_artist_info_basic calls should only fetch once."""
|
|
svc = _make_service()
|
|
call_count = 0
|
|
fake = _fake_artist_info()
|
|
|
|
async def counting_build(*args, **kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
await asyncio.sleep(0.05)
|
|
return fake
|
|
|
|
svc._build_artist_from_musicbrainz = counting_build
|
|
|
|
results = await asyncio.gather(
|
|
svc.get_artist_info_basic(MBID),
|
|
svc.get_artist_info_basic(MBID),
|
|
svc.get_artist_info_basic(MBID),
|
|
)
|
|
|
|
assert call_count == 1
|
|
assert all(r.name == "Test Artist" for r in results)
|