State Pattern in Python: State Machine Implementation
The State pattern lets an object alter its behavior when its internal state changes. Instead of scattering conditional logic throughout your code, you encapsulate state-specific behavior in dedicated...
Key Insights
- The State pattern eliminates sprawling conditional logic by encapsulating state-specific behavior in dedicated classes, making your code easier to extend and maintain
- Python’s dynamic nature enables a lightweight state implementation technique using
__class__switching, avoiding the overhead of separate context and state hierarchies - State machines shine when you have well-defined states with clear transitions and state-dependent behavior—but for simple cases, an enum with a transition table is often sufficient
Introduction to the State Pattern
The State pattern lets an object alter its behavior when its internal state changes. Instead of scattering conditional logic throughout your code, you encapsulate state-specific behavior in dedicated state classes. The object appears to change its class.
Consider a document workflow system. Without the State pattern, you end up with code like this:
class Document:
def __init__(self):
self.state = "draft"
def publish(self):
if self.state == "draft":
self.state = "pending_review"
elif self.state == "pending_review":
raise ValueError("Already pending review")
elif self.state == "published":
raise ValueError("Already published")
elif self.state == "archived":
raise ValueError("Cannot publish archived document")
def approve(self):
if self.state == "draft":
raise ValueError("Cannot approve draft directly")
elif self.state == "pending_review":
self.state = "published"
elif self.state == "published":
raise ValueError("Already published")
# ... more conditions
This approach has problems. Every method needs to check every possible state. Adding a new state means modifying every method. The logic for each state is scattered across the class.
The State pattern inverts this structure. Each state becomes a class that knows how to handle operations relevant to that state:
class Document:
def __init__(self):
self.state = DraftState()
def publish(self):
self.state.publish(self)
def approve(self):
self.state.approve(self)
Now adding a new state means adding a new class—existing code remains untouched.
Core Components and UML Structure
The State pattern consists of three participants:
Context holds a reference to the current state object and delegates state-specific behavior to it. The context is what clients interact with.
State Interface defines the methods that all concrete states must implement. In Python, this is typically an abstract base class or Protocol.
Concrete States implement state-specific behavior. Each state class handles operations appropriate for that state and manages transitions to other states.
┌─────────────────┐ ┌─────────────────┐
│ Context │────────▶│ State (ABC) │
├─────────────────┤ ├─────────────────┤
│ - state: State │ │ + handle() │
│ + request() │ │ + next_state() │
└─────────────────┘ └─────────────────┘
△
│
┌───────────────┼───────────────┐
│ │ │
┌────────┴───────┐ ┌────┴────┐ ┌───────┴────────┐
│ ConcreteStateA │ │ StateB │ │ ConcreteStateC │
└────────────────┘ └─────────┘ └────────────────┘
This structure adheres to the Open/Closed Principle: you can add new states without modifying existing state classes or the context.
Basic Implementation in Python
Let’s build a traffic light state machine. Each light color is a state with specific behavior and a defined next state.
from abc import ABC, abstractmethod
class TrafficLightState(ABC):
@abstractmethod
def get_color(self) -> str:
pass
@abstractmethod
def get_duration(self) -> int:
"""Duration in seconds before transitioning."""
pass
@abstractmethod
def next_state(self) -> "TrafficLightState":
pass
def can_pedestrians_cross(self) -> bool:
return False
class RedState(TrafficLightState):
def get_color(self) -> str:
return "RED"
def get_duration(self) -> int:
return 30
def next_state(self) -> TrafficLightState:
return GreenState()
def can_pedestrians_cross(self) -> bool:
return True
class YellowState(TrafficLightState):
def get_color(self) -> str:
return "YELLOW"
def get_duration(self) -> int:
return 5
def next_state(self) -> TrafficLightState:
return RedState()
class GreenState(TrafficLightState):
def get_color(self) -> str:
return "GREEN"
def get_duration(self) -> int:
return 25
def next_state(self) -> TrafficLightState:
return YellowState()
class TrafficLight:
def __init__(self):
self._state = RedState()
@property
def color(self) -> str:
return self._state.get_color()
@property
def pedestrians_can_cross(self) -> bool:
return self._state.can_pedestrians_cross()
def advance(self) -> None:
self._state = self._state.next_state()
# Usage
light = TrafficLight()
print(f"{light.color}, pedestrians: {light.pedestrians_can_cross}") # RED, True
light.advance()
print(f"{light.color}, pedestrians: {light.pedestrians_can_cross}") # GREEN, False
Each state encapsulates its own behavior. The TrafficLight context simply delegates to the current state.
Advanced Techniques: Pythonic Improvements
Python’s dynamic nature allows a clever optimization: instead of maintaining a separate state object, you can switch the class of the context itself.
from typing import Protocol
class OrderState(Protocol):
def pay(self) -> None: ...
def ship(self) -> None: ...
def cancel(self) -> None: ...
class Order:
"""Order that changes its own class to represent state."""
def __init__(self, order_id: str):
self.order_id = order_id
self.items: list[str] = []
def pay(self) -> None:
raise InvalidTransitionError(f"Cannot pay in {self.__class__.__name__}")
def ship(self) -> None:
raise InvalidTransitionError(f"Cannot ship in {self.__class__.__name__}")
def cancel(self) -> None:
raise InvalidTransitionError(f"Cannot cancel in {self.__class__.__name__}")
class InvalidTransitionError(Exception):
pass
class PendingOrder(Order):
def pay(self) -> None:
print(f"Order {self.order_id} paid")
self.__class__ = PaidOrder
def cancel(self) -> None:
print(f"Order {self.order_id} cancelled")
self.__class__ = CancelledOrder
class PaidOrder(Order):
def ship(self) -> None:
print(f"Order {self.order_id} shipped")
self.__class__ = ShippedOrder
def cancel(self) -> None:
print(f"Order {self.order_id} cancelled, initiating refund")
self.__class__ = CancelledOrder
class ShippedOrder(Order):
pass # Cannot pay, ship, or cancel once shipped
class CancelledOrder(Order):
pass # Terminal state
# Usage
order = PendingOrder("ORD-001")
order.pay() # Works, order is now PaidOrder
order.ship() # Works, order is now ShippedOrder
order.pay() # Raises InvalidTransitionError
The __class__ assignment changes the object’s type at runtime. The object keeps its attributes but gains the methods of the new class. This technique reduces boilerplate but can confuse developers unfamiliar with it—use judiciously.
For better type safety, combine enums with state classes:
from enum import Enum, auto
class OrderStatus(Enum):
PENDING = auto()
PAID = auto()
SHIPPED = auto()
DELIVERED = auto()
CANCELLED = auto()
class Order:
def __init__(self, order_id: str):
self.order_id = order_id
self._state: OrderState = PendingState()
@property
def status(self) -> OrderStatus:
return self._state.status
Real-World Example: Order Processing System
Here’s a complete e-commerce order workflow with proper error handling:
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum, auto
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Self
class OrderStatus(Enum):
PENDING = auto()
PAID = auto()
SHIPPED = auto()
DELIVERED = auto()
CANCELLED = auto()
class InvalidTransitionError(Exception):
def __init__(self, current: OrderStatus, action: str):
super().__init__(f"Cannot {action} order in {current.name} state")
class OrderState(ABC):
@property
@abstractmethod
def status(self) -> OrderStatus:
pass
def pay(self, context: "OrderContext") -> None:
raise InvalidTransitionError(self.status, "pay")
def ship(self, context: "OrderContext", tracking: str) -> None:
raise InvalidTransitionError(self.status, "ship")
def deliver(self, context: "OrderContext") -> None:
raise InvalidTransitionError(self.status, "deliver")
def cancel(self, context: "OrderContext", reason: str) -> None:
raise InvalidTransitionError(self.status, "cancel")
class PendingState(OrderState):
@property
def status(self) -> OrderStatus:
return OrderStatus.PENDING
def pay(self, context: "OrderContext") -> None:
context.paid_at = datetime.now()
context.transition_to(PaidState())
def cancel(self, context: "OrderContext", reason: str) -> None:
context.cancellation_reason = reason
context.transition_to(CancelledState())
class PaidState(OrderState):
@property
def status(self) -> OrderStatus:
return OrderStatus.PAID
def ship(self, context: "OrderContext", tracking: str) -> None:
context.tracking_number = tracking
context.shipped_at = datetime.now()
context.transition_to(ShippedState())
def cancel(self, context: "OrderContext", reason: str) -> None:
context.cancellation_reason = reason
context.requires_refund = True
context.transition_to(CancelledState())
class ShippedState(OrderState):
@property
def status(self) -> OrderStatus:
return OrderStatus.SHIPPED
def deliver(self, context: "OrderContext") -> None:
context.delivered_at = datetime.now()
context.transition_to(DeliveredState())
class DeliveredState(OrderState):
@property
def status(self) -> OrderStatus:
return OrderStatus.DELIVERED
class CancelledState(OrderState):
@property
def status(self) -> OrderStatus:
return OrderStatus.CANCELLED
@dataclass
class OrderContext:
order_id: str
items: list[str] = field(default_factory=list)
_state: OrderState = field(default_factory=PendingState)
# State-specific data
paid_at: datetime | None = None
shipped_at: datetime | None = None
delivered_at: datetime | None = None
tracking_number: str | None = None
cancellation_reason: str | None = None
requires_refund: bool = False
@property
def status(self) -> OrderStatus:
return self._state.status
def transition_to(self, state: OrderState) -> None:
print(f"Order {self.order_id}: {self._state.status.name} -> {state.status.name}")
self._state = state
def pay(self) -> None:
self._state.pay(self)
def ship(self, tracking: str) -> None:
self._state.ship(self, tracking)
def deliver(self) -> None:
self._state.deliver(self)
def cancel(self, reason: str) -> None:
self._state.cancel(self, reason)
Testing State Machines
State machines are straightforward to test. Test each state’s behavior in isolation, then test transitions:
import pytest
from datetime import datetime
@pytest.fixture
def pending_order() -> OrderContext:
return OrderContext(order_id="TEST-001", items=["Widget"])
@pytest.fixture
def paid_order(pending_order: OrderContext) -> OrderContext:
pending_order.pay()
return pending_order
class TestPendingState:
def test_can_pay(self, pending_order: OrderContext):
pending_order.pay()
assert pending_order.status == OrderStatus.PAID
assert pending_order.paid_at is not None
def test_can_cancel(self, pending_order: OrderContext):
pending_order.cancel("Changed mind")
assert pending_order.status == OrderStatus.CANCELLED
assert pending_order.requires_refund is False
def test_cannot_ship(self, pending_order: OrderContext):
with pytest.raises(InvalidTransitionError):
pending_order.ship("TRACK123")
class TestPaidState:
def test_can_ship(self, paid_order: OrderContext):
paid_order.ship("TRACK123")
assert paid_order.status == OrderStatus.SHIPPED
assert paid_order.tracking_number == "TRACK123"
def test_cancel_requires_refund(self, paid_order: OrderContext):
paid_order.cancel("Out of stock")
assert paid_order.status == OrderStatus.CANCELLED
assert paid_order.requires_refund is True
class TestFullWorkflow:
def test_happy_path(self, pending_order: OrderContext):
pending_order.pay()
pending_order.ship("TRACK123")
pending_order.deliver()
assert pending_order.status == OrderStatus.DELIVERED
State Pattern vs. Alternatives
The State pattern isn’t always the right choice.
Simple enum with transition table: For straightforward state machines without complex per-state behavior, a dictionary mapping (current_state, action) to next_state is simpler and more explicit.
TRANSITIONS = {
(OrderStatus.PENDING, "pay"): OrderStatus.PAID,
(OrderStatus.PAID, "ship"): OrderStatus.SHIPPED,
# ...
}
The transitions library: For complex state machines with guards, callbacks, and nested states, use a dedicated library rather than hand-rolling everything.
from transitions import Machine
machine = Machine(
states=["pending", "paid", "shipped"],
transitions=[
{"trigger": "pay", "source": "pending", "dest": "paid"},
{"trigger": "ship", "source": "paid", "dest": "shipped"},
],
initial="pending",
)
Use the State pattern when:
- States have significantly different behavior, not just different allowed transitions
- You expect to add new states over time
- State-specific logic is complex enough to warrant its own class
Skip it when:
- You have fewer than four states with simple transitions
- State behavior is mostly identical across states
- You need features like hierarchical states or history
The State pattern trades simplicity for extensibility. Make sure you need that extensibility before paying the complexity cost.