229 lines
7.1 KiB
Python
229 lines
7.1 KiB
Python
# Description: Database initialization and connection management
|
|
# Implements SQLite setup with Data Principle support
|
|
# Standards: Python 3.11+, type hints, async where possible
|
|
|
|
import sqlite3
|
|
import os
|
|
from pathlib import Path
|
|
from decimal import Decimal
|
|
from datetime import datetime
|
|
from typing import Optional, Any
|
|
import json
|
|
import logging
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DatabaseManager:
|
|
"""
|
|
Manages SQLite database connections and schema initialization.
|
|
|
|
Implements the Data Principle by providing robust data persistence
|
|
with foreign key support, transactions, and connection pooling.
|
|
"""
|
|
|
|
def __init__(self, db_path: Optional[str] = None):
|
|
"""
|
|
Initialize database manager.
|
|
|
|
Args:
|
|
db_path: Path to SQLite database. Defaults to ./data/lemontropia.db
|
|
"""
|
|
if db_path is None:
|
|
# Get project root (parent of core/)
|
|
core_dir = Path(__file__).parent
|
|
project_root = core_dir.parent
|
|
db_path = project_root / "data" / "lemontropia.db"
|
|
|
|
self.db_path = Path(db_path)
|
|
self._connection: Optional[sqlite3.Connection] = None
|
|
|
|
# Ensure data directory exists
|
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
logger.info(f"DatabaseManager initialized: {self.db_path}")
|
|
|
|
def initialize(self) -> bool:
|
|
"""
|
|
Initialize database with schema.
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
conn = self.get_connection()
|
|
|
|
# Read and execute schema
|
|
schema_path = Path(__file__).parent / "schema.sql"
|
|
with open(schema_path, 'r') as f:
|
|
schema = f.read()
|
|
|
|
# Execute schema (multiple statements)
|
|
conn.executescript(schema)
|
|
conn.commit()
|
|
|
|
logger.info("Database schema initialized successfully")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize database: {e}")
|
|
return False
|
|
|
|
def get_connection(self) -> sqlite3.Connection:
|
|
"""
|
|
Get or create database connection.
|
|
|
|
Returns:
|
|
SQLite connection with row factory and type detection
|
|
"""
|
|
if self._connection is None:
|
|
self._connection = sqlite3.connect(
|
|
self.db_path,
|
|
detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES
|
|
)
|
|
self._connection.row_factory = sqlite3.Row
|
|
|
|
# Enable foreign keys
|
|
self._connection.execute("PRAGMA foreign_keys = ON")
|
|
|
|
# Performance optimizations (Rule #3: 60+ FPS)
|
|
self._connection.execute("PRAGMA journal_mode = WAL")
|
|
self._connection.execute("PRAGMA synchronous = NORMAL")
|
|
self._connection.execute("PRAGMA cache_size = -64000") # 64MB cache
|
|
|
|
logger.debug("Database connection established")
|
|
|
|
return self._connection
|
|
|
|
def close(self) -> None:
|
|
"""Close database connection."""
|
|
if self._connection:
|
|
self._connection.close()
|
|
self._connection = None
|
|
logger.debug("Database connection closed")
|
|
|
|
def execute(self, query: str, parameters: tuple = ()) -> sqlite3.Cursor:
|
|
"""
|
|
Execute a query with parameters.
|
|
|
|
Args:
|
|
query: SQL query string
|
|
parameters: Query parameters (prevents SQL injection)
|
|
|
|
Returns:
|
|
Cursor object
|
|
"""
|
|
conn = self.get_connection()
|
|
return conn.execute(query, parameters)
|
|
|
|
def executemany(self, query: str, parameters: list) -> sqlite3.Cursor:
|
|
"""
|
|
Execute query with multiple parameter sets.
|
|
|
|
Args:
|
|
query: SQL query string
|
|
parameters: List of parameter tuples
|
|
|
|
Returns:
|
|
Cursor object
|
|
"""
|
|
conn = self.get_connection()
|
|
return conn.executemany(query, parameters)
|
|
|
|
def commit(self) -> None:
|
|
"""Commit current transaction."""
|
|
if self._connection:
|
|
self._connection.commit()
|
|
|
|
def rollback(self) -> None:
|
|
"""Rollback current transaction."""
|
|
if self._connection:
|
|
self._connection.rollback()
|
|
|
|
def get_schema_version(self) -> int:
|
|
"""
|
|
Get current schema version.
|
|
|
|
Returns:
|
|
Schema version number (0 if not initialized)
|
|
"""
|
|
try:
|
|
cursor = self.execute(
|
|
"SELECT MAX(version) FROM schema_version"
|
|
)
|
|
result = cursor.fetchone()
|
|
return result[0] if result and result[0] else 0
|
|
except sqlite3.OperationalError:
|
|
# Table doesn't exist
|
|
return 0
|
|
|
|
def backup(self, backup_path: Optional[str] = None) -> bool:
|
|
"""
|
|
Create database backup.
|
|
|
|
Args:
|
|
backup_path: Path for backup file. Defaults to timestamped backup
|
|
|
|
Returns:
|
|
True if successful
|
|
"""
|
|
try:
|
|
if backup_path is None:
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
backup_dir = self.db_path.parent / "backups"
|
|
backup_dir.mkdir(exist_ok=True)
|
|
backup_path = backup_dir / f"lemontropia_backup_{timestamp}.db"
|
|
|
|
# Use SQLite backup API
|
|
source = self.get_connection()
|
|
dest = sqlite3.connect(backup_path)
|
|
|
|
with dest:
|
|
source.backup(dest)
|
|
|
|
dest.close()
|
|
logger.info(f"Database backed up to: {backup_path}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Backup failed: {e}")
|
|
return False
|
|
|
|
def __enter__(self):
|
|
"""Context manager entry."""
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
"""Context manager exit with automatic commit/rollback."""
|
|
if exc_type is None:
|
|
self.commit()
|
|
else:
|
|
self.rollback()
|
|
self.close()
|
|
|
|
|
|
# ============================================================================
|
|
# DECIMAL HANDLING (Rule #4: Precision)
|
|
# ============================================================================
|
|
|
|
def adapt_decimal(d: Decimal) -> str:
|
|
"""Convert Decimal to string for SQLite storage."""
|
|
return str(d)
|
|
|
|
def convert_decimal(s: bytes) -> Decimal:
|
|
"""Convert SQLite string back to Decimal."""
|
|
return Decimal(s.decode('utf-8'))
|
|
|
|
# Register Decimal adapter/converter
|
|
sqlite3.register_adapter(Decimal, adapt_decimal)
|
|
sqlite3.register_converter("DECIMAL", convert_decimal)
|
|
|
|
|
|
# ============================================================================
|
|
# MODULE EXPORTS
|
|
# ============================================================================
|
|
|
|
__all__ = ['DatabaseManager', 'adapt_decimal', 'convert_decimal']
|