477 lines
15 KiB
Python
477 lines
15 KiB
Python
"""
|
|
Unit tests for Plugin API service.
|
|
|
|
Tests cover:
|
|
- Singleton pattern
|
|
- API registration and calling
|
|
- Service registration (OCR, Screenshot, Log, Audio, Nexus)
|
|
- Shared data management
|
|
- Utility functions (PED/PEC formatting, DPP calculation)
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
from typing import Any
|
|
|
|
from core.plugin_api import PluginAPI, APIType, APIEndpoint, get_api
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestPluginAPISingleton:
|
|
"""Test PluginAPI singleton behavior."""
|
|
|
|
def test_singleton_instance(self, reset_singletons):
|
|
"""Test that PluginAPI is a proper singleton."""
|
|
api1 = get_api()
|
|
api2 = get_api()
|
|
|
|
assert api1 is api2
|
|
assert isinstance(api1, PluginAPI)
|
|
|
|
def test_singleton_initialized_once(self, reset_singletons):
|
|
"""Test that singleton is initialized only once."""
|
|
api1 = get_api()
|
|
api1._initialized = True
|
|
api1.custom_attr = "test"
|
|
|
|
api2 = get_api()
|
|
assert hasattr(api2, 'custom_attr')
|
|
assert api2.custom_attr == "test"
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestAPIRegistration:
|
|
"""Test API endpoint registration."""
|
|
|
|
def test_register_api(self, reset_singletons):
|
|
"""Test registering an API endpoint."""
|
|
api = get_api()
|
|
|
|
def handler_func():
|
|
return "test result"
|
|
|
|
endpoint = APIEndpoint(
|
|
name="test_api",
|
|
api_type=APIType.UTILITY,
|
|
description="Test API",
|
|
handler=handler_func,
|
|
plugin_id="test_plugin"
|
|
)
|
|
|
|
result = api.register_api(endpoint)
|
|
|
|
assert result is True
|
|
assert "test_plugin:test_api" in api.apis
|
|
|
|
def test_register_api_error_handling(self, reset_singletons):
|
|
"""Test API registration error handling."""
|
|
api = get_api()
|
|
|
|
# Create endpoint with bad handler
|
|
endpoint = APIEndpoint(
|
|
name="bad_api",
|
|
api_type=APIType.UTILITY,
|
|
description="Bad API",
|
|
handler=None, # This will cause issues
|
|
plugin_id="test_plugin"
|
|
)
|
|
|
|
# Should not raise, returns True on success
|
|
result = api.register_api(endpoint)
|
|
assert result is True
|
|
|
|
def test_unregister_api_single(self, reset_singletons):
|
|
"""Test unregistering a single API."""
|
|
api = get_api()
|
|
|
|
endpoint = APIEndpoint(
|
|
name="test_api",
|
|
api_type=APIType.UTILITY,
|
|
description="Test API",
|
|
handler=lambda: None,
|
|
plugin_id="test_plugin"
|
|
)
|
|
|
|
api.register_api(endpoint)
|
|
assert "test_plugin:test_api" in api.apis
|
|
|
|
api.unregister_api("test_plugin", "test_api")
|
|
assert "test_plugin:test_api" not in api.apis
|
|
|
|
def test_unregister_api_all_for_plugin(self, reset_singletons):
|
|
"""Test unregistering all APIs for a plugin."""
|
|
api = get_api()
|
|
|
|
# Register multiple APIs
|
|
for i in range(3):
|
|
endpoint = APIEndpoint(
|
|
name=f"api_{i}",
|
|
api_type=APIType.UTILITY,
|
|
description=f"API {i}",
|
|
handler=lambda: None,
|
|
plugin_id="test_plugin"
|
|
)
|
|
api.register_api(endpoint)
|
|
|
|
assert len(api.apis) == 3
|
|
|
|
# Unregister all for plugin
|
|
api.unregister_api("test_plugin")
|
|
assert len(api.apis) == 0
|
|
|
|
def test_call_api_success(self, reset_singletons):
|
|
"""Test calling a registered API."""
|
|
api = get_api()
|
|
|
|
def handler_func(arg1, arg2):
|
|
return f"Result: {arg1}, {arg2}"
|
|
|
|
endpoint = APIEndpoint(
|
|
name="test_api",
|
|
api_type=APIType.UTILITY,
|
|
description="Test API",
|
|
handler=handler_func,
|
|
plugin_id="test_plugin"
|
|
)
|
|
|
|
api.register_api(endpoint)
|
|
result = api.call_api("test_plugin", "test_api", "hello", "world")
|
|
|
|
assert result == "Result: hello, world"
|
|
|
|
def test_call_api_not_found(self, reset_singletons):
|
|
"""Test calling a non-existent API."""
|
|
api = get_api()
|
|
|
|
with pytest.raises(ValueError, match="API not found"):
|
|
api.call_api("nonexistent", "api")
|
|
|
|
def test_call_api_error_propagation(self, reset_singletons):
|
|
"""Test that API errors are propagated."""
|
|
api = get_api()
|
|
|
|
def error_handler():
|
|
raise ValueError("Test error")
|
|
|
|
endpoint = APIEndpoint(
|
|
name="error_api",
|
|
api_type=APIType.UTILITY,
|
|
description="Error API",
|
|
handler=error_handler,
|
|
plugin_id="test_plugin"
|
|
)
|
|
|
|
api.register_api(endpoint)
|
|
|
|
with pytest.raises(ValueError, match="Test error"):
|
|
api.call_api("test_plugin", "error_api")
|
|
|
|
def test_find_apis_by_type(self, reset_singletons):
|
|
"""Test finding APIs by type."""
|
|
api = get_api()
|
|
|
|
# Register APIs of different types
|
|
endpoint1 = APIEndpoint(
|
|
name="ocr_api",
|
|
api_type=APIType.OCR,
|
|
description="OCR API",
|
|
handler=lambda: None,
|
|
plugin_id="plugin1"
|
|
)
|
|
endpoint2 = APIEndpoint(
|
|
name="log_api",
|
|
api_type=APIType.LOG,
|
|
description="Log API",
|
|
handler=lambda: None,
|
|
plugin_id="plugin2"
|
|
)
|
|
endpoint3 = APIEndpoint(
|
|
name="another_ocr",
|
|
api_type=APIType.OCR,
|
|
description="Another OCR",
|
|
handler=lambda: None,
|
|
plugin_id="plugin3"
|
|
)
|
|
|
|
api.register_api(endpoint1)
|
|
api.register_api(endpoint2)
|
|
api.register_api(endpoint3)
|
|
|
|
ocr_apis = api.find_apis(APIType.OCR)
|
|
assert len(ocr_apis) == 2
|
|
|
|
all_apis = api.find_apis()
|
|
assert len(all_apis) == 3
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestServiceRegistration:
|
|
"""Test service registration and access."""
|
|
|
|
def test_register_ocr_service(self, reset_singletons):
|
|
"""Test OCR service registration."""
|
|
api = get_api()
|
|
mock_ocr = MagicMock()
|
|
mock_ocr.return_value = {"text": "test", "confidence": 0.9}
|
|
|
|
api.register_ocr_service(mock_ocr)
|
|
|
|
assert api.services['ocr'] is mock_ocr
|
|
|
|
def test_ocr_capture(self, reset_singletons):
|
|
"""Test OCR capture functionality."""
|
|
api = get_api()
|
|
mock_ocr = MagicMock()
|
|
mock_ocr.return_value = {"text": "captured", "confidence": 0.95}
|
|
|
|
api.register_ocr_service(mock_ocr)
|
|
result = api.ocr_capture(region=(0, 0, 100, 100))
|
|
|
|
assert result["text"] == "captured"
|
|
mock_ocr.assert_called_once_with((0, 0, 100, 100))
|
|
|
|
def test_ocr_capture_no_service(self, reset_singletons):
|
|
"""Test OCR capture when service not available."""
|
|
api = get_api()
|
|
|
|
with pytest.raises(RuntimeError, match="OCR service not available"):
|
|
api.ocr_capture()
|
|
|
|
def test_register_screenshot_service(self, reset_singletons):
|
|
"""Test screenshot service registration."""
|
|
api = get_api()
|
|
mock_screenshot = MagicMock()
|
|
mock_image = MagicMock()
|
|
mock_screenshot.capture.return_value = mock_image
|
|
mock_screenshot.capture_region.return_value = mock_image
|
|
|
|
api.register_screenshot_service(mock_screenshot)
|
|
|
|
# Test capture_screen
|
|
result = api.capture_screen(full_screen=True)
|
|
assert result is mock_image
|
|
mock_screenshot.capture.assert_called_once_with(full_screen=True)
|
|
|
|
# Test capture_region
|
|
result = api.capture_region(10, 20, 100, 200)
|
|
assert result is mock_image
|
|
mock_screenshot.capture_region.assert_called_once_with(10, 20, 100, 200)
|
|
|
|
def test_register_log_service(self, reset_singletons):
|
|
"""Test log service registration."""
|
|
api = get_api()
|
|
mock_log = MagicMock()
|
|
mock_log.return_value = ["line1", "line2"]
|
|
|
|
api.register_log_service(mock_log)
|
|
|
|
result = api.read_log(lines=10, filter_text="test")
|
|
|
|
assert result == ["line1", "line2"]
|
|
mock_log.assert_called_once_with(10, "test")
|
|
|
|
def test_register_audio_service(self, reset_singletons):
|
|
"""Test audio service registration."""
|
|
api = get_api()
|
|
mock_audio = MagicMock()
|
|
mock_audio.play_sound.return_value = True
|
|
mock_audio.get_volume.return_value = 0.8
|
|
mock_audio.is_muted.return_value = False
|
|
mock_audio.is_available.return_value = True
|
|
|
|
api.register_audio_service(mock_audio)
|
|
|
|
# Test play_sound
|
|
result = api.play_sound("hof")
|
|
assert result is True
|
|
mock_audio.play_sound.assert_called_once_with("hof", False)
|
|
|
|
# Test volume control
|
|
api.set_volume(0.5)
|
|
mock_audio.set_volume.assert_called_once_with(0.5)
|
|
|
|
assert api.get_volume() == 0.8
|
|
|
|
# Test mute
|
|
api.mute_audio()
|
|
mock_audio.mute.assert_called_once()
|
|
|
|
api.unmute_audio()
|
|
mock_audio.unmute.assert_called_once()
|
|
|
|
assert api.toggle_mute_audio() == False
|
|
assert api.is_audio_muted() == False
|
|
assert api.is_audio_available() == True
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestSharedData:
|
|
"""Test shared data management."""
|
|
|
|
def test_get_set_data(self, reset_singletons):
|
|
"""Test getting and setting shared data."""
|
|
api = get_api()
|
|
|
|
# Set data
|
|
api.set_data("test_key", "test_value")
|
|
|
|
# Get data
|
|
result = api.get_data("test_key")
|
|
assert result == "test_value"
|
|
|
|
def test_get_data_default(self, reset_singletons):
|
|
"""Test getting data with default value."""
|
|
api = get_api()
|
|
|
|
result = api.get_data("nonexistent_key", "default")
|
|
assert result == "default"
|
|
|
|
def test_get_data_none_default(self, reset_singletons):
|
|
"""Test getting data with None default."""
|
|
api = get_api()
|
|
|
|
result = api.get_data("nonexistent_key")
|
|
assert result is None
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestUtilityFunctions:
|
|
"""Test utility helper functions."""
|
|
|
|
def test_format_ped(self, reset_singletons):
|
|
"""Test PED formatting."""
|
|
api = get_api()
|
|
|
|
assert api.format_ped(10.5) == "10.50 PED"
|
|
assert api.format_ped(0.0) == "0.00 PED"
|
|
assert api.format_ped(100.123) == "100.12 PED"
|
|
|
|
def test_format_pec(self, reset_singletons):
|
|
"""Test PEC formatting."""
|
|
api = get_api()
|
|
|
|
assert api.format_pec(50.0) == "50 PEC"
|
|
assert api.format_pec(0.0) == "0 PEC"
|
|
assert api.format_pec(100.7) == "101 PEC"
|
|
|
|
def test_calculate_dpp(self, reset_singletons):
|
|
"""Test DPP calculation."""
|
|
api = get_api()
|
|
|
|
# Normal case
|
|
dpp = api.calculate_dpp(damage=100, ammo=50, decay=0.5)
|
|
# ammo_cost = 50 * 0.01 = 0.5 PEC
|
|
# total_cost = 0.5 + 0.5 = 1.0 PEC = 0.01 PED
|
|
# dpp = 100 / 0.01 = 10000
|
|
assert dpp == 10000.0
|
|
|
|
# Zero damage
|
|
assert api.calculate_dpp(damage=0, ammo=50, decay=0.5) == 0.0
|
|
|
|
# Zero cost
|
|
assert api.calculate_dpp(damage=100, ammo=0, decay=0) == 0.0
|
|
|
|
def test_calculate_markup(self, reset_singletons):
|
|
"""Test markup calculation."""
|
|
api = get_api()
|
|
|
|
# Normal case
|
|
assert api.calculate_markup(price=110, tt=100) == 110.0
|
|
|
|
# Zero TT
|
|
assert api.calculate_markup(price=110, tt=0) == 0.0
|
|
|
|
# Negative TT
|
|
assert api.calculate_markup(price=110, tt=-10) == 0.0
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestLegacyEventSystem:
|
|
"""Test legacy event system (backward compatibility)."""
|
|
|
|
def test_publish_event(self, reset_singletons):
|
|
"""Test publishing legacy events."""
|
|
api = get_api()
|
|
|
|
api.publish_event("test_event", {"key": "value"})
|
|
|
|
# Check data was stored
|
|
event_data = api.get_data("event:test_event")
|
|
assert event_data is not None
|
|
assert event_data['data'] == {"key": "value"}
|
|
assert 'timestamp' in event_data
|
|
|
|
def test_subscribe_and_receive_legacy(self, reset_singletons):
|
|
"""Test legacy event subscription."""
|
|
api = get_api()
|
|
received = []
|
|
|
|
def callback(data):
|
|
received.append(data)
|
|
|
|
api.subscribe("test_event", callback)
|
|
api.publish_event("test_event", {"message": "hello"})
|
|
|
|
# Note: Legacy subscription is synchronous
|
|
assert len(received) == 1
|
|
assert received[0] == {"message": "hello"}
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestTypedEventIntegration:
|
|
"""Test typed event system integration."""
|
|
|
|
def test_publish_typed(self, reset_singletons, fresh_event_bus):
|
|
"""Test publishing typed events."""
|
|
from core.event_bus import SkillGainEvent
|
|
|
|
api = get_api()
|
|
event = SkillGainEvent(skill_name="Rifle", skill_value=25.0, gain_amount=0.01)
|
|
|
|
# Should not raise
|
|
api.publish_typed(event)
|
|
|
|
def test_subscribe_typed(self, reset_singletons, fresh_event_bus):
|
|
"""Test subscribing to typed events."""
|
|
from core.event_bus import SkillGainEvent
|
|
|
|
api = get_api()
|
|
received = []
|
|
|
|
def handler(event):
|
|
received.append(event)
|
|
|
|
sub_id = api.subscribe_typed(SkillGainEvent, handler)
|
|
|
|
assert isinstance(sub_id, str)
|
|
assert len(sub_id) > 0
|
|
|
|
def test_get_recent_events(self, reset_singletons, fresh_event_bus):
|
|
"""Test getting recent events."""
|
|
from core.event_bus import SkillGainEvent
|
|
|
|
api = get_api()
|
|
|
|
# Publish some events
|
|
for i in range(5):
|
|
api.publish_typed(SkillGainEvent(skill_name=f"Skill{i}", skill_value=float(i), gain_amount=0.01))
|
|
|
|
# Get recent events
|
|
events = api.get_recent_events(SkillGainEvent, count=3)
|
|
|
|
assert len(events) <= 3
|
|
|
|
def test_get_event_stats(self, reset_singletons, fresh_event_bus):
|
|
"""Test getting event statistics."""
|
|
from core.event_bus import SkillGainEvent
|
|
|
|
api = get_api()
|
|
|
|
# Publish an event
|
|
api.publish_typed(SkillGainEvent(skill_name="Rifle", skill_value=25.0, gain_amount=0.01))
|
|
|
|
stats = api.get_event_stats()
|
|
|
|
assert 'total_published' in stats
|
|
assert 'total_delivered' in stats
|