""" Unit tests for Security Utilities. Tests cover: - PathValidator - filename sanitization - PathValidator - path traversal prevention - InputValidator - identifier validation - InputValidator - string sanitization - DataValidator - data structure validation - IntegrityChecker - hash and HMAC computation """ import sys import unittest import hashlib import hmac from pathlib import Path from unittest.mock import MagicMock, patch # Add project root to path project_root = Path(__file__).parent.parent.parent if str(project_root) not in sys.path: sys.path.insert(0, str(project_root)) from core.security_utils import ( PathValidator, InputValidator, DataValidator, IntegrityChecker, SecurityError, sanitize_filename, validate_path, safe_path_join ) class TestPathValidatorSanitizeFilename(unittest.TestCase): """Test PathValidator.sanitize_filename method.""" def test_sanitize_simple_filename(self): """Test sanitizing a simple filename.""" result = PathValidator.sanitize_filename("test.txt") self.assertEqual(result, "test.txt") def test_sanitize_with_path_separators(self): """Test sanitizing filename with path separators.""" result = PathValidator.sanitize_filename("path/to/file.txt") self.assertEqual(result, "path_to_file.txt") result = PathValidator.sanitize_filename("path\\to\\file.txt") self.assertEqual(result, "path_to_file.txt") def test_sanitize_parent_directory_refs(self): """Test sanitizing filename with parent directory references.""" result = PathValidator.sanitize_filename("../../../etc/passwd") self.assertNotIn("..", result) self.assertEqual(result, "_etc_passwd") def test_sanitize_dangerous_characters(self): """Test sanitizing filename with dangerous characters.""" result = PathValidator.sanitize_filename("file:name*.txt") self.assertEqual(result, "file_name_.txt") def test_sanitize_null_bytes(self): """Test sanitizing filename with null bytes.""" result = PathValidator.sanitize_filename("file\x00name.txt") self.assertEqual(result, "file_name.txt") def test_sanitize_empty_filename(self): """Test sanitizing empty filename.""" result = PathValidator.sanitize_filename("") self.assertTrue(result.startswith("file_")) def test_sanitize_dots_only(self): """Test sanitizing filename with only dots.""" result = PathValidator.sanitize_filename("...") self.assertTrue(result.startswith("file_")) def test_sanitize_long_filename(self): """Test sanitizing very long filename.""" long_name = "a" * 300 + ".txt" result = PathValidator.sanitize_filename(long_name) self.assertLessEqual(len(result), PathValidator.MAX_FILENAME_LENGTH) def test_sanitize_custom_replacement(self): """Test sanitizing with custom replacement character.""" result = PathValidator.sanitize_filename("file:name", replacement="-") self.assertEqual(result, "file-name") class TestPathValidatorSplitExtension(unittest.TestCase): """Test PathValidator._split_extension method.""" def test_split_simple_extension(self): """Test splitting simple extension.""" name, ext = PathValidator._split_extension("file.txt") self.assertEqual(name, "file") self.assertEqual(ext, ".txt") def test_split_multiple_extensions(self): """Test splitting multiple extensions.""" name, ext = PathValidator._split_extension("archive.tar.gz") self.assertEqual(name, "archive.tar") self.assertEqual(ext, ".gz") def test_split_no_extension(self): """Test splitting file with no extension.""" name, ext = PathValidator._split_extension("file") self.assertEqual(name, "file") self.assertEqual(ext, "") class TestPathValidatorValidatePathWithinBase(unittest.TestCase): """Test PathValidator.validate_path_within_base method.""" def setUp(self): """Create temporary directory for tests.""" import tempfile self.temp_dir = tempfile.mkdtemp() def tearDown(self): """Clean up temporary directory.""" import shutil shutil.rmtree(self.temp_dir) def test_valid_path_within_base(self): """Test validating path within base directory.""" base = Path(self.temp_dir) path = base / "subdir" / "file.txt" result = PathValidator.validate_path_within_base(path, base) self.assertIsInstance(result, Path) self.assertTrue(str(result).startswith(str(base))) def test_path_traversal_detection(self): """Test detecting path traversal attack.""" base = Path(self.temp_dir) path = base / ".." / ".." / "etc" / "passwd" with self.assertRaises(SecurityError) as context: PathValidator.validate_path_within_base(path, base) self.assertIn("Path traversal", str(context.exception)) def test_absolute_path_outside_base(self): """Test absolute path outside base.""" base = Path(self.temp_dir) path = Path("/etc/passwd") with self.assertRaises(SecurityError): PathValidator.validate_path_within_base(path, base) def test_symlink_detection(self): """Test detecting symlinks.""" base = Path(self.temp_dir) # Create a symlink symlink_path = base / "link" target = Path("/etc") try: symlink_path.symlink_to(target) with self.assertRaises(SecurityError) as context: PathValidator.validate_path_within_base(symlink_path / "passwd", base) self.assertIn("Symlink", str(context.exception)) except OSError: self.skipTest("Cannot create symlinks on this system") def test_path_depth_exceeded(self): """Test path depth limit.""" base = Path(self.temp_dir) # Create a path that's too deep path = base for i in range(15): path = path / f"dir{i}" with self.assertRaises(SecurityError) as context: PathValidator.validate_path_within_base(path, base) self.assertIn("Path depth", str(context.exception)) class TestPathValidatorSafeJoin(unittest.TestCase): """Test PathValidator.safe_join method.""" def setUp(self): """Create temporary directory for tests.""" import tempfile self.temp_dir = tempfile.mkdtemp() def tearDown(self): """Clean up temporary directory.""" import shutil shutil.rmtree(self.temp_dir) def test_safe_join_simple(self): """Test safe join with simple paths.""" base = Path(self.temp_dir) result = PathValidator.safe_join(base, "subdir", "file.txt") self.assertTrue(str(result).startswith(str(base))) self.assertIn("subdir", str(result)) self.assertIn("file.txt", str(result)) def test_safe_join_sanitizes_components(self): """Test that safe_join sanitizes path components.""" base = Path(self.temp_dir) result = PathValidator.safe_join(base, "../etc", "passwd") self.assertNotIn("..", str(result)) self.assertTrue(str(result).startswith(str(base))) class TestInputValidatorIsSafeIdentifier(unittest.TestCase): """Test InputValidator.is_safe_identifier method.""" def test_valid_identifier(self): """Test valid Python identifiers.""" self.assertTrue(InputValidator.is_safe_identifier("valid_name")) self.assertTrue(InputValidator.is_safe_identifier("_private")) self.assertTrue(InputValidator.is_safe_identifier("Name123")) def test_invalid_identifier_starts_with_number(self): """Test identifier starting with number.""" self.assertFalse(InputValidator.is_safe_identifier("123name")) def test_invalid_identifier_with_spaces(self): """Test identifier with spaces.""" self.assertFalse(InputValidator.is_safe_identifier("name with spaces")) def test_invalid_identifier_special_chars(self): """Test identifier with special characters.""" self.assertFalse(InputValidator.is_safe_identifier("name-with-dash")) self.assertFalse(InputValidator.is_safe_identifier("name.with.dots")) def test_python_keyword(self): """Test Python keywords are rejected.""" self.assertFalse(InputValidator.is_safe_identifier("class")) self.assertFalse(InputValidator.is_safe_identifier("def")) self.assertFalse(InputValidator.is_safe_identifier("import")) def test_empty_identifier(self): """Test empty identifier.""" self.assertFalse(InputValidator.is_safe_identifier("")) class TestInputValidatorSanitizeString(unittest.TestCase): """Test InputValidator.sanitize_string method.""" def test_sanitize_simple_string(self): """Test sanitizing simple string.""" result = InputValidator.sanitize_string("Hello World") self.assertEqual(result, "Hello World") def test_sanitize_removes_null_bytes(self): """Test removing null bytes.""" result = InputValidator.sanitize_string("Hello\x00World") self.assertEqual(result, "HelloWorld") def test_sanitize_removes_control_chars(self): """Test removing control characters.""" result = InputValidator.sanitize_string("Hello\x01\x02World") self.assertEqual(result, "HelloWorld") def test_sanitize_keeps_newlines_and_tabs(self): """Test that newlines and tabs are preserved.""" result = InputValidator.sanitize_string("Line1\nLine2\tTabbed") self.assertEqual(result, "Line1\nLine2\tTabbed") def test_sanitize_truncates_long_string(self): """Test truncating long strings.""" long_string = "a" * 2000 result = InputValidator.sanitize_string(long_string, max_length=100) self.assertEqual(len(result), 100) def test_sanitize_non_string_input(self): """Test sanitizing non-string input.""" result = InputValidator.sanitize_string(12345) self.assertEqual(result, "12345") class TestInputValidatorValidateJsonKey(unittest.TestCase): """Test InputValidator.validate_json_key method.""" def test_valid_key(self): """Test valid JSON keys.""" self.assertTrue(InputValidator.validate_json_key("valid_key")) self.assertTrue(InputValidator.validate_json_key("key123")) def test_dunder_key_rejected(self): """Test dunder keys are rejected.""" self.assertFalse(InputValidator.validate_json_key("__class__")) self.assertFalse(InputValidator.validate_json_key("__init__")) def test_empty_key_rejected(self): """Test empty keys are rejected.""" self.assertFalse(InputValidator.validate_json_key("")) def test_dangerous_patterns_rejected(self): """Test dangerous patterns are rejected.""" self.assertFalse(InputValidator.validate_json_key("eval(something)")) self.assertFalse(InputValidator.validate_json_key("exec(")) self.assertFalse(InputValidator.validate_json_key("os.system")) def test_non_string_key(self): """Test non-string keys are rejected.""" self.assertFalse(InputValidator.validate_json_key(123)) self.assertFalse(InputValidator.validate_json_key(None)) class TestInputValidatorValidateRegionCoordinates(unittest.TestCase): """Test InputValidator.validate_region_coordinates method.""" def test_valid_coordinates(self): """Test valid region coordinates.""" # Should not raise InputValidator.validate_region_coordinates(0, 0, 100, 100) InputValidator.validate_region_coordinates(100, 200, 500, 300) def test_invalid_type(self): """Test invalid coordinate types.""" with self.assertRaises(SecurityError) as context: InputValidator.validate_region_coordinates("0", 0, 100, 100) self.assertIn("must be integers", str(context.exception)) def test_zero_width(self): """Test zero width.""" with self.assertRaises(SecurityError) as context: InputValidator.validate_region_coordinates(0, 0, 0, 100) self.assertIn("positive", str(context.exception)) def test_zero_height(self): """Test zero height.""" with self.assertRaises(SecurityError) as context: InputValidator.validate_region_coordinates(0, 0, 100, 0) self.assertIn("positive", str(context.exception)) def test_dimensions_too_large(self): """Test dimensions exceeding maximum.""" with self.assertRaises(SecurityError) as context: InputValidator.validate_region_coordinates(0, 0, 10000, 10000) self.assertIn("exceed maximum", str(context.exception)) def test_coordinates_out_of_bounds(self): """Test coordinates out of reasonable bounds.""" with self.assertRaises(SecurityError) as context: InputValidator.validate_region_coordinates(-20000, 0, 100, 100) self.assertIn("out of reasonable bounds", str(context.exception)) class TestInputValidatorValidateUrlEndpoint(unittest.TestCase): """Test InputValidator.validate_url_endpoint method.""" def test_valid_endpoint(self): """Test valid endpoints.""" result = InputValidator.validate_url_endpoint("api/v1/users") self.assertEqual(result, "api/v1/users") def test_empty_endpoint(self): """Test empty endpoint.""" with self.assertRaises(SecurityError) as context: InputValidator.validate_url_endpoint("") self.assertIn("cannot be empty", str(context.exception)) def test_leading_slash(self): """Test endpoint with leading slash.""" with self.assertRaises(SecurityError) as context: InputValidator.validate_url_endpoint("/api/users") self.assertIn("must not start with /", str(context.exception)) def test_path_traversal(self): """Test endpoint with path traversal.""" with self.assertRaises(SecurityError) as context: InputValidator.validate_url_endpoint("api/../admin") self.assertIn("Path traversal", str(context.exception)) def test_invalid_characters(self): """Test endpoint with invalid characters.""" with self.assertRaises(SecurityError) as context: InputValidator.validate_url_endpoint("api/users?id=1") self.assertIn("Invalid characters", str(context.exception)) class TestDataValidatorValidateDataStructure(unittest.TestCase): """Test DataValidator.validate_data_structure method.""" def test_valid_dict(self): """Test valid dictionary.""" data = {"key1": "value1", "key2": 42} self.assertTrue(DataValidator.validate_data_structure(data)) def test_valid_nested_dict(self): """Test valid nested dictionary.""" data = {"outer": {"inner": "value"}} self.assertTrue(DataValidator.validate_data_structure(data)) def test_valid_list(self): """Test valid list.""" data = [1, 2, 3, "string"] self.assertTrue(DataValidator.validate_data_structure(data)) def test_empty_dict(self): """Test empty dictionary.""" self.assertTrue(DataValidator.validate_data_structure({})) def test_deep_nesting(self): """Test deeply nested structure.""" # Create structure exceeding max depth data = {} current = data for i in range(15): current["nested"] = {} current = current["nested"] with self.assertRaises(SecurityError) as context: DataValidator.validate_data_structure(data) self.assertIn("nesting exceeds", str(context.exception)) def test_dict_too_large(self): """Test dictionary exceeding max size.""" data = {f"key{i}": i for i in range(15000)} with self.assertRaises(SecurityError) as context: DataValidator.validate_data_structure(data) self.assertIn("size exceeds", str(context.exception)) def test_list_too_large(self): """Test list exceeding max size.""" data = list(range(15000)) with self.assertRaises(SecurityError) as context: DataValidator.validate_data_structure(data) self.assertIn("size exceeds", str(context.exception)) def test_string_too_long(self): """Test string exceeding max length.""" data = {"key": "x" * 200000} with self.assertRaises(SecurityError) as context: DataValidator.validate_data_structure(data) self.assertIn("length exceeds", str(context.exception)) def test_invalid_dunder_key(self): """Test dictionary with dunder key.""" data = {"__class__": "value"} with self.assertRaises(SecurityError) as context: DataValidator.validate_data_structure(data) self.assertIn("Invalid dictionary key", str(context.exception)) def test_unsupported_type(self): """Test unsupported data type.""" data = {"key": MagicMock()} with self.assertRaises(SecurityError) as context: DataValidator.validate_data_structure(data) self.assertIn("Unsupported data type", str(context.exception)) def test_none_value(self): """Test None value.""" self.assertTrue(DataValidator.validate_data_structure(None)) def test_primitive_types(self): """Test primitive types.""" self.assertTrue(DataValidator.validate_data_structure(42)) self.assertTrue(DataValidator.validate_data_structure(3.14)) self.assertTrue(DataValidator.validate_data_structure(True)) class TestIntegrityCheckerComputeHash(unittest.TestCase): """Test IntegrityChecker.compute_hash method.""" def test_sha256(self): """Test SHA256 hash computation.""" data = b"test data" result = IntegrityChecker.compute_hash(data, 'sha256') expected = hashlib.sha256(data).hexdigest() self.assertEqual(result, expected) def test_sha512(self): """Test SHA512 hash computation.""" data = b"test data" result = IntegrityChecker.compute_hash(data, 'sha512') expected = hashlib.sha512(data).hexdigest() self.assertEqual(result, expected) def test_md5(self): """Test MD5 hash computation.""" data = b"test data" result = IntegrityChecker.compute_hash(data, 'md5') expected = hashlib.md5(data).hexdigest() self.assertEqual(result, expected) def test_invalid_algorithm(self): """Test invalid hash algorithm.""" with self.assertRaises(ValueError) as context: IntegrityChecker.compute_hash(b"data", 'invalid') self.assertIn("Unsupported hash algorithm", str(context.exception)) class TestIntegrityCheckerComputeHmac(unittest.TestCase): """Test IntegrityChecker.compute_hmac method.""" def test_hmac_sha256(self): """Test HMAC SHA256 computation.""" key = b"secret key" data = b"test data" result = IntegrityChecker.compute_hmac(key, data, 'sha256') expected = hmac.new(key, data, hashlib.sha256).hexdigest() self.assertEqual(result, expected) def test_hmac_sha512(self): """Test HMAC SHA512 computation.""" key = b"secret key" data = b"test data" result = IntegrityChecker.compute_hmac(key, data, 'sha512') expected = hmac.new(key, data, hashlib.sha512).hexdigest() self.assertEqual(result, expected) def test_hmac_invalid_algorithm(self): """Test invalid HMAC algorithm.""" with self.assertRaises(ValueError): IntegrityChecker.compute_hmac(b"key", b"data", 'invalid') class TestIntegrityCheckerVerifyHmac(unittest.TestCase): """Test IntegrityChecker.verify_hmac method.""" def test_valid_hmac(self): """Test verifying valid HMAC.""" key = b"secret key" data = b"test data" signature = IntegrityChecker.compute_hmac(key, data, 'sha256') self.assertTrue(IntegrityChecker.verify_hmac(key, data, signature, 'sha256')) def test_invalid_hmac(self): """Test verifying invalid HMAC.""" key = b"secret key" data = b"test data" self.assertFalse(IntegrityChecker.verify_hmac(key, data, "invalid_signature", 'sha256')) def test_tampered_data(self): """Test HMAC verification with tampered data.""" key = b"secret key" data = b"test data" signature = IntegrityChecker.compute_hmac(key, data, 'sha256') tampered_data = b"tampered data" self.assertFalse(IntegrityChecker.verify_hmac(key, tampered_data, signature, 'sha256')) class TestConvenienceFunctions(unittest.TestCase): """Test convenience functions.""" def test_sanitize_filename_convenience(self): """Test sanitize_filename convenience function.""" result = sanitize_filename("path/to/file.txt") self.assertEqual(result, "path_to_file.txt") def test_validate_path_convenience(self): """Test validate_path convenience function.""" import tempfile temp_dir = tempfile.mkdtemp() try: result = validate_path(temp_dir + "/subdir", temp_dir) self.assertIsInstance(result, Path) finally: import shutil shutil.rmtree(temp_dir) def test_safe_path_join_convenience(self): """Test safe_path_join convenience function.""" import tempfile temp_dir = tempfile.mkdtemp() try: result = safe_path_join(temp_dir, "subdir", "file.txt") self.assertIn("subdir", str(result)) self.assertIn("file.txt", str(result)) finally: import shutil shutil.rmtree(temp_dir) class TestSecurityError(unittest.TestCase): """Test SecurityError exception.""" def test_security_error_is_exception(self): """Test that SecurityError is an Exception.""" self.assertTrue(issubclass(SecurityError, Exception)) def test_security_error_message(self): """Test SecurityError message.""" error = SecurityError("Test error message") self.assertEqual(str(error), "Test error message") if __name__ == '__main__': unittest.main()