EU-Utility/tests/unit/test_security_utils.py

600 lines
22 KiB
Python

"""
Unit tests for Security Utilities.
Tests cover:
- PathValidator - filename sanitization
- PathValidator - path traversal prevention
- InputValidator - identifier validation
- InputValidator - string sanitization
- DataValidator - data structure validation
- IntegrityChecker - hash and HMAC computation
"""
import sys
import unittest
import hashlib
import hmac
from pathlib import Path
from unittest.mock import MagicMock, patch
# Add project root to path
project_root = Path(__file__).parent.parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
from core.security_utils import (
PathValidator, InputValidator, DataValidator, IntegrityChecker,
SecurityError, sanitize_filename, validate_path, safe_path_join
)
class TestPathValidatorSanitizeFilename(unittest.TestCase):
"""Test PathValidator.sanitize_filename method."""
def test_sanitize_simple_filename(self):
"""Test sanitizing a simple filename."""
result = PathValidator.sanitize_filename("test.txt")
self.assertEqual(result, "test.txt")
def test_sanitize_with_path_separators(self):
"""Test sanitizing filename with path separators."""
result = PathValidator.sanitize_filename("path/to/file.txt")
self.assertEqual(result, "path_to_file.txt")
result = PathValidator.sanitize_filename("path\\to\\file.txt")
self.assertEqual(result, "path_to_file.txt")
def test_sanitize_parent_directory_refs(self):
"""Test sanitizing filename with parent directory references."""
result = PathValidator.sanitize_filename("../../../etc/passwd")
self.assertNotIn("..", result)
self.assertEqual(result, "_etc_passwd")
def test_sanitize_dangerous_characters(self):
"""Test sanitizing filename with dangerous characters."""
result = PathValidator.sanitize_filename("file:name*.txt")
self.assertEqual(result, "file_name_.txt")
def test_sanitize_null_bytes(self):
"""Test sanitizing filename with null bytes."""
result = PathValidator.sanitize_filename("file\x00name.txt")
self.assertEqual(result, "file_name.txt")
def test_sanitize_empty_filename(self):
"""Test sanitizing empty filename."""
result = PathValidator.sanitize_filename("")
self.assertTrue(result.startswith("file_"))
def test_sanitize_dots_only(self):
"""Test sanitizing filename with only dots."""
result = PathValidator.sanitize_filename("...")
self.assertTrue(result.startswith("file_"))
def test_sanitize_long_filename(self):
"""Test sanitizing very long filename."""
long_name = "a" * 300 + ".txt"
result = PathValidator.sanitize_filename(long_name)
self.assertLessEqual(len(result), PathValidator.MAX_FILENAME_LENGTH)
def test_sanitize_custom_replacement(self):
"""Test sanitizing with custom replacement character."""
result = PathValidator.sanitize_filename("file:name", replacement="-")
self.assertEqual(result, "file-name")
class TestPathValidatorSplitExtension(unittest.TestCase):
"""Test PathValidator._split_extension method."""
def test_split_simple_extension(self):
"""Test splitting simple extension."""
name, ext = PathValidator._split_extension("file.txt")
self.assertEqual(name, "file")
self.assertEqual(ext, ".txt")
def test_split_multiple_extensions(self):
"""Test splitting multiple extensions."""
name, ext = PathValidator._split_extension("archive.tar.gz")
self.assertEqual(name, "archive.tar")
self.assertEqual(ext, ".gz")
def test_split_no_extension(self):
"""Test splitting file with no extension."""
name, ext = PathValidator._split_extension("file")
self.assertEqual(name, "file")
self.assertEqual(ext, "")
class TestPathValidatorValidatePathWithinBase(unittest.TestCase):
"""Test PathValidator.validate_path_within_base method."""
def setUp(self):
"""Create temporary directory for tests."""
import tempfile
self.temp_dir = tempfile.mkdtemp()
def tearDown(self):
"""Clean up temporary directory."""
import shutil
shutil.rmtree(self.temp_dir)
def test_valid_path_within_base(self):
"""Test validating path within base directory."""
base = Path(self.temp_dir)
path = base / "subdir" / "file.txt"
result = PathValidator.validate_path_within_base(path, base)
self.assertIsInstance(result, Path)
self.assertTrue(str(result).startswith(str(base)))
def test_path_traversal_detection(self):
"""Test detecting path traversal attack."""
base = Path(self.temp_dir)
path = base / ".." / ".." / "etc" / "passwd"
with self.assertRaises(SecurityError) as context:
PathValidator.validate_path_within_base(path, base)
self.assertIn("Path traversal", str(context.exception))
def test_absolute_path_outside_base(self):
"""Test absolute path outside base."""
base = Path(self.temp_dir)
path = Path("/etc/passwd")
with self.assertRaises(SecurityError):
PathValidator.validate_path_within_base(path, base)
def test_symlink_detection(self):
"""Test detecting symlinks."""
base = Path(self.temp_dir)
# Create a symlink
symlink_path = base / "link"
target = Path("/etc")
try:
symlink_path.symlink_to(target)
with self.assertRaises(SecurityError) as context:
PathValidator.validate_path_within_base(symlink_path / "passwd", base)
self.assertIn("Symlink", str(context.exception))
except OSError:
self.skipTest("Cannot create symlinks on this system")
def test_path_depth_exceeded(self):
"""Test path depth limit."""
base = Path(self.temp_dir)
# Create a path that's too deep
path = base
for i in range(15):
path = path / f"dir{i}"
with self.assertRaises(SecurityError) as context:
PathValidator.validate_path_within_base(path, base)
self.assertIn("Path depth", str(context.exception))
class TestPathValidatorSafeJoin(unittest.TestCase):
"""Test PathValidator.safe_join method."""
def setUp(self):
"""Create temporary directory for tests."""
import tempfile
self.temp_dir = tempfile.mkdtemp()
def tearDown(self):
"""Clean up temporary directory."""
import shutil
shutil.rmtree(self.temp_dir)
def test_safe_join_simple(self):
"""Test safe join with simple paths."""
base = Path(self.temp_dir)
result = PathValidator.safe_join(base, "subdir", "file.txt")
self.assertTrue(str(result).startswith(str(base)))
self.assertIn("subdir", str(result))
self.assertIn("file.txt", str(result))
def test_safe_join_sanitizes_components(self):
"""Test that safe_join sanitizes path components."""
base = Path(self.temp_dir)
result = PathValidator.safe_join(base, "../etc", "passwd")
self.assertNotIn("..", str(result))
self.assertTrue(str(result).startswith(str(base)))
class TestInputValidatorIsSafeIdentifier(unittest.TestCase):
"""Test InputValidator.is_safe_identifier method."""
def test_valid_identifier(self):
"""Test valid Python identifiers."""
self.assertTrue(InputValidator.is_safe_identifier("valid_name"))
self.assertTrue(InputValidator.is_safe_identifier("_private"))
self.assertTrue(InputValidator.is_safe_identifier("Name123"))
def test_invalid_identifier_starts_with_number(self):
"""Test identifier starting with number."""
self.assertFalse(InputValidator.is_safe_identifier("123name"))
def test_invalid_identifier_with_spaces(self):
"""Test identifier with spaces."""
self.assertFalse(InputValidator.is_safe_identifier("name with spaces"))
def test_invalid_identifier_special_chars(self):
"""Test identifier with special characters."""
self.assertFalse(InputValidator.is_safe_identifier("name-with-dash"))
self.assertFalse(InputValidator.is_safe_identifier("name.with.dots"))
def test_python_keyword(self):
"""Test Python keywords are rejected."""
self.assertFalse(InputValidator.is_safe_identifier("class"))
self.assertFalse(InputValidator.is_safe_identifier("def"))
self.assertFalse(InputValidator.is_safe_identifier("import"))
def test_empty_identifier(self):
"""Test empty identifier."""
self.assertFalse(InputValidator.is_safe_identifier(""))
class TestInputValidatorSanitizeString(unittest.TestCase):
"""Test InputValidator.sanitize_string method."""
def test_sanitize_simple_string(self):
"""Test sanitizing simple string."""
result = InputValidator.sanitize_string("Hello World")
self.assertEqual(result, "Hello World")
def test_sanitize_removes_null_bytes(self):
"""Test removing null bytes."""
result = InputValidator.sanitize_string("Hello\x00World")
self.assertEqual(result, "HelloWorld")
def test_sanitize_removes_control_chars(self):
"""Test removing control characters."""
result = InputValidator.sanitize_string("Hello\x01\x02World")
self.assertEqual(result, "HelloWorld")
def test_sanitize_keeps_newlines_and_tabs(self):
"""Test that newlines and tabs are preserved."""
result = InputValidator.sanitize_string("Line1\nLine2\tTabbed")
self.assertEqual(result, "Line1\nLine2\tTabbed")
def test_sanitize_truncates_long_string(self):
"""Test truncating long strings."""
long_string = "a" * 2000
result = InputValidator.sanitize_string(long_string, max_length=100)
self.assertEqual(len(result), 100)
def test_sanitize_non_string_input(self):
"""Test sanitizing non-string input."""
result = InputValidator.sanitize_string(12345)
self.assertEqual(result, "12345")
class TestInputValidatorValidateJsonKey(unittest.TestCase):
"""Test InputValidator.validate_json_key method."""
def test_valid_key(self):
"""Test valid JSON keys."""
self.assertTrue(InputValidator.validate_json_key("valid_key"))
self.assertTrue(InputValidator.validate_json_key("key123"))
def test_dunder_key_rejected(self):
"""Test dunder keys are rejected."""
self.assertFalse(InputValidator.validate_json_key("__class__"))
self.assertFalse(InputValidator.validate_json_key("__init__"))
def test_empty_key_rejected(self):
"""Test empty keys are rejected."""
self.assertFalse(InputValidator.validate_json_key(""))
def test_dangerous_patterns_rejected(self):
"""Test dangerous patterns are rejected."""
self.assertFalse(InputValidator.validate_json_key("eval(something)"))
self.assertFalse(InputValidator.validate_json_key("exec("))
self.assertFalse(InputValidator.validate_json_key("os.system"))
def test_non_string_key(self):
"""Test non-string keys are rejected."""
self.assertFalse(InputValidator.validate_json_key(123))
self.assertFalse(InputValidator.validate_json_key(None))
class TestInputValidatorValidateRegionCoordinates(unittest.TestCase):
"""Test InputValidator.validate_region_coordinates method."""
def test_valid_coordinates(self):
"""Test valid region coordinates."""
# Should not raise
InputValidator.validate_region_coordinates(0, 0, 100, 100)
InputValidator.validate_region_coordinates(100, 200, 500, 300)
def test_invalid_type(self):
"""Test invalid coordinate types."""
with self.assertRaises(SecurityError) as context:
InputValidator.validate_region_coordinates("0", 0, 100, 100)
self.assertIn("must be integers", str(context.exception))
def test_zero_width(self):
"""Test zero width."""
with self.assertRaises(SecurityError) as context:
InputValidator.validate_region_coordinates(0, 0, 0, 100)
self.assertIn("positive", str(context.exception))
def test_zero_height(self):
"""Test zero height."""
with self.assertRaises(SecurityError) as context:
InputValidator.validate_region_coordinates(0, 0, 100, 0)
self.assertIn("positive", str(context.exception))
def test_dimensions_too_large(self):
"""Test dimensions exceeding maximum."""
with self.assertRaises(SecurityError) as context:
InputValidator.validate_region_coordinates(0, 0, 10000, 10000)
self.assertIn("exceed maximum", str(context.exception))
def test_coordinates_out_of_bounds(self):
"""Test coordinates out of reasonable bounds."""
with self.assertRaises(SecurityError) as context:
InputValidator.validate_region_coordinates(-20000, 0, 100, 100)
self.assertIn("out of reasonable bounds", str(context.exception))
class TestInputValidatorValidateUrlEndpoint(unittest.TestCase):
"""Test InputValidator.validate_url_endpoint method."""
def test_valid_endpoint(self):
"""Test valid endpoints."""
result = InputValidator.validate_url_endpoint("api/v1/users")
self.assertEqual(result, "api/v1/users")
def test_empty_endpoint(self):
"""Test empty endpoint."""
with self.assertRaises(SecurityError) as context:
InputValidator.validate_url_endpoint("")
self.assertIn("cannot be empty", str(context.exception))
def test_leading_slash(self):
"""Test endpoint with leading slash."""
with self.assertRaises(SecurityError) as context:
InputValidator.validate_url_endpoint("/api/users")
self.assertIn("must not start with /", str(context.exception))
def test_path_traversal(self):
"""Test endpoint with path traversal."""
with self.assertRaises(SecurityError) as context:
InputValidator.validate_url_endpoint("api/../admin")
self.assertIn("Path traversal", str(context.exception))
def test_invalid_characters(self):
"""Test endpoint with invalid characters."""
with self.assertRaises(SecurityError) as context:
InputValidator.validate_url_endpoint("api/users?id=1")
self.assertIn("Invalid characters", str(context.exception))
class TestDataValidatorValidateDataStructure(unittest.TestCase):
"""Test DataValidator.validate_data_structure method."""
def test_valid_dict(self):
"""Test valid dictionary."""
data = {"key1": "value1", "key2": 42}
self.assertTrue(DataValidator.validate_data_structure(data))
def test_valid_nested_dict(self):
"""Test valid nested dictionary."""
data = {"outer": {"inner": "value"}}
self.assertTrue(DataValidator.validate_data_structure(data))
def test_valid_list(self):
"""Test valid list."""
data = [1, 2, 3, "string"]
self.assertTrue(DataValidator.validate_data_structure(data))
def test_empty_dict(self):
"""Test empty dictionary."""
self.assertTrue(DataValidator.validate_data_structure({}))
def test_deep_nesting(self):
"""Test deeply nested structure."""
# Create structure exceeding max depth
data = {}
current = data
for i in range(15):
current["nested"] = {}
current = current["nested"]
with self.assertRaises(SecurityError) as context:
DataValidator.validate_data_structure(data)
self.assertIn("nesting exceeds", str(context.exception))
def test_dict_too_large(self):
"""Test dictionary exceeding max size."""
data = {f"key{i}": i for i in range(15000)}
with self.assertRaises(SecurityError) as context:
DataValidator.validate_data_structure(data)
self.assertIn("size exceeds", str(context.exception))
def test_list_too_large(self):
"""Test list exceeding max size."""
data = list(range(15000))
with self.assertRaises(SecurityError) as context:
DataValidator.validate_data_structure(data)
self.assertIn("size exceeds", str(context.exception))
def test_string_too_long(self):
"""Test string exceeding max length."""
data = {"key": "x" * 200000}
with self.assertRaises(SecurityError) as context:
DataValidator.validate_data_structure(data)
self.assertIn("length exceeds", str(context.exception))
def test_invalid_dunder_key(self):
"""Test dictionary with dunder key."""
data = {"__class__": "value"}
with self.assertRaises(SecurityError) as context:
DataValidator.validate_data_structure(data)
self.assertIn("Invalid dictionary key", str(context.exception))
def test_unsupported_type(self):
"""Test unsupported data type."""
data = {"key": MagicMock()}
with self.assertRaises(SecurityError) as context:
DataValidator.validate_data_structure(data)
self.assertIn("Unsupported data type", str(context.exception))
def test_none_value(self):
"""Test None value."""
self.assertTrue(DataValidator.validate_data_structure(None))
def test_primitive_types(self):
"""Test primitive types."""
self.assertTrue(DataValidator.validate_data_structure(42))
self.assertTrue(DataValidator.validate_data_structure(3.14))
self.assertTrue(DataValidator.validate_data_structure(True))
class TestIntegrityCheckerComputeHash(unittest.TestCase):
"""Test IntegrityChecker.compute_hash method."""
def test_sha256(self):
"""Test SHA256 hash computation."""
data = b"test data"
result = IntegrityChecker.compute_hash(data, 'sha256')
expected = hashlib.sha256(data).hexdigest()
self.assertEqual(result, expected)
def test_sha512(self):
"""Test SHA512 hash computation."""
data = b"test data"
result = IntegrityChecker.compute_hash(data, 'sha512')
expected = hashlib.sha512(data).hexdigest()
self.assertEqual(result, expected)
def test_md5(self):
"""Test MD5 hash computation."""
data = b"test data"
result = IntegrityChecker.compute_hash(data, 'md5')
expected = hashlib.md5(data).hexdigest()
self.assertEqual(result, expected)
def test_invalid_algorithm(self):
"""Test invalid hash algorithm."""
with self.assertRaises(ValueError) as context:
IntegrityChecker.compute_hash(b"data", 'invalid')
self.assertIn("Unsupported hash algorithm", str(context.exception))
class TestIntegrityCheckerComputeHmac(unittest.TestCase):
"""Test IntegrityChecker.compute_hmac method."""
def test_hmac_sha256(self):
"""Test HMAC SHA256 computation."""
key = b"secret key"
data = b"test data"
result = IntegrityChecker.compute_hmac(key, data, 'sha256')
expected = hmac.new(key, data, hashlib.sha256).hexdigest()
self.assertEqual(result, expected)
def test_hmac_sha512(self):
"""Test HMAC SHA512 computation."""
key = b"secret key"
data = b"test data"
result = IntegrityChecker.compute_hmac(key, data, 'sha512')
expected = hmac.new(key, data, hashlib.sha512).hexdigest()
self.assertEqual(result, expected)
def test_hmac_invalid_algorithm(self):
"""Test invalid HMAC algorithm."""
with self.assertRaises(ValueError):
IntegrityChecker.compute_hmac(b"key", b"data", 'invalid')
class TestIntegrityCheckerVerifyHmac(unittest.TestCase):
"""Test IntegrityChecker.verify_hmac method."""
def test_valid_hmac(self):
"""Test verifying valid HMAC."""
key = b"secret key"
data = b"test data"
signature = IntegrityChecker.compute_hmac(key, data, 'sha256')
self.assertTrue(IntegrityChecker.verify_hmac(key, data, signature, 'sha256'))
def test_invalid_hmac(self):
"""Test verifying invalid HMAC."""
key = b"secret key"
data = b"test data"
self.assertFalse(IntegrityChecker.verify_hmac(key, data, "invalid_signature", 'sha256'))
def test_tampered_data(self):
"""Test HMAC verification with tampered data."""
key = b"secret key"
data = b"test data"
signature = IntegrityChecker.compute_hmac(key, data, 'sha256')
tampered_data = b"tampered data"
self.assertFalse(IntegrityChecker.verify_hmac(key, tampered_data, signature, 'sha256'))
class TestConvenienceFunctions(unittest.TestCase):
"""Test convenience functions."""
def test_sanitize_filename_convenience(self):
"""Test sanitize_filename convenience function."""
result = sanitize_filename("path/to/file.txt")
self.assertEqual(result, "path_to_file.txt")
def test_validate_path_convenience(self):
"""Test validate_path convenience function."""
import tempfile
temp_dir = tempfile.mkdtemp()
try:
result = validate_path(temp_dir + "/subdir", temp_dir)
self.assertIsInstance(result, Path)
finally:
import shutil
shutil.rmtree(temp_dir)
def test_safe_path_join_convenience(self):
"""Test safe_path_join convenience function."""
import tempfile
temp_dir = tempfile.mkdtemp()
try:
result = safe_path_join(temp_dir, "subdir", "file.txt")
self.assertIn("subdir", str(result))
self.assertIn("file.txt", str(result))
finally:
import shutil
shutil.rmtree(temp_dir)
class TestSecurityError(unittest.TestCase):
"""Test SecurityError exception."""
def test_security_error_is_exception(self):
"""Test that SecurityError is an Exception."""
self.assertTrue(issubclass(SecurityError, Exception))
def test_security_error_message(self):
"""Test SecurityError message."""
error = SecurityError("Test error message")
self.assertEqual(str(error), "Test error message")
if __name__ == '__main__':
unittest.main()