Design a Key-Value Store: Distributed NoSQL Database

A distributed key-value store is the backbone of modern infrastructure. From caching layers to session storage to configuration management, these systems handle billions of operations daily at...

Key Insights

  • Consistent hashing with virtual nodes distributes data evenly across nodes while minimizing data movement during cluster scaling—aim for 100-200 virtual nodes per physical node for balanced distribution.
  • Quorum-based replication (W + R > N) lets you tune consistency vs. availability per operation; use W=1, R=1 for speed or W=N, R=1 for strong consistency.
  • LSM-trees optimize for write-heavy workloads by buffering writes in memory and flushing sorted data to disk, trading read latency for write throughput.

Introduction & Requirements

A distributed key-value store is the backbone of modern infrastructure. From caching layers to session storage to configuration management, these systems handle billions of operations daily at companies like Amazon (DynamoDB), Meta (RocksDB), and countless startups.

Building one teaches you fundamental distributed systems concepts: partitioning, replication, consistency, and failure handling. Let’s design a production-grade key-value store from scratch.

Functional Requirements:

  • CRUD operations: get(key), put(key, value), delete(key)
  • TTL support for automatic expiration
  • Range queries on keys (optional, adds complexity)

Non-Functional Requirements:

  • High availability: survive node failures without downtime
  • Horizontal scalability: add nodes to handle more data/traffic
  • Low latency: sub-millisecond reads, single-digit millisecond writes
  • Tunable consistency: eventual or strong, per-operation

We’re targeting a system that handles 100K+ operations per second across a cluster of commodity machines.

Data Partitioning with Consistent Hashing

Naive modulo hashing (hash(key) % num_nodes) breaks catastrophically when nodes change—nearly all keys get remapped. Consistent hashing solves this by mapping both keys and nodes onto a ring, where each key belongs to the first node clockwise from its position.

Virtual nodes (vnodes) solve the uneven distribution problem. Instead of one position per physical node, we create 100-200 virtual positions per node, spreading the load evenly.

import hashlib
from bisect import bisect_right
from typing import Dict, List, Optional

class ConsistentHashRing:
    def __init__(self, virtual_nodes: int = 150):
        self.virtual_nodes = virtual_nodes
        self.ring: List[int] = []  # Sorted hash positions
        self.ring_to_node: Dict[int, str] = {}  # Hash -> physical node
        self.nodes: set = set()
    
    def _hash(self, key: str) -> int:
        return int(hashlib.sha256(key.encode()).hexdigest(), 16)
    
    def add_node(self, node: str) -> List[str]:
        """Add a node, returns keys that need migration."""
        self.nodes.add(node)
        for i in range(self.virtual_nodes):
            vnode_key = f"{node}:vnode:{i}"
            hash_val = self._hash(vnode_key)
            self.ring_to_node[hash_val] = node
            self.ring.append(hash_val)
        self.ring.sort()
        return []  # In practice, return affected key ranges
    
    def remove_node(self, node: str) -> None:
        self.nodes.discard(node)
        to_remove = [h for h, n in self.ring_to_node.items() if n == node]
        for h in to_remove:
            del self.ring_to_node[h]
            self.ring.remove(h)
    
    def get_node(self, key: str) -> Optional[str]:
        if not self.ring:
            return None
        hash_val = self._hash(key)
        idx = bisect_right(self.ring, hash_val) % len(self.ring)
        return self.ring_to_node[self.ring[idx]]
    
    def get_replica_nodes(self, key: str, n: int = 3) -> List[str]:
        """Get N distinct physical nodes for replication."""
        if len(self.nodes) < n:
            return list(self.nodes)
        
        hash_val = self._hash(key)
        idx = bisect_right(self.ring, hash_val) % len(self.ring)
        
        replicas = []
        seen = set()
        while len(replicas) < n:
            node = self.ring_to_node[self.ring[idx]]
            if node not in seen:
                replicas.append(node)
                seen.add(node)
            idx = (idx + 1) % len(self.ring)
        return replicas

When a node joins or leaves, only keys in adjacent ring segments migrate—roughly 1/N of total keys instead of nearly all.

Replication & Consistency Models

We replicate each key to N nodes (typically 3) for durability and availability. The consistency model determines how reads and writes coordinate across replicas.

Quorum-based consistency uses configurable W (write) and R (read) values:

  • W + R > N: guarantees reading the latest write (strong consistency)
  • W = 1, R = 1: fastest but eventually consistent
  • W = N, R = 1: strong consistency, slower writes, fast reads
from dataclasses import dataclass
from enum import Enum
from typing import List, Tuple, Any
import time
from concurrent.futures import ThreadPoolExecutor, as_completed

class ConsistencyLevel(Enum):
    ONE = 1          # W=1, R=1 - fastest, eventual consistency
    QUORUM = 2       # W=2, R=2 - balanced (for N=3)
    ALL = 3          # W=N, R=N - strongest, slowest

@dataclass
class VersionedValue:
    value: Any
    timestamp: int  # Lamport timestamp or wall clock
    node_id: str

class QuorumCoordinator:
    def __init__(self, replication_factor: int = 3):
        self.n = replication_factor
        self.executor = ThreadPoolExecutor(max_workers=10)
    
    def _get_quorum_size(self, level: ConsistencyLevel) -> int:
        if level == ConsistencyLevel.ONE:
            return 1
        elif level == ConsistencyLevel.QUORUM:
            return (self.n // 2) + 1
        return self.n
    
    def write(self, key: str, value: Any, replicas: List[str], 
              level: ConsistencyLevel) -> bool:
        w = self._get_quorum_size(level)
        timestamp = int(time.time() * 1000000)  # Microseconds
        
        futures = [
            self.executor.submit(self._write_to_replica, node, key, value, timestamp)
            for node in replicas
        ]
        
        successes = 0
        for future in as_completed(futures, timeout=5.0):
            if future.result():
                successes += 1
                if successes >= w:
                    return True  # Quorum achieved
        return False
    
    def read(self, key: str, replicas: List[str], 
             level: ConsistencyLevel) -> Tuple[Any, bool]:
        r = self._get_quorum_size(level)
        
        futures = [
            self.executor.submit(self._read_from_replica, node, key)
            for node in replicas
        ]
        
        responses: List[VersionedValue] = []
        for future in as_completed(futures, timeout=5.0):
            result = future.result()
            if result:
                responses.append(result)
                if len(responses) >= r:
                    break
        
        if len(responses) < r:
            return None, False
        
        # Last-write-wins conflict resolution
        latest = max(responses, key=lambda v: v.timestamp)
        return latest.value, True
    
    def _write_to_replica(self, node: str, key: str, 
                          value: Any, ts: int) -> bool:
        # Actual RPC call to replica node
        pass
    
    def _read_from_replica(self, node: str, key: str) -> VersionedValue:
        # Actual RPC call to replica node
        pass

For conflict resolution, last-write-wins (LWW) using timestamps is simple but can lose updates. Vector clocks preserve causality but add complexity. Choose based on your consistency requirements.

Storage Engine Design

LSM-trees (Log-Structured Merge-trees) dominate write-heavy key-value stores. The architecture: buffer writes in an in-memory sorted structure (memtable), flush to immutable sorted files (SSTables) when full, and periodically merge SSTables to reclaim space.

import os
import json
import threading
from collections import OrderedDict
from typing import Optional, Dict, Any

class MemTable:
    def __init__(self, max_size: int = 4 * 1024 * 1024):  # 4MB default
        self.data: Dict[str, Any] = {}
        self.size = 0
        self.max_size = max_size
        self.lock = threading.RLock()
    
    def put(self, key: str, value: Any) -> bool:
        with self.lock:
            entry_size = len(key) + len(str(value))
            self.data[key] = value
            self.size += entry_size
            return self.size >= self.max_size
    
    def get(self, key: str) -> Optional[Any]:
        with self.lock:
            return self.data.get(key)
    
    def to_sorted_items(self):
        with self.lock:
            return sorted(self.data.items())

class SSTable:
    def __init__(self, filepath: str):
        self.filepath = filepath
        self.index: Dict[str, int] = {}  # Key -> file offset
        self._load_index()
    
    def _load_index(self):
        index_path = f"{self.filepath}.idx"
        if os.path.exists(index_path):
            with open(index_path, 'r') as f:
                self.index = json.load(f)
    
    def get(self, key: str) -> Optional[Any]:
        if key not in self.index:
            return None
        with open(self.filepath, 'r') as f:
            f.seek(self.index[key])
            line = f.readline()
            k, v = json.loads(line)
            return v if k == key else None

class LSMStorage:
    def __init__(self, data_dir: str):
        self.data_dir = data_dir
        self.wal_path = os.path.join(data_dir, "wal.log")
        self.memtable = MemTable()
        self.immutable_memtables: list = []
        self.sstables: List[SSTable] = []
        self.lock = threading.RLock()
        os.makedirs(data_dir, exist_ok=True)
    
    def put(self, key: str, value: Any) -> None:
        # Write-ahead log for durability
        self._append_wal(key, value)
        
        with self.lock:
            should_flush = self.memtable.put(key, value)
            if should_flush:
                self._flush_memtable()
    
    def get(self, key: str) -> Optional[Any]:
        # Check memtable first (most recent)
        result = self.memtable.get(key)
        if result is not None:
            return result
        
        # Check immutable memtables
        for imm in reversed(self.immutable_memtables):
            result = imm.get(key)
            if result is not None:
                return result
        
        # Check SSTables (newest first)
        for sstable in reversed(self.sstables):
            result = sstable.get(key)
            if result is not None:
                return result
        
        return None
    
    def _append_wal(self, key: str, value: Any) -> None:
        with open(self.wal_path, 'a') as f:
            f.write(json.dumps([key, value]) + '\n')
            f.flush()
            os.fsync(f.fileno())
    
    def _flush_memtable(self) -> None:
        old_memtable = self.memtable
        self.memtable = MemTable()
        self.immutable_memtables.append(old_memtable)
        
        # Flush to SSTable in background
        threading.Thread(target=self._write_sstable, 
                        args=(old_memtable,)).start()
    
    def _write_sstable(self, memtable: MemTable) -> None:
        timestamp = int(time.time() * 1000)
        filepath = os.path.join(self.data_dir, f"sst_{timestamp}.dat")
        index = {}
        
        with open(filepath, 'w') as f:
            for key, value in memtable.to_sorted_items():
                index[key] = f.tell()
                f.write(json.dumps([key, value]) + '\n')
        
        with open(f"{filepath}.idx", 'w') as f:
            json.dump(index, f)
        
        with self.lock:
            self.sstables.append(SSTable(filepath))
            self.immutable_memtables.remove(memtable)

The WAL ensures durability—if the node crashes, replay the log to recover uncommitted writes.

Failure Detection & Cluster Membership

Gossip protocols detect failures without a single point of failure. Each node periodically exchanges state with random peers, propagating membership changes exponentially fast.

import random
import time
from dataclasses import dataclass, field
from typing import Dict, Set

@dataclass
class NodeState:
    node_id: str
    status: str  # "alive", "suspect", "dead"
    heartbeat: int = 0
    last_update: float = field(default_factory=time.time)

class GossipProtocol:
    def __init__(self, node_id: str, gossip_interval: float = 1.0):
        self.node_id = node_id
        self.gossip_interval = gossip_interval
        self.members: Dict[str, NodeState] = {
            node_id: NodeState(node_id, "alive")
        }
        self.suspect_timeout = 5.0
        self.dead_timeout = 15.0
    
    def heartbeat(self) -> None:
        """Increment local heartbeat counter."""
        self.members[self.node_id].heartbeat += 1
        self.members[self.node_id].last_update = time.time()
    
    def get_gossip_targets(self, count: int = 3) -> list:
        """Select random nodes to gossip with."""
        candidates = [n for n in self.members.keys() if n != self.node_id]
        return random.sample(candidates, min(count, len(candidates)))
    
    def prepare_gossip_message(self) -> Dict[str, NodeState]:
        """Prepare state to send to peers."""
        return dict(self.members)
    
    def merge_gossip(self, remote_state: Dict[str, NodeState]) -> None:
        """Merge received gossip with local state."""
        for node_id, remote in remote_state.items():
            local = self.members.get(node_id)
            
            if local is None:
                # New node discovered
                self.members[node_id] = remote
            elif remote.heartbeat > local.heartbeat:
                # Remote has newer info
                self.members[node_id] = remote
                self.members[node_id].last_update = time.time()
    
    def detect_failures(self) -> Set[str]:
        """Check for suspected/dead nodes."""
        now = time.time()
        failed = set()
        
        for node_id, state in self.members.items():
            if node_id == self.node_id:
                continue
            
            age = now - state.last_update
            if age > self.dead_timeout:
                state.status = "dead"
                failed.add(node_id)
            elif age > self.suspect_timeout:
                state.status = "suspect"
        
        return failed

When a node fails, hinted handoff temporarily stores its writes on another node. When the failed node recovers, hints are replayed. For permanent failures, anti-entropy repair using Merkle trees efficiently synchronizes data between replicas.

Client Interface & Request Routing

Clients need to find the right nodes without becoming a bottleneck. Two approaches: client-side discovery (client maintains cluster topology) or server-side routing (any node can coordinate).

import random
from typing import Optional, List

class KVClient:
    def __init__(self, seed_nodes: List[str], retries: int = 3):
        self.seed_nodes = seed_nodes
        self.known_nodes: List[str] = list(seed_nodes)
        self.retries = retries
        self.hash_ring: Optional[ConsistentHashRing] = None
        self._refresh_topology()
    
    def _refresh_topology(self) -> None:
        """Fetch cluster topology from any available node."""
        for node in self.known_nodes:
            try:
                topology = self._fetch_topology(node)
                self.hash_ring = ConsistentHashRing()
                for n in topology['nodes']:
                    self.hash_ring.add_node(n)
                self.known_nodes = topology['nodes']
                return
            except Exception:
                continue
        raise Exception("Cannot connect to cluster")
    
    def put(self, key: str, value: any, 
            consistency: ConsistencyLevel = ConsistencyLevel.QUORUM) -> bool:
        replicas = self.hash_ring.get_replica_nodes(key, n=3)
        
        for attempt in range(self.retries):
            try:
                coordinator = replicas[attempt % len(replicas)]
                return self._send_write(coordinator, key, value, consistency)
            except Exception:
                if attempt == self.retries - 1:
                    self._refresh_topology()
                    raise
        return False
    
    def get(self, key: str, 
            consistency: ConsistencyLevel = ConsistencyLevel.QUORUM) -> any:
        replicas = self.hash_ring.get_replica_nodes(key, n=3)
        
        for attempt in range(self.retries):
            try:
                coordinator = random.choice(replicas)
                return self._send_read(coordinator, key, consistency)
            except Exception:
                continue
        
        self._refresh_topology()
        raise Exception(f"Failed to read key: {key}")

Performance Optimizations & Trade-offs

Bloom filters eliminate unnecessary disk reads. Before checking SSTables, query a probabilistic filter that tells you if a key might exist (false positives possible) or definitely doesn’t exist.

Caching at multiple levels: OS page cache, application-level block cache, and row cache for hot keys.

Compression (LZ4, Snappy, Zstd) reduces storage and I/O at the cost of CPU. LZ4 offers the best speed/ratio trade-off for most workloads.

CAP theorem trade-offs in practice:

  • Choose CP (consistency + partition tolerance): Use quorum writes, accept higher latency and reduced availability during partitions
  • Choose AP (availability + partition tolerance): Accept eventual consistency, use hinted handoff and read repair

Most systems let you choose per-operation. Use strong consistency for financial transactions, eventual consistency for view counters.

The architecture we’ve built mirrors real systems like DynamoDB, Cassandra, and Riak. The key insight: there’s no perfect distributed system, only trade-offs appropriate for your use case. Start with clear requirements, then make informed decisions at each layer.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.