Design a Key-Value Store: Distributed NoSQL Database
A distributed key-value store is the backbone of modern infrastructure. From caching layers to session storage to configuration management, these systems handle billions of operations daily at...
Key Insights
- Consistent hashing with virtual nodes distributes data evenly across nodes while minimizing data movement during cluster scaling—aim for 100-200 virtual nodes per physical node for balanced distribution.
- Quorum-based replication (W + R > N) lets you tune consistency vs. availability per operation; use W=1, R=1 for speed or W=N, R=1 for strong consistency.
- LSM-trees optimize for write-heavy workloads by buffering writes in memory and flushing sorted data to disk, trading read latency for write throughput.
Introduction & Requirements
A distributed key-value store is the backbone of modern infrastructure. From caching layers to session storage to configuration management, these systems handle billions of operations daily at companies like Amazon (DynamoDB), Meta (RocksDB), and countless startups.
Building one teaches you fundamental distributed systems concepts: partitioning, replication, consistency, and failure handling. Let’s design a production-grade key-value store from scratch.
Functional Requirements:
- CRUD operations:
get(key),put(key, value),delete(key) - TTL support for automatic expiration
- Range queries on keys (optional, adds complexity)
Non-Functional Requirements:
- High availability: survive node failures without downtime
- Horizontal scalability: add nodes to handle more data/traffic
- Low latency: sub-millisecond reads, single-digit millisecond writes
- Tunable consistency: eventual or strong, per-operation
We’re targeting a system that handles 100K+ operations per second across a cluster of commodity machines.
Data Partitioning with Consistent Hashing
Naive modulo hashing (hash(key) % num_nodes) breaks catastrophically when nodes change—nearly all keys get remapped. Consistent hashing solves this by mapping both keys and nodes onto a ring, where each key belongs to the first node clockwise from its position.
Virtual nodes (vnodes) solve the uneven distribution problem. Instead of one position per physical node, we create 100-200 virtual positions per node, spreading the load evenly.
import hashlib
from bisect import bisect_right
from typing import Dict, List, Optional
class ConsistentHashRing:
def __init__(self, virtual_nodes: int = 150):
self.virtual_nodes = virtual_nodes
self.ring: List[int] = [] # Sorted hash positions
self.ring_to_node: Dict[int, str] = {} # Hash -> physical node
self.nodes: set = set()
def _hash(self, key: str) -> int:
return int(hashlib.sha256(key.encode()).hexdigest(), 16)
def add_node(self, node: str) -> List[str]:
"""Add a node, returns keys that need migration."""
self.nodes.add(node)
for i in range(self.virtual_nodes):
vnode_key = f"{node}:vnode:{i}"
hash_val = self._hash(vnode_key)
self.ring_to_node[hash_val] = node
self.ring.append(hash_val)
self.ring.sort()
return [] # In practice, return affected key ranges
def remove_node(self, node: str) -> None:
self.nodes.discard(node)
to_remove = [h for h, n in self.ring_to_node.items() if n == node]
for h in to_remove:
del self.ring_to_node[h]
self.ring.remove(h)
def get_node(self, key: str) -> Optional[str]:
if not self.ring:
return None
hash_val = self._hash(key)
idx = bisect_right(self.ring, hash_val) % len(self.ring)
return self.ring_to_node[self.ring[idx]]
def get_replica_nodes(self, key: str, n: int = 3) -> List[str]:
"""Get N distinct physical nodes for replication."""
if len(self.nodes) < n:
return list(self.nodes)
hash_val = self._hash(key)
idx = bisect_right(self.ring, hash_val) % len(self.ring)
replicas = []
seen = set()
while len(replicas) < n:
node = self.ring_to_node[self.ring[idx]]
if node not in seen:
replicas.append(node)
seen.add(node)
idx = (idx + 1) % len(self.ring)
return replicas
When a node joins or leaves, only keys in adjacent ring segments migrate—roughly 1/N of total keys instead of nearly all.
Replication & Consistency Models
We replicate each key to N nodes (typically 3) for durability and availability. The consistency model determines how reads and writes coordinate across replicas.
Quorum-based consistency uses configurable W (write) and R (read) values:
W + R > N: guarantees reading the latest write (strong consistency)W = 1, R = 1: fastest but eventually consistentW = N, R = 1: strong consistency, slower writes, fast reads
from dataclasses import dataclass
from enum import Enum
from typing import List, Tuple, Any
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
class ConsistencyLevel(Enum):
ONE = 1 # W=1, R=1 - fastest, eventual consistency
QUORUM = 2 # W=2, R=2 - balanced (for N=3)
ALL = 3 # W=N, R=N - strongest, slowest
@dataclass
class VersionedValue:
value: Any
timestamp: int # Lamport timestamp or wall clock
node_id: str
class QuorumCoordinator:
def __init__(self, replication_factor: int = 3):
self.n = replication_factor
self.executor = ThreadPoolExecutor(max_workers=10)
def _get_quorum_size(self, level: ConsistencyLevel) -> int:
if level == ConsistencyLevel.ONE:
return 1
elif level == ConsistencyLevel.QUORUM:
return (self.n // 2) + 1
return self.n
def write(self, key: str, value: Any, replicas: List[str],
level: ConsistencyLevel) -> bool:
w = self._get_quorum_size(level)
timestamp = int(time.time() * 1000000) # Microseconds
futures = [
self.executor.submit(self._write_to_replica, node, key, value, timestamp)
for node in replicas
]
successes = 0
for future in as_completed(futures, timeout=5.0):
if future.result():
successes += 1
if successes >= w:
return True # Quorum achieved
return False
def read(self, key: str, replicas: List[str],
level: ConsistencyLevel) -> Tuple[Any, bool]:
r = self._get_quorum_size(level)
futures = [
self.executor.submit(self._read_from_replica, node, key)
for node in replicas
]
responses: List[VersionedValue] = []
for future in as_completed(futures, timeout=5.0):
result = future.result()
if result:
responses.append(result)
if len(responses) >= r:
break
if len(responses) < r:
return None, False
# Last-write-wins conflict resolution
latest = max(responses, key=lambda v: v.timestamp)
return latest.value, True
def _write_to_replica(self, node: str, key: str,
value: Any, ts: int) -> bool:
# Actual RPC call to replica node
pass
def _read_from_replica(self, node: str, key: str) -> VersionedValue:
# Actual RPC call to replica node
pass
For conflict resolution, last-write-wins (LWW) using timestamps is simple but can lose updates. Vector clocks preserve causality but add complexity. Choose based on your consistency requirements.
Storage Engine Design
LSM-trees (Log-Structured Merge-trees) dominate write-heavy key-value stores. The architecture: buffer writes in an in-memory sorted structure (memtable), flush to immutable sorted files (SSTables) when full, and periodically merge SSTables to reclaim space.
import os
import json
import threading
from collections import OrderedDict
from typing import Optional, Dict, Any
class MemTable:
def __init__(self, max_size: int = 4 * 1024 * 1024): # 4MB default
self.data: Dict[str, Any] = {}
self.size = 0
self.max_size = max_size
self.lock = threading.RLock()
def put(self, key: str, value: Any) -> bool:
with self.lock:
entry_size = len(key) + len(str(value))
self.data[key] = value
self.size += entry_size
return self.size >= self.max_size
def get(self, key: str) -> Optional[Any]:
with self.lock:
return self.data.get(key)
def to_sorted_items(self):
with self.lock:
return sorted(self.data.items())
class SSTable:
def __init__(self, filepath: str):
self.filepath = filepath
self.index: Dict[str, int] = {} # Key -> file offset
self._load_index()
def _load_index(self):
index_path = f"{self.filepath}.idx"
if os.path.exists(index_path):
with open(index_path, 'r') as f:
self.index = json.load(f)
def get(self, key: str) -> Optional[Any]:
if key not in self.index:
return None
with open(self.filepath, 'r') as f:
f.seek(self.index[key])
line = f.readline()
k, v = json.loads(line)
return v if k == key else None
class LSMStorage:
def __init__(self, data_dir: str):
self.data_dir = data_dir
self.wal_path = os.path.join(data_dir, "wal.log")
self.memtable = MemTable()
self.immutable_memtables: list = []
self.sstables: List[SSTable] = []
self.lock = threading.RLock()
os.makedirs(data_dir, exist_ok=True)
def put(self, key: str, value: Any) -> None:
# Write-ahead log for durability
self._append_wal(key, value)
with self.lock:
should_flush = self.memtable.put(key, value)
if should_flush:
self._flush_memtable()
def get(self, key: str) -> Optional[Any]:
# Check memtable first (most recent)
result = self.memtable.get(key)
if result is not None:
return result
# Check immutable memtables
for imm in reversed(self.immutable_memtables):
result = imm.get(key)
if result is not None:
return result
# Check SSTables (newest first)
for sstable in reversed(self.sstables):
result = sstable.get(key)
if result is not None:
return result
return None
def _append_wal(self, key: str, value: Any) -> None:
with open(self.wal_path, 'a') as f:
f.write(json.dumps([key, value]) + '\n')
f.flush()
os.fsync(f.fileno())
def _flush_memtable(self) -> None:
old_memtable = self.memtable
self.memtable = MemTable()
self.immutable_memtables.append(old_memtable)
# Flush to SSTable in background
threading.Thread(target=self._write_sstable,
args=(old_memtable,)).start()
def _write_sstable(self, memtable: MemTable) -> None:
timestamp = int(time.time() * 1000)
filepath = os.path.join(self.data_dir, f"sst_{timestamp}.dat")
index = {}
with open(filepath, 'w') as f:
for key, value in memtable.to_sorted_items():
index[key] = f.tell()
f.write(json.dumps([key, value]) + '\n')
with open(f"{filepath}.idx", 'w') as f:
json.dump(index, f)
with self.lock:
self.sstables.append(SSTable(filepath))
self.immutable_memtables.remove(memtable)
The WAL ensures durability—if the node crashes, replay the log to recover uncommitted writes.
Failure Detection & Cluster Membership
Gossip protocols detect failures without a single point of failure. Each node periodically exchanges state with random peers, propagating membership changes exponentially fast.
import random
import time
from dataclasses import dataclass, field
from typing import Dict, Set
@dataclass
class NodeState:
node_id: str
status: str # "alive", "suspect", "dead"
heartbeat: int = 0
last_update: float = field(default_factory=time.time)
class GossipProtocol:
def __init__(self, node_id: str, gossip_interval: float = 1.0):
self.node_id = node_id
self.gossip_interval = gossip_interval
self.members: Dict[str, NodeState] = {
node_id: NodeState(node_id, "alive")
}
self.suspect_timeout = 5.0
self.dead_timeout = 15.0
def heartbeat(self) -> None:
"""Increment local heartbeat counter."""
self.members[self.node_id].heartbeat += 1
self.members[self.node_id].last_update = time.time()
def get_gossip_targets(self, count: int = 3) -> list:
"""Select random nodes to gossip with."""
candidates = [n for n in self.members.keys() if n != self.node_id]
return random.sample(candidates, min(count, len(candidates)))
def prepare_gossip_message(self) -> Dict[str, NodeState]:
"""Prepare state to send to peers."""
return dict(self.members)
def merge_gossip(self, remote_state: Dict[str, NodeState]) -> None:
"""Merge received gossip with local state."""
for node_id, remote in remote_state.items():
local = self.members.get(node_id)
if local is None:
# New node discovered
self.members[node_id] = remote
elif remote.heartbeat > local.heartbeat:
# Remote has newer info
self.members[node_id] = remote
self.members[node_id].last_update = time.time()
def detect_failures(self) -> Set[str]:
"""Check for suspected/dead nodes."""
now = time.time()
failed = set()
for node_id, state in self.members.items():
if node_id == self.node_id:
continue
age = now - state.last_update
if age > self.dead_timeout:
state.status = "dead"
failed.add(node_id)
elif age > self.suspect_timeout:
state.status = "suspect"
return failed
When a node fails, hinted handoff temporarily stores its writes on another node. When the failed node recovers, hints are replayed. For permanent failures, anti-entropy repair using Merkle trees efficiently synchronizes data between replicas.
Client Interface & Request Routing
Clients need to find the right nodes without becoming a bottleneck. Two approaches: client-side discovery (client maintains cluster topology) or server-side routing (any node can coordinate).
import random
from typing import Optional, List
class KVClient:
def __init__(self, seed_nodes: List[str], retries: int = 3):
self.seed_nodes = seed_nodes
self.known_nodes: List[str] = list(seed_nodes)
self.retries = retries
self.hash_ring: Optional[ConsistentHashRing] = None
self._refresh_topology()
def _refresh_topology(self) -> None:
"""Fetch cluster topology from any available node."""
for node in self.known_nodes:
try:
topology = self._fetch_topology(node)
self.hash_ring = ConsistentHashRing()
for n in topology['nodes']:
self.hash_ring.add_node(n)
self.known_nodes = topology['nodes']
return
except Exception:
continue
raise Exception("Cannot connect to cluster")
def put(self, key: str, value: any,
consistency: ConsistencyLevel = ConsistencyLevel.QUORUM) -> bool:
replicas = self.hash_ring.get_replica_nodes(key, n=3)
for attempt in range(self.retries):
try:
coordinator = replicas[attempt % len(replicas)]
return self._send_write(coordinator, key, value, consistency)
except Exception:
if attempt == self.retries - 1:
self._refresh_topology()
raise
return False
def get(self, key: str,
consistency: ConsistencyLevel = ConsistencyLevel.QUORUM) -> any:
replicas = self.hash_ring.get_replica_nodes(key, n=3)
for attempt in range(self.retries):
try:
coordinator = random.choice(replicas)
return self._send_read(coordinator, key, consistency)
except Exception:
continue
self._refresh_topology()
raise Exception(f"Failed to read key: {key}")
Performance Optimizations & Trade-offs
Bloom filters eliminate unnecessary disk reads. Before checking SSTables, query a probabilistic filter that tells you if a key might exist (false positives possible) or definitely doesn’t exist.
Caching at multiple levels: OS page cache, application-level block cache, and row cache for hot keys.
Compression (LZ4, Snappy, Zstd) reduces storage and I/O at the cost of CPU. LZ4 offers the best speed/ratio trade-off for most workloads.
CAP theorem trade-offs in practice:
- Choose CP (consistency + partition tolerance): Use quorum writes, accept higher latency and reduced availability during partitions
- Choose AP (availability + partition tolerance): Accept eventual consistency, use hinted handoff and read repair
Most systems let you choose per-operation. Use strong consistency for financial transactions, eventual consistency for view counters.
The architecture we’ve built mirrors real systems like DynamoDB, Cassandra, and Riak. The key insight: there’s no perfect distributed system, only trade-offs appropriate for your use case. Start with clear requirements, then make informed decisions at each layer.