EU-Utility/premium/core/state/store.py

712 lines
21 KiB
Python

"""
EU-Utility Premium - State Management
======================================
Redux-inspired state store with:
- Immutable state updates
- Time-travel debugging
- Selectors for derived state
- Middleware support
- State persistence
Example:
from premium.core.state import StateStore, Action, Reducer
# Define actions
class IncrementAction(Action):
type = "INCREMENT"
# Define reducer
def counter_reducer(state: int, action: Action) -> int:
if action.type == "INCREMENT":
return state + 1
return state
# Create store
store = StateStore(counter_reducer, initial_state=0)
store.dispatch(IncrementAction())
"""
from __future__ import annotations
import copy
import json
import logging
import threading
import time
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, asdict
from datetime import datetime
from enum import Enum, auto
from pathlib import Path
from typing import (
Any, Callable, Dict, Generic, List, Optional, Set, TypeVar, Union,
Protocol, runtime_checkable
)
# =============================================================================
# TYPE DEFINITIONS
# =============================================================================
T = TypeVar('T')
State = TypeVar('State')
@runtime_checkable
class Action(Protocol):
"""Protocol for actions."""
type: str
payload: Optional[Any] = None
meta: Optional[Dict[str, Any]] = None
class ActionBase:
"""Base class for actions."""
type: str = "UNKNOWN"
payload: Any = None
meta: Optional[Dict[str, Any]] = None
timestamp: datetime = field(default_factory=datetime.now)
def __init__(self, payload: Any = None, meta: Optional[Dict[str, Any]] = None):
self.payload = payload
self.meta = meta or {}
self.timestamp = datetime.now()
def to_dict(self) -> Dict[str, Any]:
"""Convert action to dictionary."""
return {
'type': self.type,
'payload': self.payload,
'meta': self.meta,
'timestamp': self.timestamp.isoformat(),
}
# Reducer type: (state, action) -> new_state
Reducer = Callable[[State, Action], State]
# Selector type: state -> derived_value
Selector = Callable[[State], T]
# Subscriber type: (new_state, old_state) -> None
Subscriber = Callable[[State, State], None]
# Middleware type: store -> next -> action -> result
Middleware = Callable[['StateStore', Callable[[Action], Any], Action], Any]
# =============================================================================
# BUILT-IN ACTIONS
# =============================================================================
class StateResetAction(ActionBase):
"""Reset state to initial value."""
type = "@@STATE/RESET"
class StateRestoreAction(ActionBase):
"""Restore state from snapshot."""
type = "@@STATE/RESTORE"
def __init__(self, state: State, meta: Optional[Dict[str, Any]] = None):
super().__init__(payload=state, meta=meta)
class StateBatchAction(ActionBase):
"""Batch multiple actions."""
type = "@@STATE/BATCH"
def __init__(self, actions: List[Action], meta: Optional[Dict[str, Any]] = None):
super().__init__(payload=actions, meta=meta)
class StateHydrateAction(ActionBase):
"""Hydrate state from persisted data."""
type = "@@STATE/HYDRATE"
def __init__(self, state: State, meta: Optional[Dict[str, Any]] = None):
super().__init__(payload=state, meta=meta)
# =============================================================================
# STATE SNAPSHOT
# =============================================================================
@dataclass
class StateSnapshot:
"""Immutable snapshot of state at a point in time."""
state: Any
action: Optional[Action] = None
timestamp: datetime = field(default_factory=datetime.now)
id: str = field(default_factory=lambda: str(uuid.uuid4())[:8])
def to_dict(self) -> Dict[str, Any]:
"""Convert snapshot to dictionary."""
return {
'id': self.id,
'state': self.state,
'action': self.action.to_dict() if self.action else None,
'timestamp': self.timestamp.isoformat(),
}
# =============================================================================
# STORE SLICE
# =============================================================================
@dataclass
class StoreSlice(Generic[T]):
"""A slice of the store with its own reducer and state."""
name: str
reducer: Reducer
initial_state: T
selectors: Dict[str, Selector] = field(default_factory=dict)
# =============================================================================
# COMBINED REDUCER
# =============================================================================
def combine_reducers(reducers: Dict[str, Reducer]) -> Reducer:
"""Combine multiple reducers into one.
Each reducer manages a slice of the state.
Args:
reducers: Dict mapping slice names to reducers
Returns:
Combined reducer function
Example:
root_reducer = combine_reducers({
'counter': counter_reducer,
'todos': todos_reducer,
})
"""
def combined(state: Dict[str, Any], action: Action) -> Dict[str, Any]:
new_state = {}
has_changed = False
for key, reducer in reducers.items():
previous_state = state.get(key) if isinstance(state, dict) else None
new_slice = reducer(previous_state, action)
new_state[key] = new_slice
if new_slice is not previous_state:
has_changed = True
return new_state if has_changed else state
return combined
# =============================================================================
# MIDDLEWARE
# =============================================================================
def logging_middleware(store: StateStore, next: Callable[[Action], Any], action: Action) -> Any:
"""Middleware that logs all actions."""
print(f"[State] Action: {action.type}")
result = next(action)
print(f"[State] New state: {store.get_state()}")
return result
def thunk_middleware(store: StateStore, next: Callable[[Action], Any], action: Action) -> Any:
"""Middleware that allows thunk actions (functions)."""
if callable(action) and not isinstance(action, ActionBase):
# It's a thunk - call it with dispatch and get_state
return action(store.dispatch, store.get_state)
return next(action)
def persistence_middleware(
storage_path: Path,
debounce_ms: int = 1000
) -> Middleware:
"""Create middleware that persists state to disk.
Args:
storage_path: Path to store state
debounce_ms: Debounce time in milliseconds
Returns:
Middleware function
"""
last_save = 0
pending_save = False
def middleware(store: StateStore, next: Callable[[Action], Any], action: Action) -> Any:
nonlocal last_save, pending_save
result = next(action)
# Debounce saves
current_time = time.time() * 1000
if current_time - last_save > debounce_ms:
try:
state = store.get_state()
with open(storage_path, 'w', encoding='utf-8') as f:
json.dump(state, f, indent=2, default=str)
last_save = current_time
pending_save = False
except Exception as e:
logging.getLogger("StateStore").error(f"Failed to persist state: {e}")
else:
pending_save = True
return result
return middleware
# =============================================================================
# STATE STORE
# =============================================================================
class StateStore(Generic[State]):
"""Redux-inspired state store with time-travel debugging.
Features:
- Immutable state updates
- Action history for debugging
- State snapshots for time travel
- Selectors for derived state
- Middleware support
- State persistence
Example:
store = StateStore(
reducer=root_reducer,
initial_state={'count': 0},
middleware=[logging_middleware]
)
# Subscribe to changes
unsubscribe = store.subscribe(lambda new, old: print(f"Changed: {old} -> {new}"))
# Dispatch actions
store.dispatch(IncrementAction())
# Use selectors
count = store.select(lambda state: state['count'])
# Time travel
store.undo() # Undo last action
store.jump_to_snapshot(0) # Jump to initial state
"""
def __init__(
self,
reducer: Reducer[State, Action],
initial_state: Optional[State] = None,
middleware: Optional[List[Middleware]] = None,
max_history: int = 1000,
enable_time_travel: bool = True
):
"""Initialize state store.
Args:
reducer: Root reducer function
initial_state: Initial state value
middleware: List of middleware functions
max_history: Maximum action history size
enable_time_travel: Enable time-travel debugging
"""
self._reducer = reducer
self._state: State = initial_state
self._initial_state = copy.deepcopy(initial_state)
self._middleware = middleware or []
self._max_history = max_history
self._enable_time_travel = enable_time_travel
# Subscribers
self._subscribers: Dict[str, Subscriber] = {}
self._subscriber_counter = 0
# History for time travel
self._history: List[StateSnapshot] = []
self._current_index = -1
# Lock for thread safety
self._lock = threading.RLock()
self._logger = logging.getLogger("StateStore")
# Create initial snapshot
if enable_time_travel:
self._add_snapshot(StateSnapshot(
state=copy.deepcopy(self._state),
action=None
))
# ========== Core Methods ==========
def get_state(self) -> State:
"""Get current state (immutable)."""
with self._lock:
return copy.deepcopy(self._state)
def dispatch(self, action: Action) -> Action:
"""Dispatch an action to update state.
Args:
action: Action to dispatch
Returns:
The dispatched action
"""
# Apply middleware chain
def dispatch_action(a: Action) -> Action:
return self._apply_reducer(a)
# Build middleware chain
chain = dispatch_action
for mw in reversed(self._middleware):
chain = lambda a, mw=mw, next=chain: mw(self, next, a)
return chain(action)
def _apply_reducer(self, action: Action) -> Action:
"""Apply reducer and update state."""
with self._lock:
old_state = self._state
new_state = self._reducer(copy.deepcopy(old_state), action)
# Only update if state changed
if new_state is not old_state:
self._state = new_state
# Add to history
if self._enable_time_travel:
# Remove any future states if we're not at the end
if self._current_index < len(self._history) - 1:
self._history = self._history[:self._current_index + 1]
self._add_snapshot(StateSnapshot(
state=copy.deepcopy(new_state),
action=action
))
# Notify subscribers
self._notify_subscribers(new_state, old_state)
return action
def _add_snapshot(self, snapshot: StateSnapshot) -> None:
"""Add snapshot to history."""
self._history.append(snapshot)
self._current_index = len(self._history) - 1
# Trim history if needed
if len(self._history) > self._max_history:
self._history = self._history[-self._max_history:]
self._current_index = len(self._history) - 1
# ========== Subscription ==========
def subscribe(self, callback: Subscriber, selector: Optional[Selector] = None) -> Callable[[], None]:
"""Subscribe to state changes.
Args:
callback: Function called when state changes
selector: Optional selector to compare specific parts
Returns:
Unsubscribe function
"""
with self._lock:
self._subscriber_counter += 1
sub_id = f"sub_{self._subscriber_counter}"
# Wrap callback with selector if provided
if selector:
last_value = selector(self._state)
def wrapped_callback(new_state: State, old_state: State) -> None:
nonlocal last_value
new_value = selector(new_state)
if new_value != last_value:
last_value = new_value
callback(new_state, old_state)
self._subscribers[sub_id] = wrapped_callback
else:
self._subscribers[sub_id] = callback
def unsubscribe() -> None:
with self._lock:
self._subscribers.pop(sub_id, None)
return unsubscribe
def _notify_subscribers(self, new_state: State, old_state: State) -> None:
"""Notify all subscribers of state change."""
for callback in list(self._subscribers.values()):
try:
callback(new_state, old_state)
except Exception as e:
self._logger.error(f"Error in subscriber: {e}")
# ========== Selectors ==========
def select(self, selector: Selector[T]) -> T:
"""Select a derived value from state.
Args:
selector: Function that extracts value from state
Returns:
Selected value
"""
with self._lock:
return selector(copy.deepcopy(self._state))
def create_selector(self, *input_selectors: Selector, combiner: Callable) -> Selector:
"""Create a memoized selector.
Args:
input_selectors: Selectors that provide input values
combiner: Function that combines inputs into output
Returns:
Memoized selector function
"""
last_inputs = [None] * len(input_selectors)
last_result = None
def memoized_selector(state: State) -> Any:
nonlocal last_inputs, last_result
inputs = [s(state) for s in input_selectors]
# Check if inputs changed
if inputs != last_inputs:
last_inputs = inputs
last_result = combiner(*inputs)
return last_result
return memoized_selector
# ========== Time Travel ==========
def get_history(self) -> List[StateSnapshot]:
"""Get action history."""
with self._lock:
return self._history.copy()
def get_current_index(self) -> int:
"""Get current position in history."""
return self._current_index
def can_undo(self) -> bool:
"""Check if undo is possible."""
return self._current_index > 0
def can_redo(self) -> bool:
"""Check if redo is possible."""
return self._current_index < len(self._history) - 1
def undo(self) -> bool:
"""Undo last action.
Returns:
True if undo was performed
"""
if not self.can_undo():
return False
with self._lock:
self._current_index -= 1
old_state = self._state
self._state = copy.deepcopy(self._history[self._current_index].state)
self._notify_subscribers(self._state, old_state)
return True
def redo(self) -> bool:
"""Redo last undone action.
Returns:
True if redo was performed
"""
if not self.can_redo():
return False
with self._lock:
self._current_index += 1
old_state = self._state
self._state = copy.deepcopy(self._history[self._current_index].state)
self._notify_subscribers(self._state, old_state)
return True
def jump_to_snapshot(self, index: int) -> bool:
"""Jump to a specific snapshot in history.
Args:
index: Snapshot index
Returns:
True if jump was successful
"""
if index < 0 or index >= len(self._history):
return False
with self._lock:
old_state = self._state
self._current_index = index
self._state = copy.deepcopy(self._history[index].state)
self._notify_subscribers(self._state, old_state)
return True
def reset(self) -> None:
"""Reset to initial state."""
self.dispatch(StateResetAction())
with self._lock:
old_state = self._state
self._state = copy.deepcopy(self._initial_state)
if self._enable_time_travel:
self._history.clear()
self._current_index = -1
self._add_snapshot(StateSnapshot(
state=copy.deepcopy(self._state),
action=None
))
self._notify_subscribers(self._state, old_state)
# ========== Persistence ==========
def save_to_disk(self, path: Path) -> bool:
"""Save current state to disk.
Args:
path: File path to save to
Returns:
True if saved successfully
"""
try:
with self._lock:
state_data = {
'state': self._state,
'timestamp': datetime.now().isoformat(),
'history_count': len(self._history),
}
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, 'w', encoding='utf-8') as f:
json.dump(state_data, f, indent=2, default=str)
return True
except Exception as e:
self._logger.error(f"Failed to save state: {e}")
return False
def load_from_disk(self, path: Path) -> bool:
"""Load state from disk.
Args:
path: File path to load from
Returns:
True if loaded successfully
"""
try:
with open(path, 'r', encoding='utf-8') as f:
data = json.load(f)
state = data.get('state')
if state is not None:
self.dispatch(StateHydrateAction(state))
return True
except Exception as e:
self._logger.error(f"Failed to load state: {e}")
return False
# =============================================================================
# MODULE-LEVEL STORE (Application-wide state)
# =============================================================================
_module_stores: Dict[str, StateStore] = {}
def create_store(
name: str,
reducer: Reducer,
initial_state: Optional[Any] = None,
**kwargs
) -> StateStore:
"""Create or get a named store.
Args:
name: Unique store name
reducer: Reducer function
initial_state: Initial state
**kwargs: Additional arguments for StateStore
Returns:
StateStore instance
"""
if name not in _module_stores:
_module_stores[name] = StateStore(
reducer=reducer,
initial_state=initial_state,
**kwargs
)
return _module_stores[name]
def get_store(name: str) -> Optional[StateStore]:
"""Get a named store.
Args:
name: Store name
Returns:
StateStore or None if not found
"""
return _module_stores.get(name)
def remove_store(name: str) -> bool:
"""Remove a named store.
Args:
name: Store name
Returns:
True if removed
"""
if name in _module_stores:
del _module_stores[name]
return True
return False
# =============================================================================
# EXPORTS
# =============================================================================
__all__ = [
# Types
'Action', 'ActionBase', 'Reducer', 'Selector', 'Subscriber', 'Middleware',
# Actions
'StateResetAction', 'StateRestoreAction', 'StateBatchAction', 'StateHydrateAction',
# Classes
'StateSnapshot', 'StoreSlice', 'StateStore',
# Functions
'combine_reducers', 'create_store', 'get_store', 'remove_store',
# Middleware
'logging_middleware', 'thunk_middleware', 'persistence_middleware',
]