How to Fine-Tune Pretrained Models in TensorFlow

Transfer learning leverages knowledge from models trained on large datasets to solve related problems with less data and computation. Fine-tuning takes this further by adapting a pretrained model's...

Key Insights

  • Fine-tuning pretrained models reduces training time by 10-100x compared to training from scratch and achieves better performance with limited datasets (often requiring only 10-20% of the original training data)
  • The two-phase training strategy—first training only the custom head with a frozen base, then unfreezing top layers with a 10x lower learning rate—prevents catastrophic forgetting and produces optimal results
  • Layer freezing decisions should be architecture-aware: freeze early layers that learn generic features (edges, textures) and unfreeze later layers that learn task-specific patterns

Introduction to Transfer Learning and Fine-Tuning

Transfer learning leverages knowledge from models trained on large datasets to solve related problems with less data and computation. Fine-tuning takes this further by adapting a pretrained model’s weights to your specific task, rather than using them as fixed feature extractors.

You should fine-tune instead of training from scratch when you have limited training data (typically fewer than 100,000 samples), limited computational resources, or when your task is similar to the pretrained model’s original task. For example, using an ImageNet-trained model for medical image classification makes sense because both involve visual pattern recognition, even though the specific patterns differ.

The benefits are substantial: pretrained models on ImageNet already understand edges, textures, shapes, and object parts. You’re building on millions of dollars of compute time and billions of training examples, focusing your limited resources on task-specific adaptation.

Selecting and Loading a Pretrained Model

TensorFlow provides two primary sources for pretrained models: tf.keras.applications and TensorFlow Hub. Keras Applications offers well-established architectures with ImageNet weights, while TensorFlow Hub provides a broader range of models for various tasks.

Choose your architecture based on your constraints:

  • MobileNetV2/V3: Mobile and edge deployment, real-time inference
  • ResNet50/101: Balanced accuracy and speed, good general-purpose choice
  • EfficientNetB0-B7: Best accuracy-to-computation ratio, scalable across resource budgets
  • InceptionV3: Strong performance on diverse image types

Here’s how to load a pretrained model:

import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2

# Load MobileNetV2 with ImageNet weights, excluding top classification layer
base_model = MobileNetV2(
    input_shape=(224, 224, 3),
    include_top=False,  # Exclude the ImageNet classifier
    weights='imagenet'
)

# Initially freeze all layers in the base model
base_model.trainable = False

print(f"Base model has {len(base_model.layers)} layers")

Setting include_top=False removes the original classification head, allowing you to add your custom layers for your specific task.

Preparing Your Custom Dataset

Pretrained models expect specific input formats. ImageNet models typically require 224x224 or 299x299 pixel images with pixel values in specific ranges. MobileNetV2 expects values in [-1, 1], while ResNet expects ImageNet mean subtraction.

Here’s a complete data pipeline that handles preprocessing:

import tensorflow as tf

def preprocess_image(image, label):
    # Resize to expected dimensions
    image = tf.image.resize(image, [224, 224])
    
    # Apply MobileNetV2-specific preprocessing
    # This scales pixel values to [-1, 1]
    image = tf.keras.applications.mobilenet_v2.preprocess_input(image)
    
    return image, label

def augment_image(image, label):
    # Random horizontal flip
    image = tf.image.random_flip_left_right(image)
    
    # Random brightness adjustment
    image = tf.image.random_brightness(image, max_delta=0.1)
    
    # Random rotation (small angles)
    image = tf.image.rot90(image, k=tf.random.uniform(
        shape=[], minval=0, maxval=4, dtype=tf.int32))
    
    return image, label

# Create dataset from directory
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    'path/to/train',
    image_size=(224, 224),
    batch_size=32,
    label_mode='categorical'
)

# Apply preprocessing and augmentation
train_ds = train_ds.map(augment_image, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)

# Validation dataset without augmentation
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    'path/to/val',
    image_size=(224, 224),
    batch_size=32,
    label_mode='categorical'
)
val_ds = val_ds.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
val_ds = val_ds.prefetch(tf.data.AUTOTUNE)

The prefetch operation ensures data loading doesn’t bottleneck training, and AUTOTUNE lets TensorFlow optimize parallelism automatically.

Freezing and Unfreezing Layers

Layer freezing prevents updating pretrained weights during initial training. This is crucial because randomly initialized top layers produce large gradients that would corrupt pretrained features if allowed to backpropagate.

The strategy: start with all base layers frozen, train only your custom head until convergence, then unfreeze the top layers of the base model for fine-tuning.

# Phase 1: All base layers frozen (already set during model loading)
print(f"Trainable weights: {len(base_model.trainable_weights)}")  # Should be 0

# After initial training, unfreeze the top layers
# For MobileNetV2 (154 layers), unfreeze the last 30-50 layers
base_model.trainable = True

# Freeze all layers except the last 40
for layer in base_model.layers[:-40]:
    layer.trainable = False

print(f"Trainable layers: {sum([1 for layer in base_model.layers if layer.trainable])}")

The number of layers to unfreeze depends on your dataset size and similarity to ImageNet. More data and greater task difference justify unfreezing more layers.

Building the Fine-Tuning Architecture

Your custom head should match your task requirements. For classification, use GlobalAveragePooling2D to reduce spatial dimensions, followed by Dense layers with dropout for regularization.

from tensorflow.keras import layers, models

# Build the complete model
inputs = tf.keras.Input(shape=(224, 224, 3))

# Pretrained base
x = base_model(inputs, training=False)  # training=False for batch norm

# Custom classification head
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(256, activation='relu')(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(10, activation='softmax')(x)  # 10 classes

model = models.Model(inputs, outputs)

model.summary()

The training=False argument during the frozen phase ensures batch normalization layers use their pretrained statistics rather than computing new ones from your small dataset.

Training Strategy and Hyperparameters

The two-phase approach is critical for optimal results. Phase 1 trains only the custom head with a standard learning rate. Phase 2 fine-tunes the entire model with a much lower learning rate to avoid destroying pretrained features.

# Phase 1: Train only the custom head
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

history_phase1 = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10,
    callbacks=[
        tf.keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True),
        tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=2)
    ]
)

# Phase 2: Fine-tune the entire model
# Unfreeze top layers (code from previous section)
base_model.trainable = True
for layer in base_model.layers[:-40]:
    layer.trainable = False

# Recompile with lower learning rate (10x smaller)
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

history_phase2 = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=20,
    callbacks=[
        tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True),
        tf.keras.callbacks.ModelCheckpoint(
            'best_model.h5',
            save_best_only=True,
            monitor='val_accuracy'
        )
    ]
)

The 10x learning rate reduction in phase 2 is crucial. Higher rates will cause catastrophic forgetting, where the model loses its pretrained knowledge.

Evaluation and Best Practices

After training, evaluate thoroughly on a held-out test set. Compare against a baseline (training from scratch or using the pretrained model as a fixed feature extractor) to validate your fine-tuning approach.

# Evaluate on test set
test_ds = tf.keras.preprocessing.image_dataset_from_directory(
    'path/to/test',
    image_size=(224, 224),
    batch_size=32,
    label_mode='categorical'
)
test_ds = test_ds.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)

test_loss, test_accuracy = model.evaluate(test_ds)
print(f"Test accuracy: {test_accuracy:.4f}")

# Save the complete fine-tuned model
model.save('fine_tuned_model.h5')

# Load and use for inference
loaded_model = tf.keras.models.load_model('fine_tuned_model.h5')

# Single image prediction
img = tf.keras.preprocessing.image.load_img('test_image.jpg', target_size=(224, 224))
img_array = tf.keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0)
img_array = tf.keras.applications.mobilenet_v2.preprocess_input(img_array)

predictions = loaded_model.predict(img_array)
predicted_class = tf.argmax(predictions[0]).numpy()

Common pitfalls to avoid:

  1. Forgetting to recompile after unfreezing layers—your optimizer won’t see the new trainable weights
  2. Using the same learning rate for both phases—this destroys pretrained features
  3. Insufficient regularization in custom heads—add dropout and consider L2 regularization
  4. Ignoring preprocessing requirements—each architecture has specific input expectations
  5. Training too long—use early stopping to prevent overfitting on small datasets

Fine-tuning is one of the most practical techniques in deep learning. With these strategies, you can achieve production-quality results on custom tasks with minimal data and compute resources. Start with a frozen base, train conservatively, and monitor validation metrics closely to know when you’ve achieved optimal performance.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.