""" EU-Utility - Security Utilities Common security functions for input validation, path sanitization, and secure data handling across the application. """ import re import hashlib import hmac from pathlib import Path from typing import Optional, Union, Any class SecurityError(Exception): """Raised when a security violation is detected.""" pass class PathValidator: """Validates and sanitizes file paths to prevent traversal attacks.""" # Maximum allowed filename length MAX_FILENAME_LENGTH = 255 # Maximum path depth MAX_PATH_DEPTH = 10 @classmethod def sanitize_filename(cls, filename: str, replacement: str = '_') -> str: """ Sanitize a filename to remove dangerous characters. Args: filename: The filename to sanitize replacement: Character to replace dangerous chars with Returns: Sanitized filename """ if not filename: return f"file_{hashlib.md5(str(id(filename)).encode()).hexdigest()[:8]}" # Remove null bytes filename = filename.replace('\x00', replacement) # Replace path separators filename = filename.replace('/', replacement).replace('\\', replacement) # Remove parent directory references filename = filename.replace('..', replacement) # Remove dangerous characters dangerous = ['\x00', '\n', '\r', ':', '*', '?', '"', '<', '>', '|'] for char in dangerous: filename = filename.replace(char, replacement) # Limit length if len(filename) > cls.MAX_FILENAME_LENGTH: name, ext = cls._split_extension(filename) filename = name[:cls.MAX_FILENAME_LENGTH - len(ext)] + ext # Ensure filename is not empty or just dots filename = filename.strip('. ') if not filename: filename = f"file_{hashlib.md5(str(id(filename)).encode()).hexdigest()[:8]}" return filename @classmethod def _split_extension(cls, filename: str) -> tuple: """Split filename into name and extension.""" # Handle multiple extensions like .tar.gz parts = filename.split('.') if len(parts) > 1: return '.'.join(parts[:-1]), '.' + parts[-1] return filename, '' @classmethod def validate_path_within_base( cls, path: Union[str, Path], base_path: Union[str, Path], allow_symlinks: bool = False ) -> Path: """ Validate that a path is within a base directory. Args: path: The path to validate base_path: The allowed base directory allow_symlinks: Whether to allow symlinks (default: False) Returns: Resolved Path object Raises: SecurityError: If path traversal is detected """ try: path = Path(path) base_path = Path(base_path).resolve() # Resolve the path if allow_symlinks: resolved = path.resolve() else: # Check for symlinks before resolving for part in path.parts: if part == '..': continue check_path = base_path / part if check_path.is_symlink(): raise SecurityError(f"Symlink detected: {check_path}") resolved = path.resolve() # Check path depth try: relative = resolved.relative_to(base_path) if len(relative.parts) > cls.MAX_PATH_DEPTH: raise SecurityError(f"Path depth exceeds maximum: {len(relative.parts)}") except ValueError: raise SecurityError( f"Path traversal detected: {resolved} is outside {base_path}" ) return resolved except (OSError, ValueError) as e: raise SecurityError(f"Invalid path: {e}") from e @classmethod def safe_join(cls, base: Union[str, Path], *paths: str) -> Path: """ Safely join paths and validate result is within base. Args: base: Base directory *paths: Path components to join Returns: Validated Path object """ base = Path(base) # Sanitize each path component sanitized = [] for path in paths: # Remove dangerous components path = cls.sanitize_filename(path, '_') sanitized.append(path) result = base.joinpath(*sanitized) return cls.validate_path_within_base(result, base) class InputValidator: """Validates user inputs to prevent injection attacks.""" # Common dangerous patterns DANGEROUS_PATTERNS = [ r'__\w+__', # Dunder methods r'eval\s*\(', r'exec\s*\(', r'compile\s*\(', r'__import__\s*\(', r'os\.system\s*\(', r'subprocess\.', r'\.popen\s*\(', ] @classmethod def is_safe_identifier(cls, name: str) -> bool: """ Check if a string is a safe Python identifier. Args: name: String to check Returns: True if safe identifier """ if not name: return False # Must match Python identifier rules if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name): return False # Must not be a keyword import keyword if keyword.iskeyword(name): return False return True @classmethod def sanitize_string(cls, text: str, max_length: int = 1000) -> str: """ Sanitize a string for safe use. Args: text: String to sanitize max_length: Maximum allowed length Returns: Sanitized string """ if not isinstance(text, str): text = str(text) # Remove null bytes text = text.replace('\x00', '') # Remove control characters except newlines and tabs text = ''.join( char for char in text if char == '\n' or char == '\t' or ord(char) >= 32 ) # Limit length if len(text) > max_length: text = text[:max_length] return text @classmethod def validate_json_key(cls, key: str) -> bool: """ Validate that a JSON key is safe. Args: key: Key to validate Returns: True if safe """ if not isinstance(key, str): return False # Reject keys starting with __ if key.startswith('__'): return False # Reject empty keys if not key: return False # Check for dangerous patterns for pattern in cls.DANGEROUS_PATTERNS: if re.search(pattern, key, re.IGNORECASE): return False return True @classmethod def validate_region_coordinates( cls, x: int, y: int, width: int, height: int, max_width: int = 7680, max_height: int = 4320 ) -> None: """ Validate screen region coordinates. Args: x: X coordinate y: Y coordinate width: Region width height: Region height max_width: Maximum allowed width (default 8K) max_height: Maximum allowed height (default 8K) Raises: SecurityError: If coordinates are invalid """ # Check types if not all(isinstance(v, int) for v in [x, y, width, height]): raise SecurityError("Region coordinates must be integers") # Check positive dimensions if width <= 0 or height <= 0: raise SecurityError("Region width and height must be positive") # Check maximum dimensions if width > max_width or height > max_height: raise SecurityError( f"Region dimensions exceed maximum ({max_width}x{max_height})" ) # Check reasonable bounds if x < -10000 or y < -10000: raise SecurityError("Region coordinates out of reasonable bounds") if x > 10000 or y > 10000: raise SecurityError("Region coordinates out of reasonable bounds") @classmethod def validate_url_endpoint(cls, endpoint: str) -> str: """ Validate a URL endpoint path. Args: endpoint: Endpoint path to validate Returns: Validated endpoint Raises: SecurityError: If endpoint is invalid """ if not endpoint: raise SecurityError("Endpoint cannot be empty") # Must not start with / if endpoint.startswith('/'): raise SecurityError("Endpoint must not start with /") # Must not contain path traversal if '..' in endpoint: raise SecurityError("Path traversal detected in endpoint") # Must match allowed characters if not re.match(r'^[a-zA-Z0-9_/-]+$', endpoint): raise SecurityError(f"Invalid characters in endpoint: {endpoint}") return endpoint class DataValidator: """Validates data structures for security issues.""" MAX_NESTING_DEPTH = 10 MAX_COLLECTION_SIZE = 10000 MAX_STRING_LENGTH = 100000 @classmethod def validate_data_structure(cls, data: Any, depth: int = 0) -> bool: """ Recursively validate a data structure for security issues. Args: data: Data to validate depth: Current nesting depth Returns: True if valid Raises: SecurityError: If data structure is dangerous """ # Check nesting depth if depth > cls.MAX_NESTING_DEPTH: raise SecurityError(f"Data nesting exceeds maximum depth: {depth}") if isinstance(data, dict): if len(data) > cls.MAX_COLLECTION_SIZE: raise SecurityError(f"Dictionary size exceeds maximum: {len(data)}") for key, value in data.items(): # Validate key if not InputValidator.validate_json_key(str(key)): raise SecurityError(f"Invalid dictionary key: {key}") # Recursively validate value cls.validate_data_structure(value, depth + 1) elif isinstance(data, (list, tuple, set)): if len(data) > cls.MAX_COLLECTION_SIZE: raise SecurityError(f"Collection size exceeds maximum: {len(data)}") for item in data: cls.validate_data_structure(item, depth + 1) elif isinstance(data, str): if len(data) > cls.MAX_STRING_LENGTH: raise SecurityError(f"String length exceeds maximum: {len(data)}") elif isinstance(data, (int, float, bool)): pass # Primitives are fine elif data is None: pass # None is fine else: raise SecurityError(f"Unsupported data type: {type(data)}") return True class IntegrityChecker: """Provides data integrity verification.""" @staticmethod def compute_hash(data: bytes, algorithm: str = 'sha256') -> str: """ Compute hash of data. Args: data: Data to hash algorithm: Hash algorithm to use Returns: Hex digest of hash """ if algorithm == 'sha256': return hashlib.sha256(data).hexdigest() elif algorithm == 'sha512': return hashlib.sha512(data).hexdigest() elif algorithm == 'md5': return hashlib.md5(data).hexdigest() else: raise ValueError(f"Unsupported hash algorithm: {algorithm}") @staticmethod def compute_hmac(key: bytes, data: bytes, algorithm: str = 'sha256') -> str: """ Compute HMAC of data. Args: key: HMAC key data: Data to sign algorithm: Hash algorithm to use Returns: Hex digest of HMAC """ if algorithm == 'sha256': return hmac.new(key, data, hashlib.sha256).hexdigest() elif algorithm == 'sha512': return hmac.new(key, data, hashlib.sha512).hexdigest() else: raise ValueError(f"Unsupported HMAC algorithm: {algorithm}") @staticmethod def verify_hmac(key: bytes, data: bytes, signature: str, algorithm: str = 'sha256') -> bool: """ Verify HMAC signature. Args: key: HMAC key data: Data that was signed signature: Expected signature algorithm: Hash algorithm used Returns: True if signature is valid """ expected = IntegrityChecker.compute_hmac(key, data, algorithm) return hmac.compare_digest(expected, signature) # Convenience functions def sanitize_filename(filename: str, replacement: str = '_') -> str: """Sanitize a filename.""" return PathValidator.sanitize_filename(filename, replacement) def validate_path(path: Union[str, Path], base: Union[str, Path]) -> Path: """Validate a path is within a base directory.""" return PathValidator.validate_path_within_base(path, base) def safe_path_join(base: Union[str, Path], *paths: str) -> Path: """Safely join paths.""" return PathValidator.safe_join(base, *paths)