How to Use Keras Functional API in TensorFlow

The Keras Functional API is TensorFlow's interface for building neural networks with complex topologies. While the Sequential API works well for linear stacks of layers, real-world architectures...

Key Insights

  • The Functional API enables complex architectures like multi-input models, shared layers, and residual connections that are impossible with the Sequential API
  • Models are built by explicitly defining the data flow between layers using function call syntax, giving you complete control over the computation graph
  • Use Functional API when you need anything beyond a simple linear stack of layers—it’s more verbose but vastly more flexible than Sequential models

Introduction to Keras Functional API

The Keras Functional API is TensorFlow’s interface for building neural networks with complex topologies. While the Sequential API works well for linear stacks of layers, real-world architectures often require multiple inputs, branching paths, layer sharing, or multiple outputs. The Functional API treats layers as functions that transform tensors, letting you explicitly define how data flows through your network.

You should reach for the Functional API when building multi-input models (like recommendation systems combining user and item features), multi-output models (like networks predicting multiple attributes simultaneously), models with shared layers (like Siamese networks), or any architecture with non-linear topology like ResNets, Inception modules, or encoder-decoder structures.

The core advantage is flexibility. You define the computation graph explicitly, making complex architectures straightforward to implement while maintaining the simplicity of Keras’s high-level API.

Basic Functional API Syntax

The Functional API follows a consistent pattern: define input layers, chain transformations by calling layers on tensors, and create a Model object by specifying inputs and outputs.

Here’s a simple feedforward network comparing both approaches:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Sequential API
sequential_model = keras.Sequential([
    layers.Dense(64, activation='relu', input_shape=(20,)),
    layers.Dropout(0.5),
    layers.Dense(32, activation='relu'),
    layers.Dense(10, activation='softmax')
])

# Functional API - equivalent architecture
inputs = keras.Input(shape=(20,))
x = layers.Dense(64, activation='relu')(inputs)
x = layers.Dropout(0.5)(x)
x = layers.Dense(32, activation='relu')(x)
outputs = layers.Dense(10, activation='softmax')(x)
functional_model = keras.Model(inputs=inputs, outputs=outputs)

Notice how layers are called as functions on tensor objects. The inputs variable is a symbolic tensor representing the input, and each layer call returns a new tensor. The final Model object connects the input and output tensors, defining the complete computation graph.

Multi-Input Models

Multi-input models are essential when combining heterogeneous data sources. Consider a movie recommendation system that uses both user features and movie features:

# User input branch
user_input = keras.Input(shape=(32,), name='user_features')
user_dense = layers.Dense(64, activation='relu')(user_input)
user_dense = layers.Dense(32, activation='relu')(user_dense)

# Movie input branch
movie_input = keras.Input(shape=(16,), name='movie_features')
movie_dense = layers.Dense(64, activation='relu')(movie_input)
movie_dense = layers.Dense(32, activation='relu')(movie_dense)

# Merge branches
merged = layers.concatenate([user_dense, movie_dense])
merged = layers.Dense(64, activation='relu')(merged)
merged = layers.Dropout(0.3)(merged)
output = layers.Dense(1, activation='sigmoid', name='rating')(merged)

# Create model with multiple inputs
recommendation_model = keras.Model(
    inputs=[user_input, movie_input],
    outputs=output
)

# Training with multiple inputs
recommendation_model.compile(optimizer='adam', loss='binary_crossentropy')
# recommendation_model.fit(
#     [user_data, movie_data],  # List of input arrays
#     ratings,
#     epochs=10
# )

The model processes each input through separate dense layers before merging them with concatenation. You can also use layers.Add(), layers.Multiply(), or layers.Average() for different merging strategies.

Multi-Output Models

Multi-output models predict multiple targets simultaneously, useful for multi-task learning. Here’s an image classifier that predicts both the main category and an auxiliary attribute:

# Shared feature extractor
inputs = keras.Input(shape=(224, 224, 3))
x = layers.Conv2D(32, 3, activation='relu')(inputs)
x = layers.MaxPooling2D(2)(x)
x = layers.Conv2D(64, 3, activation='relu')(x)
x = layers.MaxPooling2D(2)(x)
x = layers.Flatten()(x)
x = layers.Dense(128, activation='relu')(x)

# Main classification head
main_output = layers.Dense(10, activation='softmax', name='main_category')(x)

# Auxiliary attribute head
aux_output = layers.Dense(5, activation='softmax', name='color_attribute')(x)

# Model with multiple outputs
multi_output_model = keras.Model(
    inputs=inputs,
    outputs=[main_output, aux_output]
)

# Compile with different losses for each output
multi_output_model.compile(
    optimizer='adam',
    loss={
        'main_category': 'categorical_crossentropy',
        'color_attribute': 'categorical_crossentropy'
    },
    loss_weights={'main_category': 1.0, 'color_attribute': 0.5}
)

# Training expects a list or dict of outputs
# multi_output_model.fit(
#     images,
#     {'main_category': main_labels, 'color_attribute': color_labels},
#     epochs=10
# )

Loss weights let you balance the importance of each task. The auxiliary task often improves the shared representation even if it’s not your primary goal.

Complex Architectures with Shared Layers

Layer sharing is powerful for Siamese networks and other architectures that process multiple inputs with identical transformations:

# Define a shared convolutional base
def create_conv_base():
    inputs = keras.Input(shape=(28, 28, 1))
    x = layers.Conv2D(32, 3, activation='relu')(inputs)
    x = layers.MaxPooling2D(2)(x)
    x = layers.Conv2D(64, 3, activation='relu')(x)
    x = layers.MaxPooling2D(2)(x)
    x = layers.Flatten()(x)
    outputs = layers.Dense(128, activation='relu')(x)
    return keras.Model(inputs, outputs)

# Create shared layer
shared_conv = create_conv_base()

# Two input branches using the same weights
input_a = keras.Input(shape=(28, 28, 1), name='input_a')
input_b = keras.Input(shape=(28, 28, 1), name='input_b')

# Process both inputs through shared layers
encoded_a = shared_conv(input_a)
encoded_b = shared_conv(input_b)

# Compute similarity
distance = layers.Lambda(
    lambda tensors: tf.abs(tensors[0] - tensors[1])
)([encoded_a, encoded_b])

output = layers.Dense(1, activation='sigmoid')(distance)

siamese_model = keras.Model(
    inputs=[input_a, input_b],
    outputs=output
)

Both inputs pass through the same convolutional layers with shared weights. This ensures the network learns a consistent embedding space for similarity comparison.

Residual Connections and Skip Connections

ResNet-style skip connections are trivial with the Functional API. Here’s a basic residual block:

def residual_block(x, filters):
    # Main path
    shortcut = x
    
    x = layers.Conv2D(filters, 3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    
    x = layers.Conv2D(filters, 3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    
    # Adjust shortcut dimensions if needed
    if shortcut.shape[-1] != filters:
        shortcut = layers.Conv2D(filters, 1, padding='same')(shortcut)
    
    # Add skip connection
    x = layers.Add()([x, shortcut])
    x = layers.Activation('relu')(x)
    
    return x

# Build a mini ResNet
inputs = keras.Input(shape=(32, 32, 3))
x = layers.Conv2D(32, 3, padding='same')(inputs)

# Stack residual blocks
x = residual_block(x, 32)
x = residual_block(x, 64)
x = residual_block(x, 64)

x = layers.GlobalAveragePooling2D()(x)
outputs = layers.Dense(10, activation='softmax')(x)

resnet_model = keras.Model(inputs, outputs)

The Add layer performs element-wise addition, enabling gradient flow through skip connections. This pattern is the foundation of modern deep architectures.

Best Practices and Common Patterns

Always name your inputs and outputs for clarity, especially in multi-input/output models. Use model.summary() to verify your architecture:

model.summary()

For complex models, visualize the graph with plot_model:

keras.utils.plot_model(
    model,
    to_file='model.png',
    show_shapes=True,
    show_layer_names=True,
    rankdir='TB',  # Top to bottom
    expand_nested=True,
    dpi=96
)

When debugging, check tensor shapes at each step. The Functional API makes this explicit—every intermediate tensor has a defined shape you can inspect.

Use the Functional API when you need:

  • Multiple inputs or outputs
  • Shared layers across branches
  • Non-sequential topology (branching, merging, skip connections)
  • Access to intermediate layer outputs

Stick with Sequential for simple feedforward networks. Consider model subclassing only when you need dynamic behavior like conditional execution or Python loops in the forward pass—the Functional API handles 95% of use cases with better readability.

The Functional API strikes the perfect balance between Keras’s simplicity and the flexibility needed for modern architectures. Master it, and you’ll be able to implement virtually any neural network architecture you encounter in research papers.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.