600 lines
22 KiB
Python
600 lines
22 KiB
Python
"""
|
|
Unit tests for Security Utilities.
|
|
|
|
Tests cover:
|
|
- PathValidator - filename sanitization
|
|
- PathValidator - path traversal prevention
|
|
- InputValidator - identifier validation
|
|
- InputValidator - string sanitization
|
|
- DataValidator - data structure validation
|
|
- IntegrityChecker - hash and HMAC computation
|
|
"""
|
|
import sys
|
|
import unittest
|
|
import hashlib
|
|
import hmac
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
# Add project root to path
|
|
project_root = Path(__file__).parent.parent.parent
|
|
if str(project_root) not in sys.path:
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
from core.security_utils import (
|
|
PathValidator, InputValidator, DataValidator, IntegrityChecker,
|
|
SecurityError, sanitize_filename, validate_path, safe_path_join
|
|
)
|
|
|
|
|
|
class TestPathValidatorSanitizeFilename(unittest.TestCase):
|
|
"""Test PathValidator.sanitize_filename method."""
|
|
|
|
def test_sanitize_simple_filename(self):
|
|
"""Test sanitizing a simple filename."""
|
|
result = PathValidator.sanitize_filename("test.txt")
|
|
self.assertEqual(result, "test.txt")
|
|
|
|
def test_sanitize_with_path_separators(self):
|
|
"""Test sanitizing filename with path separators."""
|
|
result = PathValidator.sanitize_filename("path/to/file.txt")
|
|
self.assertEqual(result, "path_to_file.txt")
|
|
|
|
result = PathValidator.sanitize_filename("path\\to\\file.txt")
|
|
self.assertEqual(result, "path_to_file.txt")
|
|
|
|
def test_sanitize_parent_directory_refs(self):
|
|
"""Test sanitizing filename with parent directory references."""
|
|
result = PathValidator.sanitize_filename("../../../etc/passwd")
|
|
self.assertNotIn("..", result)
|
|
self.assertEqual(result, "_etc_passwd")
|
|
|
|
def test_sanitize_dangerous_characters(self):
|
|
"""Test sanitizing filename with dangerous characters."""
|
|
result = PathValidator.sanitize_filename("file:name*.txt")
|
|
self.assertEqual(result, "file_name_.txt")
|
|
|
|
def test_sanitize_null_bytes(self):
|
|
"""Test sanitizing filename with null bytes."""
|
|
result = PathValidator.sanitize_filename("file\x00name.txt")
|
|
self.assertEqual(result, "file_name.txt")
|
|
|
|
def test_sanitize_empty_filename(self):
|
|
"""Test sanitizing empty filename."""
|
|
result = PathValidator.sanitize_filename("")
|
|
self.assertTrue(result.startswith("file_"))
|
|
|
|
def test_sanitize_dots_only(self):
|
|
"""Test sanitizing filename with only dots."""
|
|
result = PathValidator.sanitize_filename("...")
|
|
self.assertTrue(result.startswith("file_"))
|
|
|
|
def test_sanitize_long_filename(self):
|
|
"""Test sanitizing very long filename."""
|
|
long_name = "a" * 300 + ".txt"
|
|
result = PathValidator.sanitize_filename(long_name)
|
|
self.assertLessEqual(len(result), PathValidator.MAX_FILENAME_LENGTH)
|
|
|
|
def test_sanitize_custom_replacement(self):
|
|
"""Test sanitizing with custom replacement character."""
|
|
result = PathValidator.sanitize_filename("file:name", replacement="-")
|
|
self.assertEqual(result, "file-name")
|
|
|
|
|
|
class TestPathValidatorSplitExtension(unittest.TestCase):
|
|
"""Test PathValidator._split_extension method."""
|
|
|
|
def test_split_simple_extension(self):
|
|
"""Test splitting simple extension."""
|
|
name, ext = PathValidator._split_extension("file.txt")
|
|
self.assertEqual(name, "file")
|
|
self.assertEqual(ext, ".txt")
|
|
|
|
def test_split_multiple_extensions(self):
|
|
"""Test splitting multiple extensions."""
|
|
name, ext = PathValidator._split_extension("archive.tar.gz")
|
|
self.assertEqual(name, "archive.tar")
|
|
self.assertEqual(ext, ".gz")
|
|
|
|
def test_split_no_extension(self):
|
|
"""Test splitting file with no extension."""
|
|
name, ext = PathValidator._split_extension("file")
|
|
self.assertEqual(name, "file")
|
|
self.assertEqual(ext, "")
|
|
|
|
|
|
class TestPathValidatorValidatePathWithinBase(unittest.TestCase):
|
|
"""Test PathValidator.validate_path_within_base method."""
|
|
|
|
def setUp(self):
|
|
"""Create temporary directory for tests."""
|
|
import tempfile
|
|
self.temp_dir = tempfile.mkdtemp()
|
|
|
|
def tearDown(self):
|
|
"""Clean up temporary directory."""
|
|
import shutil
|
|
shutil.rmtree(self.temp_dir)
|
|
|
|
def test_valid_path_within_base(self):
|
|
"""Test validating path within base directory."""
|
|
base = Path(self.temp_dir)
|
|
path = base / "subdir" / "file.txt"
|
|
|
|
result = PathValidator.validate_path_within_base(path, base)
|
|
|
|
self.assertIsInstance(result, Path)
|
|
self.assertTrue(str(result).startswith(str(base)))
|
|
|
|
def test_path_traversal_detection(self):
|
|
"""Test detecting path traversal attack."""
|
|
base = Path(self.temp_dir)
|
|
path = base / ".." / ".." / "etc" / "passwd"
|
|
|
|
with self.assertRaises(SecurityError) as context:
|
|
PathValidator.validate_path_within_base(path, base)
|
|
|
|
self.assertIn("Path traversal", str(context.exception))
|
|
|
|
def test_absolute_path_outside_base(self):
|
|
"""Test absolute path outside base."""
|
|
base = Path(self.temp_dir)
|
|
path = Path("/etc/passwd")
|
|
|
|
with self.assertRaises(SecurityError):
|
|
PathValidator.validate_path_within_base(path, base)
|
|
|
|
def test_symlink_detection(self):
|
|
"""Test detecting symlinks."""
|
|
base = Path(self.temp_dir)
|
|
|
|
# Create a symlink
|
|
symlink_path = base / "link"
|
|
target = Path("/etc")
|
|
try:
|
|
symlink_path.symlink_to(target)
|
|
|
|
with self.assertRaises(SecurityError) as context:
|
|
PathValidator.validate_path_within_base(symlink_path / "passwd", base)
|
|
|
|
self.assertIn("Symlink", str(context.exception))
|
|
except OSError:
|
|
self.skipTest("Cannot create symlinks on this system")
|
|
|
|
def test_path_depth_exceeded(self):
|
|
"""Test path depth limit."""
|
|
base = Path(self.temp_dir)
|
|
# Create a path that's too deep
|
|
path = base
|
|
for i in range(15):
|
|
path = path / f"dir{i}"
|
|
|
|
with self.assertRaises(SecurityError) as context:
|
|
PathValidator.validate_path_within_base(path, base)
|
|
|
|
self.assertIn("Path depth", str(context.exception))
|
|
|
|
|
|
class TestPathValidatorSafeJoin(unittest.TestCase):
|
|
"""Test PathValidator.safe_join method."""
|
|
|
|
def setUp(self):
|
|
"""Create temporary directory for tests."""
|
|
import tempfile
|
|
self.temp_dir = tempfile.mkdtemp()
|
|
|
|
def tearDown(self):
|
|
"""Clean up temporary directory."""
|
|
import shutil
|
|
shutil.rmtree(self.temp_dir)
|
|
|
|
def test_safe_join_simple(self):
|
|
"""Test safe join with simple paths."""
|
|
base = Path(self.temp_dir)
|
|
|
|
result = PathValidator.safe_join(base, "subdir", "file.txt")
|
|
|
|
self.assertTrue(str(result).startswith(str(base)))
|
|
self.assertIn("subdir", str(result))
|
|
self.assertIn("file.txt", str(result))
|
|
|
|
def test_safe_join_sanitizes_components(self):
|
|
"""Test that safe_join sanitizes path components."""
|
|
base = Path(self.temp_dir)
|
|
|
|
result = PathValidator.safe_join(base, "../etc", "passwd")
|
|
|
|
self.assertNotIn("..", str(result))
|
|
self.assertTrue(str(result).startswith(str(base)))
|
|
|
|
|
|
class TestInputValidatorIsSafeIdentifier(unittest.TestCase):
|
|
"""Test InputValidator.is_safe_identifier method."""
|
|
|
|
def test_valid_identifier(self):
|
|
"""Test valid Python identifiers."""
|
|
self.assertTrue(InputValidator.is_safe_identifier("valid_name"))
|
|
self.assertTrue(InputValidator.is_safe_identifier("_private"))
|
|
self.assertTrue(InputValidator.is_safe_identifier("Name123"))
|
|
|
|
def test_invalid_identifier_starts_with_number(self):
|
|
"""Test identifier starting with number."""
|
|
self.assertFalse(InputValidator.is_safe_identifier("123name"))
|
|
|
|
def test_invalid_identifier_with_spaces(self):
|
|
"""Test identifier with spaces."""
|
|
self.assertFalse(InputValidator.is_safe_identifier("name with spaces"))
|
|
|
|
def test_invalid_identifier_special_chars(self):
|
|
"""Test identifier with special characters."""
|
|
self.assertFalse(InputValidator.is_safe_identifier("name-with-dash"))
|
|
self.assertFalse(InputValidator.is_safe_identifier("name.with.dots"))
|
|
|
|
def test_python_keyword(self):
|
|
"""Test Python keywords are rejected."""
|
|
self.assertFalse(InputValidator.is_safe_identifier("class"))
|
|
self.assertFalse(InputValidator.is_safe_identifier("def"))
|
|
self.assertFalse(InputValidator.is_safe_identifier("import"))
|
|
|
|
def test_empty_identifier(self):
|
|
"""Test empty identifier."""
|
|
self.assertFalse(InputValidator.is_safe_identifier(""))
|
|
|
|
|
|
class TestInputValidatorSanitizeString(unittest.TestCase):
|
|
"""Test InputValidator.sanitize_string method."""
|
|
|
|
def test_sanitize_simple_string(self):
|
|
"""Test sanitizing simple string."""
|
|
result = InputValidator.sanitize_string("Hello World")
|
|
self.assertEqual(result, "Hello World")
|
|
|
|
def test_sanitize_removes_null_bytes(self):
|
|
"""Test removing null bytes."""
|
|
result = InputValidator.sanitize_string("Hello\x00World")
|
|
self.assertEqual(result, "HelloWorld")
|
|
|
|
def test_sanitize_removes_control_chars(self):
|
|
"""Test removing control characters."""
|
|
result = InputValidator.sanitize_string("Hello\x01\x02World")
|
|
self.assertEqual(result, "HelloWorld")
|
|
|
|
def test_sanitize_keeps_newlines_and_tabs(self):
|
|
"""Test that newlines and tabs are preserved."""
|
|
result = InputValidator.sanitize_string("Line1\nLine2\tTabbed")
|
|
self.assertEqual(result, "Line1\nLine2\tTabbed")
|
|
|
|
def test_sanitize_truncates_long_string(self):
|
|
"""Test truncating long strings."""
|
|
long_string = "a" * 2000
|
|
result = InputValidator.sanitize_string(long_string, max_length=100)
|
|
self.assertEqual(len(result), 100)
|
|
|
|
def test_sanitize_non_string_input(self):
|
|
"""Test sanitizing non-string input."""
|
|
result = InputValidator.sanitize_string(12345)
|
|
self.assertEqual(result, "12345")
|
|
|
|
|
|
class TestInputValidatorValidateJsonKey(unittest.TestCase):
|
|
"""Test InputValidator.validate_json_key method."""
|
|
|
|
def test_valid_key(self):
|
|
"""Test valid JSON keys."""
|
|
self.assertTrue(InputValidator.validate_json_key("valid_key"))
|
|
self.assertTrue(InputValidator.validate_json_key("key123"))
|
|
|
|
def test_dunder_key_rejected(self):
|
|
"""Test dunder keys are rejected."""
|
|
self.assertFalse(InputValidator.validate_json_key("__class__"))
|
|
self.assertFalse(InputValidator.validate_json_key("__init__"))
|
|
|
|
def test_empty_key_rejected(self):
|
|
"""Test empty keys are rejected."""
|
|
self.assertFalse(InputValidator.validate_json_key(""))
|
|
|
|
def test_dangerous_patterns_rejected(self):
|
|
"""Test dangerous patterns are rejected."""
|
|
self.assertFalse(InputValidator.validate_json_key("eval(something)"))
|
|
self.assertFalse(InputValidator.validate_json_key("exec("))
|
|
self.assertFalse(InputValidator.validate_json_key("os.system"))
|
|
|
|
def test_non_string_key(self):
|
|
"""Test non-string keys are rejected."""
|
|
self.assertFalse(InputValidator.validate_json_key(123))
|
|
self.assertFalse(InputValidator.validate_json_key(None))
|
|
|
|
|
|
class TestInputValidatorValidateRegionCoordinates(unittest.TestCase):
|
|
"""Test InputValidator.validate_region_coordinates method."""
|
|
|
|
def test_valid_coordinates(self):
|
|
"""Test valid region coordinates."""
|
|
# Should not raise
|
|
InputValidator.validate_region_coordinates(0, 0, 100, 100)
|
|
InputValidator.validate_region_coordinates(100, 200, 500, 300)
|
|
|
|
def test_invalid_type(self):
|
|
"""Test invalid coordinate types."""
|
|
with self.assertRaises(SecurityError) as context:
|
|
InputValidator.validate_region_coordinates("0", 0, 100, 100)
|
|
self.assertIn("must be integers", str(context.exception))
|
|
|
|
def test_zero_width(self):
|
|
"""Test zero width."""
|
|
with self.assertRaises(SecurityError) as context:
|
|
InputValidator.validate_region_coordinates(0, 0, 0, 100)
|
|
self.assertIn("positive", str(context.exception))
|
|
|
|
def test_zero_height(self):
|
|
"""Test zero height."""
|
|
with self.assertRaises(SecurityError) as context:
|
|
InputValidator.validate_region_coordinates(0, 0, 100, 0)
|
|
self.assertIn("positive", str(context.exception))
|
|
|
|
def test_dimensions_too_large(self):
|
|
"""Test dimensions exceeding maximum."""
|
|
with self.assertRaises(SecurityError) as context:
|
|
InputValidator.validate_region_coordinates(0, 0, 10000, 10000)
|
|
self.assertIn("exceed maximum", str(context.exception))
|
|
|
|
def test_coordinates_out_of_bounds(self):
|
|
"""Test coordinates out of reasonable bounds."""
|
|
with self.assertRaises(SecurityError) as context:
|
|
InputValidator.validate_region_coordinates(-20000, 0, 100, 100)
|
|
self.assertIn("out of reasonable bounds", str(context.exception))
|
|
|
|
|
|
class TestInputValidatorValidateUrlEndpoint(unittest.TestCase):
|
|
"""Test InputValidator.validate_url_endpoint method."""
|
|
|
|
def test_valid_endpoint(self):
|
|
"""Test valid endpoints."""
|
|
result = InputValidator.validate_url_endpoint("api/v1/users")
|
|
self.assertEqual(result, "api/v1/users")
|
|
|
|
def test_empty_endpoint(self):
|
|
"""Test empty endpoint."""
|
|
with self.assertRaises(SecurityError) as context:
|
|
InputValidator.validate_url_endpoint("")
|
|
self.assertIn("cannot be empty", str(context.exception))
|
|
|
|
def test_leading_slash(self):
|
|
"""Test endpoint with leading slash."""
|
|
with self.assertRaises(SecurityError) as context:
|
|
InputValidator.validate_url_endpoint("/api/users")
|
|
self.assertIn("must not start with /", str(context.exception))
|
|
|
|
def test_path_traversal(self):
|
|
"""Test endpoint with path traversal."""
|
|
with self.assertRaises(SecurityError) as context:
|
|
InputValidator.validate_url_endpoint("api/../admin")
|
|
self.assertIn("Path traversal", str(context.exception))
|
|
|
|
def test_invalid_characters(self):
|
|
"""Test endpoint with invalid characters."""
|
|
with self.assertRaises(SecurityError) as context:
|
|
InputValidator.validate_url_endpoint("api/users?id=1")
|
|
self.assertIn("Invalid characters", str(context.exception))
|
|
|
|
|
|
class TestDataValidatorValidateDataStructure(unittest.TestCase):
|
|
"""Test DataValidator.validate_data_structure method."""
|
|
|
|
def test_valid_dict(self):
|
|
"""Test valid dictionary."""
|
|
data = {"key1": "value1", "key2": 42}
|
|
self.assertTrue(DataValidator.validate_data_structure(data))
|
|
|
|
def test_valid_nested_dict(self):
|
|
"""Test valid nested dictionary."""
|
|
data = {"outer": {"inner": "value"}}
|
|
self.assertTrue(DataValidator.validate_data_structure(data))
|
|
|
|
def test_valid_list(self):
|
|
"""Test valid list."""
|
|
data = [1, 2, 3, "string"]
|
|
self.assertTrue(DataValidator.validate_data_structure(data))
|
|
|
|
def test_empty_dict(self):
|
|
"""Test empty dictionary."""
|
|
self.assertTrue(DataValidator.validate_data_structure({}))
|
|
|
|
def test_deep_nesting(self):
|
|
"""Test deeply nested structure."""
|
|
# Create structure exceeding max depth
|
|
data = {}
|
|
current = data
|
|
for i in range(15):
|
|
current["nested"] = {}
|
|
current = current["nested"]
|
|
|
|
with self.assertRaises(SecurityError) as context:
|
|
DataValidator.validate_data_structure(data)
|
|
self.assertIn("nesting exceeds", str(context.exception))
|
|
|
|
def test_dict_too_large(self):
|
|
"""Test dictionary exceeding max size."""
|
|
data = {f"key{i}": i for i in range(15000)}
|
|
|
|
with self.assertRaises(SecurityError) as context:
|
|
DataValidator.validate_data_structure(data)
|
|
self.assertIn("size exceeds", str(context.exception))
|
|
|
|
def test_list_too_large(self):
|
|
"""Test list exceeding max size."""
|
|
data = list(range(15000))
|
|
|
|
with self.assertRaises(SecurityError) as context:
|
|
DataValidator.validate_data_structure(data)
|
|
self.assertIn("size exceeds", str(context.exception))
|
|
|
|
def test_string_too_long(self):
|
|
"""Test string exceeding max length."""
|
|
data = {"key": "x" * 200000}
|
|
|
|
with self.assertRaises(SecurityError) as context:
|
|
DataValidator.validate_data_structure(data)
|
|
self.assertIn("length exceeds", str(context.exception))
|
|
|
|
def test_invalid_dunder_key(self):
|
|
"""Test dictionary with dunder key."""
|
|
data = {"__class__": "value"}
|
|
|
|
with self.assertRaises(SecurityError) as context:
|
|
DataValidator.validate_data_structure(data)
|
|
self.assertIn("Invalid dictionary key", str(context.exception))
|
|
|
|
def test_unsupported_type(self):
|
|
"""Test unsupported data type."""
|
|
data = {"key": MagicMock()}
|
|
|
|
with self.assertRaises(SecurityError) as context:
|
|
DataValidator.validate_data_structure(data)
|
|
self.assertIn("Unsupported data type", str(context.exception))
|
|
|
|
def test_none_value(self):
|
|
"""Test None value."""
|
|
self.assertTrue(DataValidator.validate_data_structure(None))
|
|
|
|
def test_primitive_types(self):
|
|
"""Test primitive types."""
|
|
self.assertTrue(DataValidator.validate_data_structure(42))
|
|
self.assertTrue(DataValidator.validate_data_structure(3.14))
|
|
self.assertTrue(DataValidator.validate_data_structure(True))
|
|
|
|
|
|
class TestIntegrityCheckerComputeHash(unittest.TestCase):
|
|
"""Test IntegrityChecker.compute_hash method."""
|
|
|
|
def test_sha256(self):
|
|
"""Test SHA256 hash computation."""
|
|
data = b"test data"
|
|
result = IntegrityChecker.compute_hash(data, 'sha256')
|
|
expected = hashlib.sha256(data).hexdigest()
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_sha512(self):
|
|
"""Test SHA512 hash computation."""
|
|
data = b"test data"
|
|
result = IntegrityChecker.compute_hash(data, 'sha512')
|
|
expected = hashlib.sha512(data).hexdigest()
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_md5(self):
|
|
"""Test MD5 hash computation."""
|
|
data = b"test data"
|
|
result = IntegrityChecker.compute_hash(data, 'md5')
|
|
expected = hashlib.md5(data).hexdigest()
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_invalid_algorithm(self):
|
|
"""Test invalid hash algorithm."""
|
|
with self.assertRaises(ValueError) as context:
|
|
IntegrityChecker.compute_hash(b"data", 'invalid')
|
|
self.assertIn("Unsupported hash algorithm", str(context.exception))
|
|
|
|
|
|
class TestIntegrityCheckerComputeHmac(unittest.TestCase):
|
|
"""Test IntegrityChecker.compute_hmac method."""
|
|
|
|
def test_hmac_sha256(self):
|
|
"""Test HMAC SHA256 computation."""
|
|
key = b"secret key"
|
|
data = b"test data"
|
|
result = IntegrityChecker.compute_hmac(key, data, 'sha256')
|
|
expected = hmac.new(key, data, hashlib.sha256).hexdigest()
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_hmac_sha512(self):
|
|
"""Test HMAC SHA512 computation."""
|
|
key = b"secret key"
|
|
data = b"test data"
|
|
result = IntegrityChecker.compute_hmac(key, data, 'sha512')
|
|
expected = hmac.new(key, data, hashlib.sha512).hexdigest()
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_hmac_invalid_algorithm(self):
|
|
"""Test invalid HMAC algorithm."""
|
|
with self.assertRaises(ValueError):
|
|
IntegrityChecker.compute_hmac(b"key", b"data", 'invalid')
|
|
|
|
|
|
class TestIntegrityCheckerVerifyHmac(unittest.TestCase):
|
|
"""Test IntegrityChecker.verify_hmac method."""
|
|
|
|
def test_valid_hmac(self):
|
|
"""Test verifying valid HMAC."""
|
|
key = b"secret key"
|
|
data = b"test data"
|
|
signature = IntegrityChecker.compute_hmac(key, data, 'sha256')
|
|
|
|
self.assertTrue(IntegrityChecker.verify_hmac(key, data, signature, 'sha256'))
|
|
|
|
def test_invalid_hmac(self):
|
|
"""Test verifying invalid HMAC."""
|
|
key = b"secret key"
|
|
data = b"test data"
|
|
|
|
self.assertFalse(IntegrityChecker.verify_hmac(key, data, "invalid_signature", 'sha256'))
|
|
|
|
def test_tampered_data(self):
|
|
"""Test HMAC verification with tampered data."""
|
|
key = b"secret key"
|
|
data = b"test data"
|
|
signature = IntegrityChecker.compute_hmac(key, data, 'sha256')
|
|
|
|
tampered_data = b"tampered data"
|
|
self.assertFalse(IntegrityChecker.verify_hmac(key, tampered_data, signature, 'sha256'))
|
|
|
|
|
|
class TestConvenienceFunctions(unittest.TestCase):
|
|
"""Test convenience functions."""
|
|
|
|
def test_sanitize_filename_convenience(self):
|
|
"""Test sanitize_filename convenience function."""
|
|
result = sanitize_filename("path/to/file.txt")
|
|
self.assertEqual(result, "path_to_file.txt")
|
|
|
|
def test_validate_path_convenience(self):
|
|
"""Test validate_path convenience function."""
|
|
import tempfile
|
|
temp_dir = tempfile.mkdtemp()
|
|
|
|
try:
|
|
result = validate_path(temp_dir + "/subdir", temp_dir)
|
|
self.assertIsInstance(result, Path)
|
|
finally:
|
|
import shutil
|
|
shutil.rmtree(temp_dir)
|
|
|
|
def test_safe_path_join_convenience(self):
|
|
"""Test safe_path_join convenience function."""
|
|
import tempfile
|
|
temp_dir = tempfile.mkdtemp()
|
|
|
|
try:
|
|
result = safe_path_join(temp_dir, "subdir", "file.txt")
|
|
self.assertIn("subdir", str(result))
|
|
self.assertIn("file.txt", str(result))
|
|
finally:
|
|
import shutil
|
|
shutil.rmtree(temp_dir)
|
|
|
|
|
|
class TestSecurityError(unittest.TestCase):
|
|
"""Test SecurityError exception."""
|
|
|
|
def test_security_error_is_exception(self):
|
|
"""Test that SecurityError is an Exception."""
|
|
self.assertTrue(issubclass(SecurityError, Exception))
|
|
|
|
def test_security_error_message(self):
|
|
"""Test SecurityError message."""
|
|
error = SecurityError("Test error message")
|
|
self.assertEqual(str(error), "Test error message")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|