470 lines
14 KiB
Python
470 lines
14 KiB
Python
"""
|
|
EU-Utility - Security Utilities
|
|
|
|
Common security functions for input validation, path sanitization,
|
|
and secure data handling across the application.
|
|
"""
|
|
|
|
import re
|
|
import hashlib
|
|
import hmac
|
|
from pathlib import Path
|
|
from typing import Optional, Union, Any
|
|
|
|
|
|
class SecurityError(Exception):
|
|
"""Raised when a security violation is detected."""
|
|
pass
|
|
|
|
|
|
class PathValidator:
|
|
"""Validates and sanitizes file paths to prevent traversal attacks."""
|
|
|
|
# Maximum allowed filename length
|
|
MAX_FILENAME_LENGTH = 255
|
|
|
|
# Maximum path depth
|
|
MAX_PATH_DEPTH = 10
|
|
|
|
@classmethod
|
|
def sanitize_filename(cls, filename: str, replacement: str = '_') -> str:
|
|
"""
|
|
Sanitize a filename to remove dangerous characters.
|
|
|
|
Args:
|
|
filename: The filename to sanitize
|
|
replacement: Character to replace dangerous chars with
|
|
|
|
Returns:
|
|
Sanitized filename
|
|
"""
|
|
if not filename:
|
|
return f"file_{hashlib.md5(str(id(filename)).encode()).hexdigest()[:8]}"
|
|
|
|
# Remove null bytes
|
|
filename = filename.replace('\x00', replacement)
|
|
|
|
# Replace path separators
|
|
filename = filename.replace('/', replacement).replace('\\', replacement)
|
|
|
|
# Remove parent directory references
|
|
filename = filename.replace('..', replacement)
|
|
|
|
# Remove dangerous characters
|
|
dangerous = ['\x00', '\n', '\r', ':', '*', '?', '"', '<', '>', '|']
|
|
for char in dangerous:
|
|
filename = filename.replace(char, replacement)
|
|
|
|
# Limit length
|
|
if len(filename) > cls.MAX_FILENAME_LENGTH:
|
|
name, ext = cls._split_extension(filename)
|
|
filename = name[:cls.MAX_FILENAME_LENGTH - len(ext)] + ext
|
|
|
|
# Ensure filename is not empty or just dots
|
|
filename = filename.strip('. ')
|
|
if not filename:
|
|
filename = f"file_{hashlib.md5(str(id(filename)).encode()).hexdigest()[:8]}"
|
|
|
|
return filename
|
|
|
|
@classmethod
|
|
def _split_extension(cls, filename: str) -> tuple:
|
|
"""Split filename into name and extension."""
|
|
# Handle multiple extensions like .tar.gz
|
|
parts = filename.split('.')
|
|
if len(parts) > 1:
|
|
return '.'.join(parts[:-1]), '.' + parts[-1]
|
|
return filename, ''
|
|
|
|
@classmethod
|
|
def validate_path_within_base(
|
|
cls,
|
|
path: Union[str, Path],
|
|
base_path: Union[str, Path],
|
|
allow_symlinks: bool = False
|
|
) -> Path:
|
|
"""
|
|
Validate that a path is within a base directory.
|
|
|
|
Args:
|
|
path: The path to validate
|
|
base_path: The allowed base directory
|
|
allow_symlinks: Whether to allow symlinks (default: False)
|
|
|
|
Returns:
|
|
Resolved Path object
|
|
|
|
Raises:
|
|
SecurityError: If path traversal is detected
|
|
"""
|
|
try:
|
|
path = Path(path)
|
|
base_path = Path(base_path).resolve()
|
|
|
|
# Resolve the path
|
|
if allow_symlinks:
|
|
resolved = path.resolve()
|
|
else:
|
|
# Check for symlinks before resolving
|
|
for part in path.parts:
|
|
if part == '..':
|
|
continue
|
|
check_path = base_path / part
|
|
if check_path.is_symlink():
|
|
raise SecurityError(f"Symlink detected: {check_path}")
|
|
resolved = path.resolve()
|
|
|
|
# Check path depth
|
|
try:
|
|
relative = resolved.relative_to(base_path)
|
|
if len(relative.parts) > cls.MAX_PATH_DEPTH:
|
|
raise SecurityError(f"Path depth exceeds maximum: {len(relative.parts)}")
|
|
except ValueError:
|
|
raise SecurityError(
|
|
f"Path traversal detected: {resolved} is outside {base_path}"
|
|
)
|
|
|
|
return resolved
|
|
|
|
except (OSError, ValueError) as e:
|
|
raise SecurityError(f"Invalid path: {e}") from e
|
|
|
|
@classmethod
|
|
def safe_join(cls, base: Union[str, Path], *paths: str) -> Path:
|
|
"""
|
|
Safely join paths and validate result is within base.
|
|
|
|
Args:
|
|
base: Base directory
|
|
*paths: Path components to join
|
|
|
|
Returns:
|
|
Validated Path object
|
|
"""
|
|
base = Path(base)
|
|
|
|
# Sanitize each path component
|
|
sanitized = []
|
|
for path in paths:
|
|
# Remove dangerous components
|
|
path = cls.sanitize_filename(path, '_')
|
|
sanitized.append(path)
|
|
|
|
result = base.joinpath(*sanitized)
|
|
return cls.validate_path_within_base(result, base)
|
|
|
|
|
|
class InputValidator:
|
|
"""Validates user inputs to prevent injection attacks."""
|
|
|
|
# Common dangerous patterns
|
|
DANGEROUS_PATTERNS = [
|
|
r'__\w+__', # Dunder methods
|
|
r'eval\s*\(',
|
|
r'exec\s*\(',
|
|
r'compile\s*\(',
|
|
r'__import__\s*\(',
|
|
r'os\.system\s*\(',
|
|
r'subprocess\.',
|
|
r'\.popen\s*\(',
|
|
]
|
|
|
|
@classmethod
|
|
def is_safe_identifier(cls, name: str) -> bool:
|
|
"""
|
|
Check if a string is a safe Python identifier.
|
|
|
|
Args:
|
|
name: String to check
|
|
|
|
Returns:
|
|
True if safe identifier
|
|
"""
|
|
if not name:
|
|
return False
|
|
|
|
# Must match Python identifier rules
|
|
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name):
|
|
return False
|
|
|
|
# Must not be a keyword
|
|
import keyword
|
|
if keyword.iskeyword(name):
|
|
return False
|
|
|
|
return True
|
|
|
|
@classmethod
|
|
def sanitize_string(cls, text: str, max_length: int = 1000) -> str:
|
|
"""
|
|
Sanitize a string for safe use.
|
|
|
|
Args:
|
|
text: String to sanitize
|
|
max_length: Maximum allowed length
|
|
|
|
Returns:
|
|
Sanitized string
|
|
"""
|
|
if not isinstance(text, str):
|
|
text = str(text)
|
|
|
|
# Remove null bytes
|
|
text = text.replace('\x00', '')
|
|
|
|
# Remove control characters except newlines and tabs
|
|
text = ''.join(
|
|
char for char in text
|
|
if char == '\n' or char == '\t' or ord(char) >= 32
|
|
)
|
|
|
|
# Limit length
|
|
if len(text) > max_length:
|
|
text = text[:max_length]
|
|
|
|
return text
|
|
|
|
@classmethod
|
|
def validate_json_key(cls, key: str) -> bool:
|
|
"""
|
|
Validate that a JSON key is safe.
|
|
|
|
Args:
|
|
key: Key to validate
|
|
|
|
Returns:
|
|
True if safe
|
|
"""
|
|
if not isinstance(key, str):
|
|
return False
|
|
|
|
# Reject keys starting with __
|
|
if key.startswith('__'):
|
|
return False
|
|
|
|
# Reject empty keys
|
|
if not key:
|
|
return False
|
|
|
|
# Check for dangerous patterns
|
|
for pattern in cls.DANGEROUS_PATTERNS:
|
|
if re.search(pattern, key, re.IGNORECASE):
|
|
return False
|
|
|
|
return True
|
|
|
|
@classmethod
|
|
def validate_region_coordinates(
|
|
cls,
|
|
x: int,
|
|
y: int,
|
|
width: int,
|
|
height: int,
|
|
max_width: int = 7680,
|
|
max_height: int = 4320
|
|
) -> None:
|
|
"""
|
|
Validate screen region coordinates.
|
|
|
|
Args:
|
|
x: X coordinate
|
|
y: Y coordinate
|
|
width: Region width
|
|
height: Region height
|
|
max_width: Maximum allowed width (default 8K)
|
|
max_height: Maximum allowed height (default 8K)
|
|
|
|
Raises:
|
|
SecurityError: If coordinates are invalid
|
|
"""
|
|
# Check types
|
|
if not all(isinstance(v, int) for v in [x, y, width, height]):
|
|
raise SecurityError("Region coordinates must be integers")
|
|
|
|
# Check positive dimensions
|
|
if width <= 0 or height <= 0:
|
|
raise SecurityError("Region width and height must be positive")
|
|
|
|
# Check maximum dimensions
|
|
if width > max_width or height > max_height:
|
|
raise SecurityError(
|
|
f"Region dimensions exceed maximum ({max_width}x{max_height})"
|
|
)
|
|
|
|
# Check reasonable bounds
|
|
if x < -10000 or y < -10000:
|
|
raise SecurityError("Region coordinates out of reasonable bounds")
|
|
|
|
if x > 10000 or y > 10000:
|
|
raise SecurityError("Region coordinates out of reasonable bounds")
|
|
|
|
@classmethod
|
|
def validate_url_endpoint(cls, endpoint: str) -> str:
|
|
"""
|
|
Validate a URL endpoint path.
|
|
|
|
Args:
|
|
endpoint: Endpoint path to validate
|
|
|
|
Returns:
|
|
Validated endpoint
|
|
|
|
Raises:
|
|
SecurityError: If endpoint is invalid
|
|
"""
|
|
if not endpoint:
|
|
raise SecurityError("Endpoint cannot be empty")
|
|
|
|
# Must not start with /
|
|
if endpoint.startswith('/'):
|
|
raise SecurityError("Endpoint must not start with /")
|
|
|
|
# Must not contain path traversal
|
|
if '..' in endpoint:
|
|
raise SecurityError("Path traversal detected in endpoint")
|
|
|
|
# Must match allowed characters
|
|
if not re.match(r'^[a-zA-Z0-9_/-]+$', endpoint):
|
|
raise SecurityError(f"Invalid characters in endpoint: {endpoint}")
|
|
|
|
return endpoint
|
|
|
|
|
|
class DataValidator:
|
|
"""Validates data structures for security issues."""
|
|
|
|
MAX_NESTING_DEPTH = 10
|
|
MAX_COLLECTION_SIZE = 10000
|
|
MAX_STRING_LENGTH = 100000
|
|
|
|
@classmethod
|
|
def validate_data_structure(cls, data: Any, depth: int = 0) -> bool:
|
|
"""
|
|
Recursively validate a data structure for security issues.
|
|
|
|
Args:
|
|
data: Data to validate
|
|
depth: Current nesting depth
|
|
|
|
Returns:
|
|
True if valid
|
|
|
|
Raises:
|
|
SecurityError: If data structure is dangerous
|
|
"""
|
|
# Check nesting depth
|
|
if depth > cls.MAX_NESTING_DEPTH:
|
|
raise SecurityError(f"Data nesting exceeds maximum depth: {depth}")
|
|
|
|
if isinstance(data, dict):
|
|
if len(data) > cls.MAX_COLLECTION_SIZE:
|
|
raise SecurityError(f"Dictionary size exceeds maximum: {len(data)}")
|
|
|
|
for key, value in data.items():
|
|
# Validate key
|
|
if not InputValidator.validate_json_key(str(key)):
|
|
raise SecurityError(f"Invalid dictionary key: {key}")
|
|
|
|
# Recursively validate value
|
|
cls.validate_data_structure(value, depth + 1)
|
|
|
|
elif isinstance(data, (list, tuple, set)):
|
|
if len(data) > cls.MAX_COLLECTION_SIZE:
|
|
raise SecurityError(f"Collection size exceeds maximum: {len(data)}")
|
|
|
|
for item in data:
|
|
cls.validate_data_structure(item, depth + 1)
|
|
|
|
elif isinstance(data, str):
|
|
if len(data) > cls.MAX_STRING_LENGTH:
|
|
raise SecurityError(f"String length exceeds maximum: {len(data)}")
|
|
|
|
elif isinstance(data, (int, float, bool)):
|
|
pass # Primitives are fine
|
|
|
|
elif data is None:
|
|
pass # None is fine
|
|
|
|
else:
|
|
raise SecurityError(f"Unsupported data type: {type(data)}")
|
|
|
|
return True
|
|
|
|
|
|
class IntegrityChecker:
|
|
"""Provides data integrity verification."""
|
|
|
|
@staticmethod
|
|
def compute_hash(data: bytes, algorithm: str = 'sha256') -> str:
|
|
"""
|
|
Compute hash of data.
|
|
|
|
Args:
|
|
data: Data to hash
|
|
algorithm: Hash algorithm to use
|
|
|
|
Returns:
|
|
Hex digest of hash
|
|
"""
|
|
if algorithm == 'sha256':
|
|
return hashlib.sha256(data).hexdigest()
|
|
elif algorithm == 'sha512':
|
|
return hashlib.sha512(data).hexdigest()
|
|
elif algorithm == 'md5':
|
|
return hashlib.md5(data).hexdigest()
|
|
else:
|
|
raise ValueError(f"Unsupported hash algorithm: {algorithm}")
|
|
|
|
@staticmethod
|
|
def compute_hmac(key: bytes, data: bytes, algorithm: str = 'sha256') -> str:
|
|
"""
|
|
Compute HMAC of data.
|
|
|
|
Args:
|
|
key: HMAC key
|
|
data: Data to sign
|
|
algorithm: Hash algorithm to use
|
|
|
|
Returns:
|
|
Hex digest of HMAC
|
|
"""
|
|
if algorithm == 'sha256':
|
|
return hmac.new(key, data, hashlib.sha256).hexdigest()
|
|
elif algorithm == 'sha512':
|
|
return hmac.new(key, data, hashlib.sha512).hexdigest()
|
|
else:
|
|
raise ValueError(f"Unsupported HMAC algorithm: {algorithm}")
|
|
|
|
@staticmethod
|
|
def verify_hmac(key: bytes, data: bytes, signature: str, algorithm: str = 'sha256') -> bool:
|
|
"""
|
|
Verify HMAC signature.
|
|
|
|
Args:
|
|
key: HMAC key
|
|
data: Data that was signed
|
|
signature: Expected signature
|
|
algorithm: Hash algorithm used
|
|
|
|
Returns:
|
|
True if signature is valid
|
|
"""
|
|
expected = IntegrityChecker.compute_hmac(key, data, algorithm)
|
|
return hmac.compare_digest(expected, signature)
|
|
|
|
|
|
# Convenience functions
|
|
def sanitize_filename(filename: str, replacement: str = '_') -> str:
|
|
"""Sanitize a filename."""
|
|
return PathValidator.sanitize_filename(filename, replacement)
|
|
|
|
|
|
def validate_path(path: Union[str, Path], base: Union[str, Path]) -> Path:
|
|
"""Validate a path is within a base directory."""
|
|
return PathValidator.validate_path_within_base(path, base)
|
|
|
|
|
|
def safe_path_join(base: Union[str, Path], *paths: str) -> Path:
|
|
"""Safely join paths."""
|
|
return PathValidator.safe_join(base, *paths)
|