How to Save and Load Models in Python

Training machine learning models takes time and computational resources. Once you've invested hours or days training a model, you need to save it for later use. Model persistence is the bridge...

Key Insights

  • Use joblib for scikit-learn models, SavedModel format for TensorFlow/Keras, and state dictionaries for PyTorch—each framework has specific best practices that prevent compatibility issues and reduce file sizes.
  • Never trust pickle files from untrusted sources; they can execute arbitrary code during deserialization, making them a serious security vulnerability in production environments.
  • Always save preprocessing pipelines alongside your models and track metadata like hyperparameters and training dates—a model without its preprocessing context is essentially useless in production.

Introduction to Model Persistence

Training machine learning models takes time and computational resources. Once you’ve invested hours or days training a model, you need to save it for later use. Model persistence is the bridge between experimentation and production deployment.

Python offers several serialization formats for saving models. The pickle module provides general-purpose object serialization, while joblib optimizes for large numpy arrays common in ML workflows. Framework-specific formats like TensorFlow’s SavedModel and PyTorch’s state dictionaries offer additional benefits like cross-language compatibility and version safety.

The right choice depends on your framework, deployment requirements, and security considerations. Let’s explore the practical implementation for each major framework.

Saving and Loading with Scikit-learn

For scikit-learn models, joblib is the recommended approach. It handles large numpy arrays more efficiently than pickle and provides better compression for model files.

from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import joblib

# Generate sample data and train model
X, y = make_classification(n_samples=1000, n_features=20, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# Train model
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

# Save model
joblib.dump(model, 'random_forest_model.joblib')

# Load model and make predictions
loaded_model = joblib.load('random_forest_model.joblib')
predictions = loaded_model.predict(X_test)
print(f"Model accuracy: {loaded_model.score(X_test, y_test):.3f}")

While you can use Python’s built-in pickle module, joblib offers better performance for scikit-learn models:

import pickle

# Alternative using pickle (less efficient)
with open('model.pkl', 'wb') as f:
    pickle.dump(model, f)

with open('model.pkl', 'rb') as f:
    loaded_model = pickle.load(f)

Joblib typically produces smaller files and loads faster, especially for models with large coefficient arrays or tree structures.

TensorFlow/Keras Model Persistence

TensorFlow offers multiple saving formats. The SavedModel format is the recommended approach for production deployments, while HDF5 remains popular for its simplicity.

import tensorflow as tf
from tensorflow import keras
import numpy as np

# Create a simple neural network
model = keras.Sequential([
    keras.layers.Dense(64, activation='relu', input_shape=(20,)),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(1, activation='sigmoid')
])

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Generate sample data
X_train = np.random.randn(1000, 20)
y_train = np.random.randint(0, 2, 1000)

model.fit(X_train, y_train, epochs=5, verbose=0)

# Save in SavedModel format (recommended)
model.save('my_model')

# Save in HDF5 format
model.save('my_model.h5')

# Load SavedModel
loaded_model = keras.models.load_model('my_model')

# Load HDF5 model
loaded_h5_model = keras.models.load_model('my_model.h5')

# Make predictions
test_data = np.random.randn(10, 20)
predictions = loaded_model.predict(test_data)

You can also save only the weights if you want to reconstruct the model architecture separately:

# Save only weights
model.save_weights('model_weights.h5')

# Reconstruct model architecture and load weights
new_model = keras.Sequential([
    keras.layers.Dense(64, activation='relu', input_shape=(20,)),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(1, activation='sigmoid')
])
new_model.load_weights('model_weights.h5')

SavedModel format is preferred because it includes the computation graph, making it compatible with TensorFlow Serving and TensorFlow Lite.

PyTorch Model Saving Strategies

PyTorch recommends saving state dictionaries rather than entire model objects. This approach provides better flexibility and reduces compatibility issues across PyTorch versions.

import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple neural network
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(20, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 1)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x

# Train model
model = SimpleNet()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters())

# Save state dictionary (recommended)
torch.save(model.state_dict(), 'model_state_dict.pth')

# Load state dictionary
loaded_model = SimpleNet()
loaded_model.load_state_dict(torch.load('model_state_dict.pth'))
loaded_model.eval()

# Alternative: Save entire model (not recommended)
torch.save(model, 'entire_model.pth')
loaded_entire_model = torch.load('entire_model.pth')

For production scenarios, save additional training information:

# Save checkpoint with training state
checkpoint = {
    'epoch': 10,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': 0.25,
}
torch.save(checkpoint, 'checkpoint.pth')

# Resume training
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']

Handling Model Metadata and Versioning

Models rarely work in isolation. You need to save preprocessing pipelines, feature engineering steps, and metadata for reproducibility.

from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.ensemble import GradientBoostingClassifier
import joblib
import json
from datetime import datetime

# Create a complete pipeline
pipeline = Pipeline([
    ('scaler', StandardScaler()),
    ('classifier', GradientBoostingClassifier(n_estimators=100, random_state=42))
])

# Train pipeline
X_train = np.random.randn(1000, 20)
y_train = np.random.randint(0, 2, 1000)
pipeline.fit(X_train, y_train)

# Save pipeline
joblib.dump(pipeline, 'complete_pipeline.joblib')

# Save metadata separately
metadata = {
    'model_type': 'GradientBoostingClassifier',
    'training_date': datetime.now().isoformat(),
    'n_features': X_train.shape[1],
    'n_samples': X_train.shape[0],
    'hyperparameters': {
        'n_estimators': 100,
        'random_state': 42
    },
    'preprocessing': 'StandardScaler',
    'accuracy': pipeline.score(X_train, y_train)
}

with open('model_metadata.json', 'w') as f:
    json.dump(metadata, f, indent=2)

# Load everything back
loaded_pipeline = joblib.load('complete_pipeline.joblib')
with open('model_metadata.json', 'r') as f:
    loaded_metadata = json.load(f)

print(f"Model trained on: {loaded_metadata['training_date']}")
print(f"Training accuracy: {loaded_metadata['accuracy']:.3f}")

Common Pitfalls and Best Practices

Security risks with pickle: Pickle can execute arbitrary code during deserialization. Never load pickle files from untrusted sources. In production, use framework-specific formats or validate file integrity with checksums.

Version compatibility: Models saved with one library version may not load with another. Always document your environment:

import sklearn
import joblib

# Save version information
versions = {
    'sklearn': sklearn.__version__,
    'joblib': joblib.__version__,
}

with open('versions.json', 'w') as f:
    json.dump(versions, f)

File size optimization: Large models can be compressed:

# Compress when saving
joblib.dump(model, 'model.joblib', compress=3)  # compression level 0-9

Checkpointing during training: Save models periodically to avoid losing progress:

# Keras callback for checkpointing
checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath='model_checkpoint_{epoch:02d}.h5',
    save_freq='epoch',
    save_best_only=True,
    monitor='val_loss'
)

model.fit(X_train, y_train, epochs=50, callbacks=[checkpoint_callback])

Production Considerations

For production deployment, consider cross-platform formats like ONNX (Open Neural Network Exchange):

import torch
import torch.onnx

# Convert PyTorch model to ONNX
model = SimpleNet()
dummy_input = torch.randn(1, 20)

torch.onnx.export(
    model,
    dummy_input,
    'model.onnx',
    export_params=True,
    opset_version=11,
    input_names=['input'],
    output_names=['output']
)

# Load and run inference with ONNX Runtime
import onnxruntime as ort

session = ort.InferenceSession('model.onnx')
input_name = session.get_inputs()[0].name
result = session.run(None, {input_name: dummy_input.numpy()})

For cloud deployment, integrate with object storage:

import boto3

# Save to S3
s3_client = boto3.client('s3')
s3_client.upload_file('model.joblib', 'my-bucket', 'models/model.joblib')

# Load from S3
s3_client.download_file('my-bucket', 'models/model.joblib', 'downloaded_model.joblib')
model = joblib.load('downloaded_model.joblib')

Model persistence is fundamental to machine learning workflows. Choose the right format for your framework, always save preprocessing pipelines alongside models, and document your environment for reproducibility. With these practices, your models will transition smoothly from development to production.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.