712 lines
21 KiB
Python
712 lines
21 KiB
Python
"""
|
|
EU-Utility Premium - State Management
|
|
======================================
|
|
|
|
Redux-inspired state store with:
|
|
- Immutable state updates
|
|
- Time-travel debugging
|
|
- Selectors for derived state
|
|
- Middleware support
|
|
- State persistence
|
|
|
|
Example:
|
|
from premium.core.state import StateStore, Action, Reducer
|
|
|
|
# Define actions
|
|
class IncrementAction(Action):
|
|
type = "INCREMENT"
|
|
|
|
# Define reducer
|
|
def counter_reducer(state: int, action: Action) -> int:
|
|
if action.type == "INCREMENT":
|
|
return state + 1
|
|
return state
|
|
|
|
# Create store
|
|
store = StateStore(counter_reducer, initial_state=0)
|
|
store.dispatch(IncrementAction())
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import copy
|
|
import json
|
|
import logging
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field, asdict
|
|
from datetime import datetime
|
|
from enum import Enum, auto
|
|
from pathlib import Path
|
|
from typing import (
|
|
Any, Callable, Dict, Generic, List, Optional, Set, TypeVar, Union,
|
|
Protocol, runtime_checkable
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# TYPE DEFINITIONS
|
|
# =============================================================================
|
|
|
|
T = TypeVar('T')
|
|
State = TypeVar('State')
|
|
|
|
|
|
@runtime_checkable
|
|
class Action(Protocol):
|
|
"""Protocol for actions."""
|
|
type: str
|
|
payload: Optional[Any] = None
|
|
meta: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
class ActionBase:
|
|
"""Base class for actions."""
|
|
type: str = "UNKNOWN"
|
|
payload: Any = None
|
|
meta: Optional[Dict[str, Any]] = None
|
|
timestamp: datetime = field(default_factory=datetime.now)
|
|
|
|
def __init__(self, payload: Any = None, meta: Optional[Dict[str, Any]] = None):
|
|
self.payload = payload
|
|
self.meta = meta or {}
|
|
self.timestamp = datetime.now()
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert action to dictionary."""
|
|
return {
|
|
'type': self.type,
|
|
'payload': self.payload,
|
|
'meta': self.meta,
|
|
'timestamp': self.timestamp.isoformat(),
|
|
}
|
|
|
|
|
|
# Reducer type: (state, action) -> new_state
|
|
Reducer = Callable[[State, Action], State]
|
|
|
|
# Selector type: state -> derived_value
|
|
Selector = Callable[[State], T]
|
|
|
|
# Subscriber type: (new_state, old_state) -> None
|
|
Subscriber = Callable[[State, State], None]
|
|
|
|
# Middleware type: store -> next -> action -> result
|
|
Middleware = Callable[['StateStore', Callable[[Action], Any], Action], Any]
|
|
|
|
|
|
# =============================================================================
|
|
# BUILT-IN ACTIONS
|
|
# =============================================================================
|
|
|
|
class StateResetAction(ActionBase):
|
|
"""Reset state to initial value."""
|
|
type = "@@STATE/RESET"
|
|
|
|
|
|
class StateRestoreAction(ActionBase):
|
|
"""Restore state from snapshot."""
|
|
type = "@@STATE/RESTORE"
|
|
|
|
def __init__(self, state: State, meta: Optional[Dict[str, Any]] = None):
|
|
super().__init__(payload=state, meta=meta)
|
|
|
|
|
|
class StateBatchAction(ActionBase):
|
|
"""Batch multiple actions."""
|
|
type = "@@STATE/BATCH"
|
|
|
|
def __init__(self, actions: List[Action], meta: Optional[Dict[str, Any]] = None):
|
|
super().__init__(payload=actions, meta=meta)
|
|
|
|
|
|
class StateHydrateAction(ActionBase):
|
|
"""Hydrate state from persisted data."""
|
|
type = "@@STATE/HYDRATE"
|
|
|
|
def __init__(self, state: State, meta: Optional[Dict[str, Any]] = None):
|
|
super().__init__(payload=state, meta=meta)
|
|
|
|
|
|
# =============================================================================
|
|
# STATE SNAPSHOT
|
|
# =============================================================================
|
|
|
|
@dataclass
|
|
class StateSnapshot:
|
|
"""Immutable snapshot of state at a point in time."""
|
|
state: Any
|
|
action: Optional[Action] = None
|
|
timestamp: datetime = field(default_factory=datetime.now)
|
|
id: str = field(default_factory=lambda: str(uuid.uuid4())[:8])
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert snapshot to dictionary."""
|
|
return {
|
|
'id': self.id,
|
|
'state': self.state,
|
|
'action': self.action.to_dict() if self.action else None,
|
|
'timestamp': self.timestamp.isoformat(),
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# STORE SLICE
|
|
# =============================================================================
|
|
|
|
@dataclass
|
|
class StoreSlice(Generic[T]):
|
|
"""A slice of the store with its own reducer and state."""
|
|
name: str
|
|
reducer: Reducer
|
|
initial_state: T
|
|
selectors: Dict[str, Selector] = field(default_factory=dict)
|
|
|
|
|
|
# =============================================================================
|
|
# COMBINED REDUCER
|
|
# =============================================================================
|
|
|
|
def combine_reducers(reducers: Dict[str, Reducer]) -> Reducer:
|
|
"""Combine multiple reducers into one.
|
|
|
|
Each reducer manages a slice of the state.
|
|
|
|
Args:
|
|
reducers: Dict mapping slice names to reducers
|
|
|
|
Returns:
|
|
Combined reducer function
|
|
|
|
Example:
|
|
root_reducer = combine_reducers({
|
|
'counter': counter_reducer,
|
|
'todos': todos_reducer,
|
|
})
|
|
"""
|
|
def combined(state: Dict[str, Any], action: Action) -> Dict[str, Any]:
|
|
new_state = {}
|
|
has_changed = False
|
|
|
|
for key, reducer in reducers.items():
|
|
previous_state = state.get(key) if isinstance(state, dict) else None
|
|
new_slice = reducer(previous_state, action)
|
|
new_state[key] = new_slice
|
|
|
|
if new_slice is not previous_state:
|
|
has_changed = True
|
|
|
|
return new_state if has_changed else state
|
|
|
|
return combined
|
|
|
|
|
|
# =============================================================================
|
|
# MIDDLEWARE
|
|
# =============================================================================
|
|
|
|
def logging_middleware(store: StateStore, next: Callable[[Action], Any], action: Action) -> Any:
|
|
"""Middleware that logs all actions."""
|
|
print(f"[State] Action: {action.type}")
|
|
result = next(action)
|
|
print(f"[State] New state: {store.get_state()}")
|
|
return result
|
|
|
|
|
|
def thunk_middleware(store: StateStore, next: Callable[[Action], Any], action: Action) -> Any:
|
|
"""Middleware that allows thunk actions (functions)."""
|
|
if callable(action) and not isinstance(action, ActionBase):
|
|
# It's a thunk - call it with dispatch and get_state
|
|
return action(store.dispatch, store.get_state)
|
|
return next(action)
|
|
|
|
|
|
def persistence_middleware(
|
|
storage_path: Path,
|
|
debounce_ms: int = 1000
|
|
) -> Middleware:
|
|
"""Create middleware that persists state to disk.
|
|
|
|
Args:
|
|
storage_path: Path to store state
|
|
debounce_ms: Debounce time in milliseconds
|
|
|
|
Returns:
|
|
Middleware function
|
|
"""
|
|
last_save = 0
|
|
pending_save = False
|
|
|
|
def middleware(store: StateStore, next: Callable[[Action], Any], action: Action) -> Any:
|
|
nonlocal last_save, pending_save
|
|
|
|
result = next(action)
|
|
|
|
# Debounce saves
|
|
current_time = time.time() * 1000
|
|
if current_time - last_save > debounce_ms:
|
|
try:
|
|
state = store.get_state()
|
|
with open(storage_path, 'w', encoding='utf-8') as f:
|
|
json.dump(state, f, indent=2, default=str)
|
|
last_save = current_time
|
|
pending_save = False
|
|
except Exception as e:
|
|
logging.getLogger("StateStore").error(f"Failed to persist state: {e}")
|
|
else:
|
|
pending_save = True
|
|
|
|
return result
|
|
|
|
return middleware
|
|
|
|
|
|
# =============================================================================
|
|
# STATE STORE
|
|
# =============================================================================
|
|
|
|
class StateStore(Generic[State]):
|
|
"""Redux-inspired state store with time-travel debugging.
|
|
|
|
Features:
|
|
- Immutable state updates
|
|
- Action history for debugging
|
|
- State snapshots for time travel
|
|
- Selectors for derived state
|
|
- Middleware support
|
|
- State persistence
|
|
|
|
Example:
|
|
store = StateStore(
|
|
reducer=root_reducer,
|
|
initial_state={'count': 0},
|
|
middleware=[logging_middleware]
|
|
)
|
|
|
|
# Subscribe to changes
|
|
unsubscribe = store.subscribe(lambda new, old: print(f"Changed: {old} -> {new}"))
|
|
|
|
# Dispatch actions
|
|
store.dispatch(IncrementAction())
|
|
|
|
# Use selectors
|
|
count = store.select(lambda state: state['count'])
|
|
|
|
# Time travel
|
|
store.undo() # Undo last action
|
|
store.jump_to_snapshot(0) # Jump to initial state
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
reducer: Reducer[State, Action],
|
|
initial_state: Optional[State] = None,
|
|
middleware: Optional[List[Middleware]] = None,
|
|
max_history: int = 1000,
|
|
enable_time_travel: bool = True
|
|
):
|
|
"""Initialize state store.
|
|
|
|
Args:
|
|
reducer: Root reducer function
|
|
initial_state: Initial state value
|
|
middleware: List of middleware functions
|
|
max_history: Maximum action history size
|
|
enable_time_travel: Enable time-travel debugging
|
|
"""
|
|
self._reducer = reducer
|
|
self._state: State = initial_state
|
|
self._initial_state = copy.deepcopy(initial_state)
|
|
self._middleware = middleware or []
|
|
self._max_history = max_history
|
|
self._enable_time_travel = enable_time_travel
|
|
|
|
# Subscribers
|
|
self._subscribers: Dict[str, Subscriber] = {}
|
|
self._subscriber_counter = 0
|
|
|
|
# History for time travel
|
|
self._history: List[StateSnapshot] = []
|
|
self._current_index = -1
|
|
|
|
# Lock for thread safety
|
|
self._lock = threading.RLock()
|
|
self._logger = logging.getLogger("StateStore")
|
|
|
|
# Create initial snapshot
|
|
if enable_time_travel:
|
|
self._add_snapshot(StateSnapshot(
|
|
state=copy.deepcopy(self._state),
|
|
action=None
|
|
))
|
|
|
|
# ========== Core Methods ==========
|
|
|
|
def get_state(self) -> State:
|
|
"""Get current state (immutable)."""
|
|
with self._lock:
|
|
return copy.deepcopy(self._state)
|
|
|
|
def dispatch(self, action: Action) -> Action:
|
|
"""Dispatch an action to update state.
|
|
|
|
Args:
|
|
action: Action to dispatch
|
|
|
|
Returns:
|
|
The dispatched action
|
|
"""
|
|
# Apply middleware chain
|
|
def dispatch_action(a: Action) -> Action:
|
|
return self._apply_reducer(a)
|
|
|
|
# Build middleware chain
|
|
chain = dispatch_action
|
|
for mw in reversed(self._middleware):
|
|
chain = lambda a, mw=mw, next=chain: mw(self, next, a)
|
|
|
|
return chain(action)
|
|
|
|
def _apply_reducer(self, action: Action) -> Action:
|
|
"""Apply reducer and update state."""
|
|
with self._lock:
|
|
old_state = self._state
|
|
new_state = self._reducer(copy.deepcopy(old_state), action)
|
|
|
|
# Only update if state changed
|
|
if new_state is not old_state:
|
|
self._state = new_state
|
|
|
|
# Add to history
|
|
if self._enable_time_travel:
|
|
# Remove any future states if we're not at the end
|
|
if self._current_index < len(self._history) - 1:
|
|
self._history = self._history[:self._current_index + 1]
|
|
|
|
self._add_snapshot(StateSnapshot(
|
|
state=copy.deepcopy(new_state),
|
|
action=action
|
|
))
|
|
|
|
# Notify subscribers
|
|
self._notify_subscribers(new_state, old_state)
|
|
|
|
return action
|
|
|
|
def _add_snapshot(self, snapshot: StateSnapshot) -> None:
|
|
"""Add snapshot to history."""
|
|
self._history.append(snapshot)
|
|
self._current_index = len(self._history) - 1
|
|
|
|
# Trim history if needed
|
|
if len(self._history) > self._max_history:
|
|
self._history = self._history[-self._max_history:]
|
|
self._current_index = len(self._history) - 1
|
|
|
|
# ========== Subscription ==========
|
|
|
|
def subscribe(self, callback: Subscriber, selector: Optional[Selector] = None) -> Callable[[], None]:
|
|
"""Subscribe to state changes.
|
|
|
|
Args:
|
|
callback: Function called when state changes
|
|
selector: Optional selector to compare specific parts
|
|
|
|
Returns:
|
|
Unsubscribe function
|
|
"""
|
|
with self._lock:
|
|
self._subscriber_counter += 1
|
|
sub_id = f"sub_{self._subscriber_counter}"
|
|
|
|
# Wrap callback with selector if provided
|
|
if selector:
|
|
last_value = selector(self._state)
|
|
|
|
def wrapped_callback(new_state: State, old_state: State) -> None:
|
|
nonlocal last_value
|
|
new_value = selector(new_state)
|
|
if new_value != last_value:
|
|
last_value = new_value
|
|
callback(new_state, old_state)
|
|
|
|
self._subscribers[sub_id] = wrapped_callback
|
|
else:
|
|
self._subscribers[sub_id] = callback
|
|
|
|
def unsubscribe() -> None:
|
|
with self._lock:
|
|
self._subscribers.pop(sub_id, None)
|
|
|
|
return unsubscribe
|
|
|
|
def _notify_subscribers(self, new_state: State, old_state: State) -> None:
|
|
"""Notify all subscribers of state change."""
|
|
for callback in list(self._subscribers.values()):
|
|
try:
|
|
callback(new_state, old_state)
|
|
except Exception as e:
|
|
self._logger.error(f"Error in subscriber: {e}")
|
|
|
|
# ========== Selectors ==========
|
|
|
|
def select(self, selector: Selector[T]) -> T:
|
|
"""Select a derived value from state.
|
|
|
|
Args:
|
|
selector: Function that extracts value from state
|
|
|
|
Returns:
|
|
Selected value
|
|
"""
|
|
with self._lock:
|
|
return selector(copy.deepcopy(self._state))
|
|
|
|
def create_selector(self, *input_selectors: Selector, combiner: Callable) -> Selector:
|
|
"""Create a memoized selector.
|
|
|
|
Args:
|
|
input_selectors: Selectors that provide input values
|
|
combiner: Function that combines inputs into output
|
|
|
|
Returns:
|
|
Memoized selector function
|
|
"""
|
|
last_inputs = [None] * len(input_selectors)
|
|
last_result = None
|
|
|
|
def memoized_selector(state: State) -> Any:
|
|
nonlocal last_inputs, last_result
|
|
|
|
inputs = [s(state) for s in input_selectors]
|
|
|
|
# Check if inputs changed
|
|
if inputs != last_inputs:
|
|
last_inputs = inputs
|
|
last_result = combiner(*inputs)
|
|
|
|
return last_result
|
|
|
|
return memoized_selector
|
|
|
|
# ========== Time Travel ==========
|
|
|
|
def get_history(self) -> List[StateSnapshot]:
|
|
"""Get action history."""
|
|
with self._lock:
|
|
return self._history.copy()
|
|
|
|
def get_current_index(self) -> int:
|
|
"""Get current position in history."""
|
|
return self._current_index
|
|
|
|
def can_undo(self) -> bool:
|
|
"""Check if undo is possible."""
|
|
return self._current_index > 0
|
|
|
|
def can_redo(self) -> bool:
|
|
"""Check if redo is possible."""
|
|
return self._current_index < len(self._history) - 1
|
|
|
|
def undo(self) -> bool:
|
|
"""Undo last action.
|
|
|
|
Returns:
|
|
True if undo was performed
|
|
"""
|
|
if not self.can_undo():
|
|
return False
|
|
|
|
with self._lock:
|
|
self._current_index -= 1
|
|
old_state = self._state
|
|
self._state = copy.deepcopy(self._history[self._current_index].state)
|
|
self._notify_subscribers(self._state, old_state)
|
|
|
|
return True
|
|
|
|
def redo(self) -> bool:
|
|
"""Redo last undone action.
|
|
|
|
Returns:
|
|
True if redo was performed
|
|
"""
|
|
if not self.can_redo():
|
|
return False
|
|
|
|
with self._lock:
|
|
self._current_index += 1
|
|
old_state = self._state
|
|
self._state = copy.deepcopy(self._history[self._current_index].state)
|
|
self._notify_subscribers(self._state, old_state)
|
|
|
|
return True
|
|
|
|
def jump_to_snapshot(self, index: int) -> bool:
|
|
"""Jump to a specific snapshot in history.
|
|
|
|
Args:
|
|
index: Snapshot index
|
|
|
|
Returns:
|
|
True if jump was successful
|
|
"""
|
|
if index < 0 or index >= len(self._history):
|
|
return False
|
|
|
|
with self._lock:
|
|
old_state = self._state
|
|
self._current_index = index
|
|
self._state = copy.deepcopy(self._history[index].state)
|
|
self._notify_subscribers(self._state, old_state)
|
|
|
|
return True
|
|
|
|
def reset(self) -> None:
|
|
"""Reset to initial state."""
|
|
self.dispatch(StateResetAction())
|
|
|
|
with self._lock:
|
|
old_state = self._state
|
|
self._state = copy.deepcopy(self._initial_state)
|
|
|
|
if self._enable_time_travel:
|
|
self._history.clear()
|
|
self._current_index = -1
|
|
self._add_snapshot(StateSnapshot(
|
|
state=copy.deepcopy(self._state),
|
|
action=None
|
|
))
|
|
|
|
self._notify_subscribers(self._state, old_state)
|
|
|
|
# ========== Persistence ==========
|
|
|
|
def save_to_disk(self, path: Path) -> bool:
|
|
"""Save current state to disk.
|
|
|
|
Args:
|
|
path: File path to save to
|
|
|
|
Returns:
|
|
True if saved successfully
|
|
"""
|
|
try:
|
|
with self._lock:
|
|
state_data = {
|
|
'state': self._state,
|
|
'timestamp': datetime.now().isoformat(),
|
|
'history_count': len(self._history),
|
|
}
|
|
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
with open(path, 'w', encoding='utf-8') as f:
|
|
json.dump(state_data, f, indent=2, default=str)
|
|
|
|
return True
|
|
except Exception as e:
|
|
self._logger.error(f"Failed to save state: {e}")
|
|
return False
|
|
|
|
def load_from_disk(self, path: Path) -> bool:
|
|
"""Load state from disk.
|
|
|
|
Args:
|
|
path: File path to load from
|
|
|
|
Returns:
|
|
True if loaded successfully
|
|
"""
|
|
try:
|
|
with open(path, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
|
|
state = data.get('state')
|
|
if state is not None:
|
|
self.dispatch(StateHydrateAction(state))
|
|
|
|
return True
|
|
except Exception as e:
|
|
self._logger.error(f"Failed to load state: {e}")
|
|
return False
|
|
|
|
|
|
# =============================================================================
|
|
# MODULE-LEVEL STORE (Application-wide state)
|
|
# =============================================================================
|
|
|
|
_module_stores: Dict[str, StateStore] = {}
|
|
|
|
|
|
def create_store(
|
|
name: str,
|
|
reducer: Reducer,
|
|
initial_state: Optional[Any] = None,
|
|
**kwargs
|
|
) -> StateStore:
|
|
"""Create or get a named store.
|
|
|
|
Args:
|
|
name: Unique store name
|
|
reducer: Reducer function
|
|
initial_state: Initial state
|
|
**kwargs: Additional arguments for StateStore
|
|
|
|
Returns:
|
|
StateStore instance
|
|
"""
|
|
if name not in _module_stores:
|
|
_module_stores[name] = StateStore(
|
|
reducer=reducer,
|
|
initial_state=initial_state,
|
|
**kwargs
|
|
)
|
|
return _module_stores[name]
|
|
|
|
|
|
def get_store(name: str) -> Optional[StateStore]:
|
|
"""Get a named store.
|
|
|
|
Args:
|
|
name: Store name
|
|
|
|
Returns:
|
|
StateStore or None if not found
|
|
"""
|
|
return _module_stores.get(name)
|
|
|
|
|
|
def remove_store(name: str) -> bool:
|
|
"""Remove a named store.
|
|
|
|
Args:
|
|
name: Store name
|
|
|
|
Returns:
|
|
True if removed
|
|
"""
|
|
if name in _module_stores:
|
|
del _module_stores[name]
|
|
return True
|
|
return False
|
|
|
|
|
|
# =============================================================================
|
|
# EXPORTS
|
|
# =============================================================================
|
|
|
|
__all__ = [
|
|
# Types
|
|
'Action', 'ActionBase', 'Reducer', 'Selector', 'Subscriber', 'Middleware',
|
|
# Actions
|
|
'StateResetAction', 'StateRestoreAction', 'StateBatchAction', 'StateHydrateAction',
|
|
# Classes
|
|
'StateSnapshot', 'StoreSlice', 'StateStore',
|
|
# Functions
|
|
'combine_reducers', 'create_store', 'get_store', 'remove_store',
|
|
# Middleware
|
|
'logging_middleware', 'thunk_middleware', 'persistence_middleware',
|
|
]
|