How to Use Transfer Learning in TensorFlow

Transfer learning is the practice of taking a model trained on one task and repurposing it for a different but related task. Instead of training a neural network from scratch with randomly...

Key Insights

  • Transfer learning lets you leverage pre-trained models to build accurate classifiers with 10-100x less training data and time than training from scratch
  • The two main approaches are feature extraction (freezing all base layers) and fine-tuning (unfreezing top layers with a low learning rate)
  • Matching your preprocessing pipeline to the pre-trained model’s original training setup is critical—mismatched preprocessing is the most common failure point

Introduction to Transfer Learning

Transfer learning is the practice of taking a model trained on one task and repurposing it for a different but related task. Instead of training a neural network from scratch with randomly initialized weights, you start with weights learned from a massive dataset like ImageNet (14 million images, 1000 categories) and adapt them to your specific problem.

The value proposition is compelling: you can build a production-quality image classifier with just a few hundred examples per class instead of tens of thousands. Training time drops from days to minutes. For small teams and startups, this is the difference between “possible” and “impossible.”

Transfer learning dominates computer vision applications—image classification, object detection, segmentation—but also powers NLP tasks through models like BERT and GPT. This article focuses on the computer vision use case with TensorFlow and Keras.

Understanding Pre-trained Models in TensorFlow

TensorFlow provides pre-trained models through two main channels: tf.keras.applications and TensorFlow Hub. For most use cases, Keras Applications is simpler and sufficient.

Popular architectures include:

  • MobileNetV2/V3: Lightweight models optimized for mobile and edge devices. Good accuracy-to-size ratio.
  • ResNet50/101/152: Deep residual networks. ResNet50 is a solid general-purpose choice.
  • EfficientNetB0-B7: State-of-the-art accuracy with efficient scaling. EfficientNetB0 is excellent for starting out.
  • VGG16/19: Older but simple architecture. Large model size but good for understanding fundamentals.

Choose based on your constraints. For mobile deployment, use MobileNet. For maximum accuracy with compute budget, use EfficientNet. For quick prototyping, ResNet50 is the industry standard.

Here’s how to load a pre-trained model:

import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2

# Load MobileNetV2 with ImageNet weights
base_model = MobileNetV2(
    input_shape=(224, 224, 3),
    include_top=False,  # Exclude final classification layer
    weights='imagenet'
)

print(f"Model has {len(base_model.layers)} layers")

The include_top=False parameter is crucial—it removes the final 1000-class ImageNet classification layer, letting you add your own custom classifier.

Feature Extraction Approach

Feature extraction treats the pre-trained model as a fixed feature extractor. You freeze all base model layers so their weights don’t update during training, then train only your custom classification layers on top.

This approach works well when:

  • You have limited training data (hundreds, not thousands of examples)
  • Your new task is similar to the original task (e.g., both are general image classification)
  • You need fast training

Here’s a complete implementation:

from tensorflow.keras import layers, models
from tensorflow.keras.applications import MobileNetV2

# Load base model
base_model = MobileNetV2(
    input_shape=(224, 224, 3),
    include_top=False,
    weights='imagenet'
)

# Freeze all base model layers
base_model.trainable = False

# Build new model
model = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(128, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(10, activation='softmax')  # 10 classes
])

# Compile
model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

model.summary()

The GlobalAveragePooling2D layer converts the base model’s output (a 7×7×1280 tensor for MobileNetV2) into a 1280-dimensional vector. This is more parameter-efficient than flattening.

Train this model normally on your dataset. The base model acts as a sophisticated feature extractor, and you’re only learning the final classification mapping.

Fine-tuning Approach

Fine-tuning goes one step further. After training the custom top layers, you unfreeze some of the base model’s top layers and continue training with a very low learning rate. This allows the pre-trained features to adapt slightly to your specific domain.

Fine-tuning works best when:

  • You have more training data (thousands of examples)
  • Your task differs somewhat from ImageNet (e.g., medical images, satellite imagery)
  • You’ve already done feature extraction training

Critical rule: always use a low learning rate (10-100x lower than initial training) to avoid destroying the pre-trained weights.

# First, train with frozen base (feature extraction)
# ... training code here ...

# Now unfreeze the base model
base_model.trainable = True

# Freeze all layers except the last 20
for layer in base_model.layers[:-20]:
    layer.trainable = False

print(f"Trainable layers: {len([l for l in base_model.layers if l.trainable])}")

# Recompile with low learning rate
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),  # Very low LR
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Continue training
history_fine = model.fit(
    train_dataset,
    epochs=10,
    validation_data=val_dataset
)

The choice of how many layers to unfreeze is empirical. Start with the top 10-20% of layers. Earlier layers learn general features (edges, textures) that transfer well; later layers learn task-specific features that benefit from adaptation.

Practical Implementation: Image Classification Example

Let’s build a complete flower classifier using transfer learning. This example uses TensorFlow’s built-in flower dataset.

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import EfficientNetB0
import tensorflow_datasets as tfds

# Load dataset
(train_ds, val_ds), info = tfds.load(
    'tf_flowers',
    split=['train[:80%]', 'train[80%:]'],
    as_supervised=True,
    with_info=True
)

num_classes = info.features['label'].num_classes
IMG_SIZE = 224

# Preprocessing function
def preprocess(image, label):
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    image = tf.keras.applications.efficientnet.preprocess_input(image)
    return image, tf.one_hot(label, num_classes)

# Augmentation for training
def augment(image, label):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, 0.2)
    return image, label

# Prepare datasets
train_ds = (train_ds
    .map(preprocess)
    .map(augment)
    .shuffle(1000)
    .batch(32)
    .prefetch(tf.data.AUTOTUNE))

val_ds = (val_ds
    .map(preprocess)
    .batch(32)
    .prefetch(tf.data.AUTOTUNE))

# Build model with transfer learning
base_model = EfficientNetB0(
    include_top=False,
    weights='imagenet',
    input_shape=(IMG_SIZE, IMG_SIZE, 3)
)
base_model.trainable = False

model = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(128, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(num_classes, activation='softmax')
])

model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Train
history = model.fit(
    train_ds,
    epochs=10,
    validation_data=val_ds
)

Notice the preprocessing function uses efficientnet.preprocess_input. Each model architecture expects specific preprocessing—this is non-negotiable. Using the wrong preprocessing function will tank your accuracy.

With this approach, you’ll typically achieve 85-90% validation accuracy on the flowers dataset in under 10 minutes on a GPU. Training from scratch would require hours and likely achieve worse results.

Best Practices and Common Pitfalls

Use the correct preprocessing function. Every pre-trained model expects inputs in a specific range and format. MobileNet expects [-1, 1], ResNet expects ImageNet mean subtraction. Always use the model-specific preprocessing function from tf.keras.applications.<model_name>.preprocess_input.

Start with feature extraction, then fine-tune. Don’t skip the initial frozen training phase. The custom top layers need to learn reasonable weights before you start modifying the base model, or you’ll get unstable gradients.

Use very low learning rates for fine-tuning. I typically use 1e-5 or 1e-6. Higher learning rates destroy pre-trained features.

Match input sizes. Most models were trained on 224×224 images. While they can handle other sizes, you’ll get best results matching the original training resolution.

Watch for overfitting. Transfer learning models can overfit quickly on small datasets. Use dropout, data augmentation, and early stopping. Monitor the gap between training and validation accuracy.

Don’t unfreeze batch normalization layers. If you’re selectively unfreezing layers, keep batch normalization layers frozen. Updating batch norm statistics with a small batch size causes instability.

Conclusion

Transfer learning is the default approach for computer vision tasks. Unless you have millions of labeled examples and unlimited compute, you should start with a pre-trained model.

The typical workflow: start with feature extraction using a frozen base model, train until validation loss plateaus, then optionally fine-tune the top layers with a low learning rate for an additional 1-3% accuracy boost.

Expected gains: 10-30% higher accuracy compared to training from scratch, 10-100x less training data required, and 10-50x faster training time. For a small team, this often means the difference between shipping a feature and abandoning it.

Next steps: experiment with different base models on your specific dataset. Try EfficientNetB0 as your starting point, then scale up to B1-B3 if you need more accuracy and have the compute budget. Build your data pipeline carefully—good data augmentation matters more than model architecture for small datasets.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.