System Design: Consistent Hashing for Distributed Caches

When engineers first build a distributed cache, they reach for the obvious solution: hash the key and modulo by the number of nodes. It's simple, it's fast, and it works—until you need to add or...

Key Insights

  • Consistent hashing limits key remapping to K/N keys on average when nodes change, compared to traditional modulo hashing which remaps nearly all keys—this is the difference between a minor cache miss spike and a complete cache stampede.
  • Virtual nodes are non-negotiable in production; without them, you’ll see 10x load variance between nodes, and adding a single powerful server becomes impossible without manual rebalancing.
  • The hash function matters less than you think (MurmurHash3, xxHash, and even MD5 all work), but your data structure choice directly impacts lookup latency—use a sorted array with binary search, not a tree.

The Problem with Traditional Hashing

When engineers first build a distributed cache, they reach for the obvious solution: hash the key and modulo by the number of nodes. It’s simple, it’s fast, and it works—until you need to add or remove a cache server.

def simple_hash(key: str, num_nodes: int) -> int:
    return hash(key) % num_nodes

# With 3 nodes
nodes_3 = [simple_hash(f"user:{i}", 3) for i in range(10)]
print(f"3 nodes: {nodes_3}")  # [2, 0, 1, 2, 0, 1, 2, 0, 1, 2]

# Add one node - watch the chaos
nodes_4 = [simple_hash(f"user:{i}", 4) for i in range(10)]
print(f"4 nodes: {nodes_4}")  # [2, 3, 0, 1, 2, 3, 0, 1, 2, 3]

# Count how many keys moved
moved = sum(1 for a, b in zip(nodes_3, nodes_4) if a != b)
print(f"Keys remapped: {moved}/10")  # Typically 7-8 out of 10

This is the rehashing catastrophe. Adding a fourth node to a three-node cluster remaps roughly 75% of all keys. In production, this means 75% of your cache becomes instantly useless. Every request that would have been a cache hit now becomes a cache miss, hammering your database with a thundering herd of requests.

I’ve seen this take down production systems. A well-intentioned ops engineer adds a cache node during peak traffic to handle load, and instead of helping, it triggers a cascade of database timeouts. The fix is consistent hashing.

Consistent Hashing Fundamentals

Consistent hashing reimagines key distribution as a ring. Instead of mapping keys to node indices, we map both keys and nodes onto a circular hash space (typically 0 to 2^32-1). To find which node owns a key, we hash the key and walk clockwise until we hit a node.

import hashlib
from bisect import bisect_right
from typing import Optional

class BasicHashRing:
    def __init__(self):
        self.ring: list[int] = []  # Sorted positions on ring
        self.nodes: dict[int, str] = {}  # Position -> node name
    
    def _hash(self, key: str) -> int:
        """Hash a string to a position on the ring (0 to 2^32-1)."""
        digest = hashlib.md5(key.encode()).digest()
        return int.from_bytes(digest[:4], 'big')
    
    def add_node(self, node: str) -> None:
        pos = self._hash(node)
        if pos not in self.nodes:
            self.ring.append(pos)
            self.ring.sort()
            self.nodes[pos] = node
    
    def get_node(self, key: str) -> Optional[str]:
        if not self.ring:
            return None
        
        pos = self._hash(key)
        # Find first node position >= key position (clockwise walk)
        idx = bisect_right(self.ring, pos)
        # Wrap around if we've gone past the last node
        if idx == len(self.ring):
            idx = 0
        return self.nodes[self.ring[idx]]

The magic happens during topology changes. When you add a node, it lands at one position on the ring and only “steals” keys from the next node clockwise. When you remove a node, only its keys move to the next node. On average, only K/N keys are affected (where K is total keys and N is node count), compared to (N-1)/N keys with modulo hashing.

Virtual Nodes for Load Balancing

The basic ring has a fatal flaw: with only a few nodes, distribution is wildly uneven. Hash functions don’t guarantee uniform spacing—you might have three nodes clustered in one quadrant of the ring, leaving another quadrant empty.

Virtual nodes solve this by giving each physical node multiple positions on the ring. Instead of hashing “cache-server-1” once, we hash “cache-server-1-0”, “cache-server-1-1”, through “cache-server-1-149” to get 150 positions.

class VNodeHashRing:
    def __init__(self, default_vnodes: int = 150):
        self.default_vnodes = default_vnodes
        self.ring: list[int] = []
        self.positions: dict[int, str] = {}  # Position -> physical node
        self.node_vnodes: dict[str, int] = {}  # Node -> vnode count
    
    def _hash(self, key: str) -> int:
        digest = hashlib.md5(key.encode()).digest()
        return int.from_bytes(digest[:4], 'big')
    
    def add_node(self, node: str, vnodes: Optional[int] = None) -> None:
        """Add a node with configurable virtual node count."""
        vnode_count = vnodes or self.default_vnodes
        self.node_vnodes[node] = vnode_count
        
        for i in range(vnode_count):
            vnode_key = f"{node}-vnode-{i}"
            pos = self._hash(vnode_key)
            if pos not in self.positions:
                self.ring.append(pos)
                self.positions[pos] = node
        
        self.ring.sort()
    
    def remove_node(self, node: str) -> None:
        """Remove all virtual nodes for a physical node."""
        if node not in self.node_vnodes:
            return
        
        vnode_count = self.node_vnodes.pop(node)
        for i in range(vnode_count):
            vnode_key = f"{node}-vnode-{i}"
            pos = self._hash(vnode_key)
            if pos in self.positions:
                self.ring.remove(pos)
                del self.positions[pos]
    
    def get_node(self, key: str) -> Optional[str]:
        if not self.ring:
            return None
        
        pos = self._hash(key)
        idx = bisect_right(self.ring, pos) % len(self.ring)
        return self.positions[self.ring[idx]]

Virtual nodes also enable heterogeneous clusters. Got a new server with twice the RAM? Give it twice the virtual nodes. Decommissioning an old machine gradually? Reduce its vnode count over time instead of removing it all at once.

The vnode count is a tuning parameter. Too few (under 50) and distribution remains uneven. Too many (over 500) and you’re wasting memory on ring metadata. I’ve found 100-200 vnodes per physical node hits the sweet spot for most workloads.

Implementation Walkthrough

Let’s build a production-ready implementation with proper abstractions and efficient operations.

import hashlib
from bisect import bisect_right
from dataclasses import dataclass, field
from typing import Optional
import mmh3  # MurmurHash3 - pip install mmh3

@dataclass
class ConsistentHashRing:
    """Production-ready consistent hash ring with virtual nodes."""
    
    vnodes_per_node: int = 150
    _ring: list[int] = field(default_factory=list, repr=False)
    _position_to_node: dict[int, str] = field(default_factory=dict, repr=False)
    _nodes: set[str] = field(default_factory=set)
    
    def _hash(self, key: str) -> int:
        """MurmurHash3 - faster than MD5, good distribution."""
        return mmh3.hash(key, signed=False)
    
    def add_node(self, node: str, weight: int = 1) -> int:
        """Add node with optional weight multiplier. Returns keys affected."""
        if node in self._nodes:
            return 0
        
        self._nodes.add(node)
        vnodes = self.vnodes_per_node * weight
        
        for i in range(vnodes):
            pos = self._hash(f"{node}:{i}")
            # Handle hash collisions (rare but possible)
            while pos in self._position_to_node:
                pos = (pos + 1) % (2**32)
            self._position_to_node[pos] = node
        
        self._ring = sorted(self._position_to_node.keys())
        return len(self._ring) // len(self._nodes)  # Approximate keys moved
    
    def remove_node(self, node: str) -> int:
        """Remove node and all its virtual nodes."""
        if node not in self._nodes:
            return 0
        
        self._nodes.remove(node)
        positions_to_remove = [
            pos for pos, n in self._position_to_node.items() if n == node
        ]
        
        for pos in positions_to_remove:
            del self._position_to_node[pos]
        
        self._ring = sorted(self._position_to_node.keys())
        return len(positions_to_remove)
    
    def get_node(self, key: str) -> Optional[str]:
        """O(log n) lookup using binary search."""
        if not self._ring:
            return None
        
        pos = self._hash(key)
        idx = bisect_right(self._ring, pos)
        if idx == len(self._ring):
            idx = 0
        return self._position_to_node[self._ring[idx]]
    
    @property
    def node_count(self) -> int:
        return len(self._nodes)

I chose MurmurHash3 over MD5 because it’s 10x faster with equivalent distribution quality. For the ring structure, a sorted list with binary search gives O(log n) lookups and O(n log n) insertions—acceptable since node additions are rare compared to key lookups. Skip lists or balanced trees work too, but the sorted array is simpler and cache-friendly.

Replication and Fault Tolerance

A cache without replication is a cache that will lose data. The hash ring naturally supports replica placement: instead of returning one node, walk clockwise and return N distinct physical nodes.

def get_replica_nodes(self, key: str, replica_count: int = 3) -> list[str]:
    """Get N distinct physical nodes for replication."""
    if not self._ring or replica_count > len(self._nodes):
        return list(self._nodes)
    
    replicas: list[str] = []
    seen_nodes: set[str] = set()
    
    pos = self._hash(key)
    start_idx = bisect_right(self._ring, pos) % len(self._ring)
    
    idx = start_idx
    while len(replicas) < replica_count:
        node = self._position_to_node[self._ring[idx]]
        if node not in seen_nodes:
            replicas.append(node)
            seen_nodes.add(node)
        idx = (idx + 1) % len(self._ring)
        
        # Safety check: we've wrapped around
        if idx == start_idx:
            break
    
    return replicas

When a node fails, its keys automatically fail over to the next node in the ring—which already holds replicas if you’re using this pattern. The failed node’s virtual nodes effectively disappear, and the ring self-heals.

Real-World Applications

Cassandra uses consistent hashing as its core partitioning strategy, with vnodes enabled by default (256 per node in recent versions). DynamoDB uses a variant where the ring is divided into fixed partitions that get assigned to nodes, simplifying rebalancing. Memcached clients like libketama implement client-side consistent hashing to distribute keys across a cluster.

Jump consistent hash is worth knowing—it achieves perfect distribution with zero memory overhead but only works when nodes are numbered 0 to N-1 and you’re only adding/removing from the end. Rendezvous hashing (highest random weight) is another alternative that avoids the ring entirely by computing a score for each node and picking the highest.

Testing and Operational Considerations

Before deploying, validate that your implementation actually distributes keys evenly.

def test_distribution_uniformity(ring: ConsistentHashRing, num_keys: int = 100000):
    """Statistical test for even key distribution."""
    from collections import Counter
    
    distribution = Counter(
        ring.get_node(f"key-{i}") for i in range(num_keys)
    )
    
    expected = num_keys / ring.node_count
    max_deviation = max(
        abs(count - expected) / expected 
        for count in distribution.values()
    )
    
    print(f"Keys per node: {dict(distribution)}")
    print(f"Expected per node: {expected:.0f}")
    print(f"Max deviation: {max_deviation:.1%}")
    
    # With 150 vnodes, deviation should be under 10%
    assert max_deviation < 0.15, f"Distribution too uneven: {max_deviation:.1%}"

# Test it
ring = ConsistentHashRing(vnodes_per_node=150)
for i in range(5):
    ring.add_node(f"cache-{i}")
test_distribution_uniformity(ring)

In production, monitor the actual request distribution across nodes, not just theoretical key distribution. Hot keys (that one celebrity’s profile everyone’s viewing) will create hotspots regardless of your hashing strategy. Consider adding a small in-memory LRU in front of your distributed cache for these cases.

When adding nodes, do it during low-traffic periods and monitor your database load. Even with consistent hashing, you’ll see a temporary spike in cache misses. Some teams pre-warm new nodes by copying data from neighbors before adding them to the ring—a worthwhile optimization for large clusters.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.