EU-Utility/core/security_utils.py

470 lines
14 KiB
Python

"""
EU-Utility - Security Utilities
Common security functions for input validation, path sanitization,
and secure data handling across the application.
"""
import re
import hashlib
import hmac
from pathlib import Path
from typing import Optional, Union, Any
class SecurityError(Exception):
"""Raised when a security violation is detected."""
pass
class PathValidator:
"""Validates and sanitizes file paths to prevent traversal attacks."""
# Maximum allowed filename length
MAX_FILENAME_LENGTH = 255
# Maximum path depth
MAX_PATH_DEPTH = 10
@classmethod
def sanitize_filename(cls, filename: str, replacement: str = '_') -> str:
"""
Sanitize a filename to remove dangerous characters.
Args:
filename: The filename to sanitize
replacement: Character to replace dangerous chars with
Returns:
Sanitized filename
"""
if not filename:
return f"file_{hashlib.md5(str(id(filename)).encode()).hexdigest()[:8]}"
# Remove null bytes
filename = filename.replace('\x00', replacement)
# Replace path separators
filename = filename.replace('/', replacement).replace('\\', replacement)
# Remove parent directory references
filename = filename.replace('..', replacement)
# Remove dangerous characters
dangerous = ['\x00', '\n', '\r', ':', '*', '?', '"', '<', '>', '|']
for char in dangerous:
filename = filename.replace(char, replacement)
# Limit length
if len(filename) > cls.MAX_FILENAME_LENGTH:
name, ext = cls._split_extension(filename)
filename = name[:cls.MAX_FILENAME_LENGTH - len(ext)] + ext
# Ensure filename is not empty or just dots
filename = filename.strip('. ')
if not filename:
filename = f"file_{hashlib.md5(str(id(filename)).encode()).hexdigest()[:8]}"
return filename
@classmethod
def _split_extension(cls, filename: str) -> tuple:
"""Split filename into name and extension."""
# Handle multiple extensions like .tar.gz
parts = filename.split('.')
if len(parts) > 1:
return '.'.join(parts[:-1]), '.' + parts[-1]
return filename, ''
@classmethod
def validate_path_within_base(
cls,
path: Union[str, Path],
base_path: Union[str, Path],
allow_symlinks: bool = False
) -> Path:
"""
Validate that a path is within a base directory.
Args:
path: The path to validate
base_path: The allowed base directory
allow_symlinks: Whether to allow symlinks (default: False)
Returns:
Resolved Path object
Raises:
SecurityError: If path traversal is detected
"""
try:
path = Path(path)
base_path = Path(base_path).resolve()
# Resolve the path
if allow_symlinks:
resolved = path.resolve()
else:
# Check for symlinks before resolving
for part in path.parts:
if part == '..':
continue
check_path = base_path / part
if check_path.is_symlink():
raise SecurityError(f"Symlink detected: {check_path}")
resolved = path.resolve()
# Check path depth
try:
relative = resolved.relative_to(base_path)
if len(relative.parts) > cls.MAX_PATH_DEPTH:
raise SecurityError(f"Path depth exceeds maximum: {len(relative.parts)}")
except ValueError:
raise SecurityError(
f"Path traversal detected: {resolved} is outside {base_path}"
)
return resolved
except (OSError, ValueError) as e:
raise SecurityError(f"Invalid path: {e}") from e
@classmethod
def safe_join(cls, base: Union[str, Path], *paths: str) -> Path:
"""
Safely join paths and validate result is within base.
Args:
base: Base directory
*paths: Path components to join
Returns:
Validated Path object
"""
base = Path(base)
# Sanitize each path component
sanitized = []
for path in paths:
# Remove dangerous components
path = cls.sanitize_filename(path, '_')
sanitized.append(path)
result = base.joinpath(*sanitized)
return cls.validate_path_within_base(result, base)
class InputValidator:
"""Validates user inputs to prevent injection attacks."""
# Common dangerous patterns
DANGEROUS_PATTERNS = [
r'__\w+__', # Dunder methods
r'eval\s*\(',
r'exec\s*\(',
r'compile\s*\(',
r'__import__\s*\(',
r'os\.system\s*\(',
r'subprocess\.',
r'\.popen\s*\(',
]
@classmethod
def is_safe_identifier(cls, name: str) -> bool:
"""
Check if a string is a safe Python identifier.
Args:
name: String to check
Returns:
True if safe identifier
"""
if not name:
return False
# Must match Python identifier rules
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name):
return False
# Must not be a keyword
import keyword
if keyword.iskeyword(name):
return False
return True
@classmethod
def sanitize_string(cls, text: str, max_length: int = 1000) -> str:
"""
Sanitize a string for safe use.
Args:
text: String to sanitize
max_length: Maximum allowed length
Returns:
Sanitized string
"""
if not isinstance(text, str):
text = str(text)
# Remove null bytes
text = text.replace('\x00', '')
# Remove control characters except newlines and tabs
text = ''.join(
char for char in text
if char == '\n' or char == '\t' or ord(char) >= 32
)
# Limit length
if len(text) > max_length:
text = text[:max_length]
return text
@classmethod
def validate_json_key(cls, key: str) -> bool:
"""
Validate that a JSON key is safe.
Args:
key: Key to validate
Returns:
True if safe
"""
if not isinstance(key, str):
return False
# Reject keys starting with __
if key.startswith('__'):
return False
# Reject empty keys
if not key:
return False
# Check for dangerous patterns
for pattern in cls.DANGEROUS_PATTERNS:
if re.search(pattern, key, re.IGNORECASE):
return False
return True
@classmethod
def validate_region_coordinates(
cls,
x: int,
y: int,
width: int,
height: int,
max_width: int = 7680,
max_height: int = 4320
) -> None:
"""
Validate screen region coordinates.
Args:
x: X coordinate
y: Y coordinate
width: Region width
height: Region height
max_width: Maximum allowed width (default 8K)
max_height: Maximum allowed height (default 8K)
Raises:
SecurityError: If coordinates are invalid
"""
# Check types
if not all(isinstance(v, int) for v in [x, y, width, height]):
raise SecurityError("Region coordinates must be integers")
# Check positive dimensions
if width <= 0 or height <= 0:
raise SecurityError("Region width and height must be positive")
# Check maximum dimensions
if width > max_width or height > max_height:
raise SecurityError(
f"Region dimensions exceed maximum ({max_width}x{max_height})"
)
# Check reasonable bounds
if x < -10000 or y < -10000:
raise SecurityError("Region coordinates out of reasonable bounds")
if x > 10000 or y > 10000:
raise SecurityError("Region coordinates out of reasonable bounds")
@classmethod
def validate_url_endpoint(cls, endpoint: str) -> str:
"""
Validate a URL endpoint path.
Args:
endpoint: Endpoint path to validate
Returns:
Validated endpoint
Raises:
SecurityError: If endpoint is invalid
"""
if not endpoint:
raise SecurityError("Endpoint cannot be empty")
# Must not start with /
if endpoint.startswith('/'):
raise SecurityError("Endpoint must not start with /")
# Must not contain path traversal
if '..' in endpoint:
raise SecurityError("Path traversal detected in endpoint")
# Must match allowed characters
if not re.match(r'^[a-zA-Z0-9_/-]+$', endpoint):
raise SecurityError(f"Invalid characters in endpoint: {endpoint}")
return endpoint
class DataValidator:
"""Validates data structures for security issues."""
MAX_NESTING_DEPTH = 10
MAX_COLLECTION_SIZE = 10000
MAX_STRING_LENGTH = 100000
@classmethod
def validate_data_structure(cls, data: Any, depth: int = 0) -> bool:
"""
Recursively validate a data structure for security issues.
Args:
data: Data to validate
depth: Current nesting depth
Returns:
True if valid
Raises:
SecurityError: If data structure is dangerous
"""
# Check nesting depth
if depth > cls.MAX_NESTING_DEPTH:
raise SecurityError(f"Data nesting exceeds maximum depth: {depth}")
if isinstance(data, dict):
if len(data) > cls.MAX_COLLECTION_SIZE:
raise SecurityError(f"Dictionary size exceeds maximum: {len(data)}")
for key, value in data.items():
# Validate key
if not InputValidator.validate_json_key(str(key)):
raise SecurityError(f"Invalid dictionary key: {key}")
# Recursively validate value
cls.validate_data_structure(value, depth + 1)
elif isinstance(data, (list, tuple, set)):
if len(data) > cls.MAX_COLLECTION_SIZE:
raise SecurityError(f"Collection size exceeds maximum: {len(data)}")
for item in data:
cls.validate_data_structure(item, depth + 1)
elif isinstance(data, str):
if len(data) > cls.MAX_STRING_LENGTH:
raise SecurityError(f"String length exceeds maximum: {len(data)}")
elif isinstance(data, (int, float, bool)):
pass # Primitives are fine
elif data is None:
pass # None is fine
else:
raise SecurityError(f"Unsupported data type: {type(data)}")
return True
class IntegrityChecker:
"""Provides data integrity verification."""
@staticmethod
def compute_hash(data: bytes, algorithm: str = 'sha256') -> str:
"""
Compute hash of data.
Args:
data: Data to hash
algorithm: Hash algorithm to use
Returns:
Hex digest of hash
"""
if algorithm == 'sha256':
return hashlib.sha256(data).hexdigest()
elif algorithm == 'sha512':
return hashlib.sha512(data).hexdigest()
elif algorithm == 'md5':
return hashlib.md5(data).hexdigest()
else:
raise ValueError(f"Unsupported hash algorithm: {algorithm}")
@staticmethod
def compute_hmac(key: bytes, data: bytes, algorithm: str = 'sha256') -> str:
"""
Compute HMAC of data.
Args:
key: HMAC key
data: Data to sign
algorithm: Hash algorithm to use
Returns:
Hex digest of HMAC
"""
if algorithm == 'sha256':
return hmac.new(key, data, hashlib.sha256).hexdigest()
elif algorithm == 'sha512':
return hmac.new(key, data, hashlib.sha512).hexdigest()
else:
raise ValueError(f"Unsupported HMAC algorithm: {algorithm}")
@staticmethod
def verify_hmac(key: bytes, data: bytes, signature: str, algorithm: str = 'sha256') -> bool:
"""
Verify HMAC signature.
Args:
key: HMAC key
data: Data that was signed
signature: Expected signature
algorithm: Hash algorithm used
Returns:
True if signature is valid
"""
expected = IntegrityChecker.compute_hmac(key, data, algorithm)
return hmac.compare_digest(expected, signature)
# Convenience functions
def sanitize_filename(filename: str, replacement: str = '_') -> str:
"""Sanitize a filename."""
return PathValidator.sanitize_filename(filename, replacement)
def validate_path(path: Union[str, Path], base: Union[str, Path]) -> Path:
"""Validate a path is within a base directory."""
return PathValidator.validate_path_within_base(path, base)
def safe_path_join(base: Union[str, Path], *paths: str) -> Path:
"""Safely join paths."""
return PathValidator.safe_join(base, *paths)