""" EU-Utility Premium - State Management ====================================== Redux-inspired state store with: - Immutable state updates - Time-travel debugging - Selectors for derived state - Middleware support - State persistence Example: from premium.core.state import StateStore, Action, Reducer # Define actions class IncrementAction(Action): type = "INCREMENT" # Define reducer def counter_reducer(state: int, action: Action) -> int: if action.type == "INCREMENT": return state + 1 return state # Create store store = StateStore(counter_reducer, initial_state=0) store.dispatch(IncrementAction()) """ from __future__ import annotations import copy import json import logging import threading import time import uuid from abc import ABC, abstractmethod from dataclasses import dataclass, field, asdict from datetime import datetime from enum import Enum, auto from pathlib import Path from typing import ( Any, Callable, Dict, Generic, List, Optional, Set, TypeVar, Union, Protocol, runtime_checkable ) # ============================================================================= # TYPE DEFINITIONS # ============================================================================= T = TypeVar('T') State = TypeVar('State') @runtime_checkable class Action(Protocol): """Protocol for actions.""" type: str payload: Optional[Any] = None meta: Optional[Dict[str, Any]] = None class ActionBase: """Base class for actions.""" type: str = "UNKNOWN" payload: Any = None meta: Optional[Dict[str, Any]] = None timestamp: datetime = field(default_factory=datetime.now) def __init__(self, payload: Any = None, meta: Optional[Dict[str, Any]] = None): self.payload = payload self.meta = meta or {} self.timestamp = datetime.now() def to_dict(self) -> Dict[str, Any]: """Convert action to dictionary.""" return { 'type': self.type, 'payload': self.payload, 'meta': self.meta, 'timestamp': self.timestamp.isoformat(), } # Reducer type: (state, action) -> new_state Reducer = Callable[[State, Action], State] # Selector type: state -> derived_value Selector = Callable[[State], T] # Subscriber type: (new_state, old_state) -> None Subscriber = Callable[[State, State], None] # Middleware type: store -> next -> action -> result Middleware = Callable[['StateStore', Callable[[Action], Any], Action], Any] # ============================================================================= # BUILT-IN ACTIONS # ============================================================================= class StateResetAction(ActionBase): """Reset state to initial value.""" type = "@@STATE/RESET" class StateRestoreAction(ActionBase): """Restore state from snapshot.""" type = "@@STATE/RESTORE" def __init__(self, state: State, meta: Optional[Dict[str, Any]] = None): super().__init__(payload=state, meta=meta) class StateBatchAction(ActionBase): """Batch multiple actions.""" type = "@@STATE/BATCH" def __init__(self, actions: List[Action], meta: Optional[Dict[str, Any]] = None): super().__init__(payload=actions, meta=meta) class StateHydrateAction(ActionBase): """Hydrate state from persisted data.""" type = "@@STATE/HYDRATE" def __init__(self, state: State, meta: Optional[Dict[str, Any]] = None): super().__init__(payload=state, meta=meta) # ============================================================================= # STATE SNAPSHOT # ============================================================================= @dataclass class StateSnapshot: """Immutable snapshot of state at a point in time.""" state: Any action: Optional[Action] = None timestamp: datetime = field(default_factory=datetime.now) id: str = field(default_factory=lambda: str(uuid.uuid4())[:8]) def to_dict(self) -> Dict[str, Any]: """Convert snapshot to dictionary.""" return { 'id': self.id, 'state': self.state, 'action': self.action.to_dict() if self.action else None, 'timestamp': self.timestamp.isoformat(), } # ============================================================================= # STORE SLICE # ============================================================================= @dataclass class StoreSlice(Generic[T]): """A slice of the store with its own reducer and state.""" name: str reducer: Reducer initial_state: T selectors: Dict[str, Selector] = field(default_factory=dict) # ============================================================================= # COMBINED REDUCER # ============================================================================= def combine_reducers(reducers: Dict[str, Reducer]) -> Reducer: """Combine multiple reducers into one. Each reducer manages a slice of the state. Args: reducers: Dict mapping slice names to reducers Returns: Combined reducer function Example: root_reducer = combine_reducers({ 'counter': counter_reducer, 'todos': todos_reducer, }) """ def combined(state: Dict[str, Any], action: Action) -> Dict[str, Any]: new_state = {} has_changed = False for key, reducer in reducers.items(): previous_state = state.get(key) if isinstance(state, dict) else None new_slice = reducer(previous_state, action) new_state[key] = new_slice if new_slice is not previous_state: has_changed = True return new_state if has_changed else state return combined # ============================================================================= # MIDDLEWARE # ============================================================================= def logging_middleware(store: StateStore, next: Callable[[Action], Any], action: Action) -> Any: """Middleware that logs all actions.""" print(f"[State] Action: {action.type}") result = next(action) print(f"[State] New state: {store.get_state()}") return result def thunk_middleware(store: StateStore, next: Callable[[Action], Any], action: Action) -> Any: """Middleware that allows thunk actions (functions).""" if callable(action) and not isinstance(action, ActionBase): # It's a thunk - call it with dispatch and get_state return action(store.dispatch, store.get_state) return next(action) def persistence_middleware( storage_path: Path, debounce_ms: int = 1000 ) -> Middleware: """Create middleware that persists state to disk. Args: storage_path: Path to store state debounce_ms: Debounce time in milliseconds Returns: Middleware function """ last_save = 0 pending_save = False def middleware(store: StateStore, next: Callable[[Action], Any], action: Action) -> Any: nonlocal last_save, pending_save result = next(action) # Debounce saves current_time = time.time() * 1000 if current_time - last_save > debounce_ms: try: state = store.get_state() with open(storage_path, 'w', encoding='utf-8') as f: json.dump(state, f, indent=2, default=str) last_save = current_time pending_save = False except Exception as e: logging.getLogger("StateStore").error(f"Failed to persist state: {e}") else: pending_save = True return result return middleware # ============================================================================= # STATE STORE # ============================================================================= class StateStore(Generic[State]): """Redux-inspired state store with time-travel debugging. Features: - Immutable state updates - Action history for debugging - State snapshots for time travel - Selectors for derived state - Middleware support - State persistence Example: store = StateStore( reducer=root_reducer, initial_state={'count': 0}, middleware=[logging_middleware] ) # Subscribe to changes unsubscribe = store.subscribe(lambda new, old: print(f"Changed: {old} -> {new}")) # Dispatch actions store.dispatch(IncrementAction()) # Use selectors count = store.select(lambda state: state['count']) # Time travel store.undo() # Undo last action store.jump_to_snapshot(0) # Jump to initial state """ def __init__( self, reducer: Reducer[State, Action], initial_state: Optional[State] = None, middleware: Optional[List[Middleware]] = None, max_history: int = 1000, enable_time_travel: bool = True ): """Initialize state store. Args: reducer: Root reducer function initial_state: Initial state value middleware: List of middleware functions max_history: Maximum action history size enable_time_travel: Enable time-travel debugging """ self._reducer = reducer self._state: State = initial_state self._initial_state = copy.deepcopy(initial_state) self._middleware = middleware or [] self._max_history = max_history self._enable_time_travel = enable_time_travel # Subscribers self._subscribers: Dict[str, Subscriber] = {} self._subscriber_counter = 0 # History for time travel self._history: List[StateSnapshot] = [] self._current_index = -1 # Lock for thread safety self._lock = threading.RLock() self._logger = logging.getLogger("StateStore") # Create initial snapshot if enable_time_travel: self._add_snapshot(StateSnapshot( state=copy.deepcopy(self._state), action=None )) # ========== Core Methods ========== def get_state(self) -> State: """Get current state (immutable).""" with self._lock: return copy.deepcopy(self._state) def dispatch(self, action: Action) -> Action: """Dispatch an action to update state. Args: action: Action to dispatch Returns: The dispatched action """ # Apply middleware chain def dispatch_action(a: Action) -> Action: return self._apply_reducer(a) # Build middleware chain chain = dispatch_action for mw in reversed(self._middleware): chain = lambda a, mw=mw, next=chain: mw(self, next, a) return chain(action) def _apply_reducer(self, action: Action) -> Action: """Apply reducer and update state.""" with self._lock: old_state = self._state new_state = self._reducer(copy.deepcopy(old_state), action) # Only update if state changed if new_state is not old_state: self._state = new_state # Add to history if self._enable_time_travel: # Remove any future states if we're not at the end if self._current_index < len(self._history) - 1: self._history = self._history[:self._current_index + 1] self._add_snapshot(StateSnapshot( state=copy.deepcopy(new_state), action=action )) # Notify subscribers self._notify_subscribers(new_state, old_state) return action def _add_snapshot(self, snapshot: StateSnapshot) -> None: """Add snapshot to history.""" self._history.append(snapshot) self._current_index = len(self._history) - 1 # Trim history if needed if len(self._history) > self._max_history: self._history = self._history[-self._max_history:] self._current_index = len(self._history) - 1 # ========== Subscription ========== def subscribe(self, callback: Subscriber, selector: Optional[Selector] = None) -> Callable[[], None]: """Subscribe to state changes. Args: callback: Function called when state changes selector: Optional selector to compare specific parts Returns: Unsubscribe function """ with self._lock: self._subscriber_counter += 1 sub_id = f"sub_{self._subscriber_counter}" # Wrap callback with selector if provided if selector: last_value = selector(self._state) def wrapped_callback(new_state: State, old_state: State) -> None: nonlocal last_value new_value = selector(new_state) if new_value != last_value: last_value = new_value callback(new_state, old_state) self._subscribers[sub_id] = wrapped_callback else: self._subscribers[sub_id] = callback def unsubscribe() -> None: with self._lock: self._subscribers.pop(sub_id, None) return unsubscribe def _notify_subscribers(self, new_state: State, old_state: State) -> None: """Notify all subscribers of state change.""" for callback in list(self._subscribers.values()): try: callback(new_state, old_state) except Exception as e: self._logger.error(f"Error in subscriber: {e}") # ========== Selectors ========== def select(self, selector: Selector[T]) -> T: """Select a derived value from state. Args: selector: Function that extracts value from state Returns: Selected value """ with self._lock: return selector(copy.deepcopy(self._state)) def create_selector(self, *input_selectors: Selector, combiner: Callable) -> Selector: """Create a memoized selector. Args: input_selectors: Selectors that provide input values combiner: Function that combines inputs into output Returns: Memoized selector function """ last_inputs = [None] * len(input_selectors) last_result = None def memoized_selector(state: State) -> Any: nonlocal last_inputs, last_result inputs = [s(state) for s in input_selectors] # Check if inputs changed if inputs != last_inputs: last_inputs = inputs last_result = combiner(*inputs) return last_result return memoized_selector # ========== Time Travel ========== def get_history(self) -> List[StateSnapshot]: """Get action history.""" with self._lock: return self._history.copy() def get_current_index(self) -> int: """Get current position in history.""" return self._current_index def can_undo(self) -> bool: """Check if undo is possible.""" return self._current_index > 0 def can_redo(self) -> bool: """Check if redo is possible.""" return self._current_index < len(self._history) - 1 def undo(self) -> bool: """Undo last action. Returns: True if undo was performed """ if not self.can_undo(): return False with self._lock: self._current_index -= 1 old_state = self._state self._state = copy.deepcopy(self._history[self._current_index].state) self._notify_subscribers(self._state, old_state) return True def redo(self) -> bool: """Redo last undone action. Returns: True if redo was performed """ if not self.can_redo(): return False with self._lock: self._current_index += 1 old_state = self._state self._state = copy.deepcopy(self._history[self._current_index].state) self._notify_subscribers(self._state, old_state) return True def jump_to_snapshot(self, index: int) -> bool: """Jump to a specific snapshot in history. Args: index: Snapshot index Returns: True if jump was successful """ if index < 0 or index >= len(self._history): return False with self._lock: old_state = self._state self._current_index = index self._state = copy.deepcopy(self._history[index].state) self._notify_subscribers(self._state, old_state) return True def reset(self) -> None: """Reset to initial state.""" self.dispatch(StateResetAction()) with self._lock: old_state = self._state self._state = copy.deepcopy(self._initial_state) if self._enable_time_travel: self._history.clear() self._current_index = -1 self._add_snapshot(StateSnapshot( state=copy.deepcopy(self._state), action=None )) self._notify_subscribers(self._state, old_state) # ========== Persistence ========== def save_to_disk(self, path: Path) -> bool: """Save current state to disk. Args: path: File path to save to Returns: True if saved successfully """ try: with self._lock: state_data = { 'state': self._state, 'timestamp': datetime.now().isoformat(), 'history_count': len(self._history), } path.parent.mkdir(parents=True, exist_ok=True) with open(path, 'w', encoding='utf-8') as f: json.dump(state_data, f, indent=2, default=str) return True except Exception as e: self._logger.error(f"Failed to save state: {e}") return False def load_from_disk(self, path: Path) -> bool: """Load state from disk. Args: path: File path to load from Returns: True if loaded successfully """ try: with open(path, 'r', encoding='utf-8') as f: data = json.load(f) state = data.get('state') if state is not None: self.dispatch(StateHydrateAction(state)) return True except Exception as e: self._logger.error(f"Failed to load state: {e}") return False # ============================================================================= # MODULE-LEVEL STORE (Application-wide state) # ============================================================================= _module_stores: Dict[str, StateStore] = {} def create_store( name: str, reducer: Reducer, initial_state: Optional[Any] = None, **kwargs ) -> StateStore: """Create or get a named store. Args: name: Unique store name reducer: Reducer function initial_state: Initial state **kwargs: Additional arguments for StateStore Returns: StateStore instance """ if name not in _module_stores: _module_stores[name] = StateStore( reducer=reducer, initial_state=initial_state, **kwargs ) return _module_stores[name] def get_store(name: str) -> Optional[StateStore]: """Get a named store. Args: name: Store name Returns: StateStore or None if not found """ return _module_stores.get(name) def remove_store(name: str) -> bool: """Remove a named store. Args: name: Store name Returns: True if removed """ if name in _module_stores: del _module_stores[name] return True return False # ============================================================================= # EXPORTS # ============================================================================= __all__ = [ # Types 'Action', 'ActionBase', 'Reducer', 'Selector', 'Subscriber', 'Middleware', # Actions 'StateResetAction', 'StateRestoreAction', 'StateBatchAction', 'StateHydrateAction', # Classes 'StateSnapshot', 'StoreSlice', 'StateStore', # Functions 'combine_reducers', 'create_store', 'get_store', 'remove_store', # Middleware 'logging_middleware', 'thunk_middleware', 'persistence_middleware', ]