How to Implement Decision Trees in Python

Decision trees are supervised learning algorithms that make predictions by learning a series of if-then-else decision rules from training data. Think of them as flowcharts where each internal node...

Key Insights

  • Decision trees split data recursively based on feature values that maximize information gain or minimize Gini impurity, creating an interpretable model structure that mirrors human decision-making
  • Building a decision tree from scratch requires implementing node splitting logic, recursive tree construction, and stopping criteria—understanding this foundation makes you better at tuning scikit-learn’s implementation
  • Overfitting is decision trees’ biggest weakness; control it through max_depth, min_samples_split, and pruning, or use ensemble methods like Random Forests for production systems

Introduction to Decision Trees

Decision trees are supervised learning algorithms that make predictions by learning a series of if-then-else decision rules from training data. Think of them as flowcharts where each internal node represents a test on a feature, each branch represents the outcome of that test, and each leaf node represents a class label or continuous value.

The algorithm works by recursively partitioning the feature space. Starting at the root node with all training data, it selects the best feature and threshold to split the data into two subsets. This process repeats for each subset until a stopping criterion is met—typically when nodes become pure (contain only one class), reach a minimum size, or hit a maximum depth.

Decision trees handle both classification (predicting discrete labels) and regression (predicting continuous values). Their key advantage is interpretability: you can trace exactly why a model made a specific prediction. This makes them invaluable in healthcare, finance, and other domains where model transparency matters.

Understanding the Mathematics Behind Decision Trees

Decision trees need a metric to evaluate how good a split is. The two most common metrics are Gini impurity and entropy.

Gini impurity measures the probability of incorrectly classifying a randomly chosen element. For a node with classes, it’s calculated as: Gini = 1 - Σ(p_i)², where p_i is the probability of class i. A Gini of 0 means perfect purity (all samples belong to one class).

Entropy measures disorder or uncertainty. It’s calculated as: Entropy = -Σ(p_i × log₂(p_i)). Like Gini, lower values indicate purer nodes.

Information gain is the reduction in entropy after a split. The algorithm selects splits that maximize information gain.

Let’s calculate these manually:

import numpy as np

def gini_impurity(labels):
    """Calculate Gini impurity for a list of labels."""
    if len(labels) == 0:
        return 0
    
    unique, counts = np.unique(labels, return_counts=True)
    probabilities = counts / len(labels)
    return 1 - np.sum(probabilities ** 2)

def entropy(labels):
    """Calculate entropy for a list of labels."""
    if len(labels) == 0:
        return 0
    
    unique, counts = np.unique(labels, return_counts=True)
    probabilities = counts / len(labels)
    return -np.sum(probabilities * np.log2(probabilities + 1e-10))

# Example dataset: [Yes, Yes, No, No, No]
labels = np.array([1, 1, 0, 0, 0])

print(f"Gini Impurity: {gini_impurity(labels):.4f}")  # 0.4800
print(f"Entropy: {entropy(labels):.4f}")              # 0.9710

# After split: [Yes, Yes] and [No, No, No]
left_labels = np.array([1, 1])
right_labels = np.array([0, 0, 0])

weighted_gini = (len(left_labels) * gini_impurity(left_labels) + 
                 len(right_labels) * gini_impurity(right_labels)) / len(labels)
print(f"Weighted Gini after split: {weighted_gini:.4f}")  # 0.0000 - perfect split!

Building a Decision Tree from Scratch

Understanding the internals makes you better at using libraries. Here’s a minimal decision tree classifier:

import numpy as np

class Node:
    def __init__(self, feature=None, threshold=None, left=None, right=None, value=None):
        self.feature = feature      # Index of feature to split on
        self.threshold = threshold  # Threshold value for split
        self.left = left           # Left subtree
        self.right = right         # Right subtree
        self.value = value         # Class value if leaf node

class DecisionTree:
    def __init__(self, max_depth=10, min_samples_split=2):
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.root = None
    
    def fit(self, X, y):
        self.root = self._build_tree(X, y, depth=0)
    
    def _build_tree(self, X, y, depth):
        n_samples, n_features = X.shape
        n_classes = len(np.unique(y))
        
        # Stopping criteria
        if (depth >= self.max_depth or 
            n_classes == 1 or 
            n_samples < self.min_samples_split):
            leaf_value = self._most_common_label(y)
            return Node(value=leaf_value)
        
        # Find best split
        best_feature, best_threshold = self._best_split(X, y, n_features)
        
        # Create child nodes
        left_idxs = X[:, best_feature] <= best_threshold
        right_idxs = ~left_idxs
        
        left = self._build_tree(X[left_idxs], y[left_idxs], depth + 1)
        right = self._build_tree(X[right_idxs], y[right_idxs], depth + 1)
        
        return Node(best_feature, best_threshold, left, right)
    
    def _best_split(self, X, y, n_features):
        best_gain = -1
        split_idx, split_threshold = None, None
        
        for feature in range(n_features):
            thresholds = np.unique(X[:, feature])
            
            for threshold in thresholds:
                gain = self._information_gain(X[:, feature], y, threshold)
                
                if gain > best_gain:
                    best_gain = gain
                    split_idx = feature
                    split_threshold = threshold
        
        return split_idx, split_threshold
    
    def _information_gain(self, X_column, y, threshold):
        parent_entropy = entropy(y)
        
        left_idxs = X_column <= threshold
        right_idxs = ~left_idxs
        
        if np.sum(left_idxs) == 0 or np.sum(right_idxs) == 0:
            return 0
        
        n = len(y)
        n_left, n_right = np.sum(left_idxs), np.sum(right_idxs)
        e_left, e_right = entropy(y[left_idxs]), entropy(y[right_idxs])
        child_entropy = (n_left / n) * e_left + (n_right / n) * e_right
        
        return parent_entropy - child_entropy
    
    def _most_common_label(self, y):
        return np.bincount(y).argmax()
    
    def predict(self, X):
        return np.array([self._traverse_tree(x, self.root) for x in X])
    
    def _traverse_tree(self, x, node):
        if node.value is not None:
            return node.value
        
        if x[node.feature] <= node.threshold:
            return self._traverse_tree(x, node.left)
        return self._traverse_tree(x, node.right)

Using Scikit-learn’s DecisionTreeClassifier

For production use, scikit-learn provides an optimized implementation:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt

# Load data
iris = load_iris()
X, y = iris.data, iris.target

# Split dataset
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# Train decision tree
clf = DecisionTreeClassifier(max_depth=3, random_state=42)
clf.fit(X_train, y_train)

# Make predictions
y_pred = clf.predict(X_test)
accuracy = clf.score(X_test, y_test)
print(f"Accuracy: {accuracy:.3f}")

# Visualize the tree
plt.figure(figsize=(15, 10))
plot_tree(clf, feature_names=iris.feature_names, 
          class_names=iris.target_names, filled=True)
plt.savefig('decision_tree.png', dpi=300, bbox_inches='tight')

Hyperparameter Tuning and Avoiding Overfitting

Decision trees easily overfit by memorizing training data. Control this through hyperparameters:

  • max_depth: Limits tree depth (start with 3-10)
  • min_samples_split: Minimum samples required to split a node
  • min_samples_leaf: Minimum samples required at leaf nodes
  • max_features: Number of features to consider for best split
from sklearn.model_selection import cross_val_score

# Compare different max_depth values
depths = [1, 3, 5, 10, None]
train_scores = []
val_scores = []

for depth in depths:
    clf = DecisionTreeClassifier(max_depth=depth, random_state=42)
    
    # Training score
    clf.fit(X_train, y_train)
    train_scores.append(clf.score(X_train, y_train))
    
    # Validation score via cross-validation
    cv_scores = cross_val_score(clf, X_train, y_train, cv=5)
    val_scores.append(cv_scores.mean())

# Visualize overfitting
plt.figure(figsize=(10, 6))
plt.plot([str(d) for d in depths], train_scores, label='Training Score', marker='o')
plt.plot([str(d) for d in depths], val_scores, label='Validation Score', marker='o')
plt.xlabel('Max Depth')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Training vs Validation Accuracy by Max Depth')
plt.grid(True)
plt.show()

Use GridSearchCV for systematic tuning:

from sklearn.model_selection import GridSearchCV

param_grid = {
    'max_depth': [3, 5, 7, 10],
    'min_samples_split': [2, 5, 10],
    'min_samples_leaf': [1, 2, 4]
}

grid_search = GridSearchCV(
    DecisionTreeClassifier(random_state=42),
    param_grid,
    cv=5,
    scoring='accuracy'
)

grid_search.fit(X_train, y_train)
print(f"Best parameters: {grid_search.best_params_}")
print(f"Best cross-validation score: {grid_search.best_score_:.3f}")

Real-World Application: End-to-End Project

Here’s a complete classification workflow:

from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import roc_curve, auc
import seaborn as sns

# Train final model with best parameters
best_clf = grid_search.best_estimator_
best_clf.fit(X_train, y_train)
y_pred = best_clf.predict(X_test)

# Feature importance
feature_importance = pd.DataFrame({
    'feature': iris.feature_names,
    'importance': best_clf.feature_importances_
}).sort_values('importance', ascending=False)

print("Feature Importance:")
print(feature_importance)

# Confusion matrix
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=iris.target_names,
            yticklabels=iris.target_names)
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title('Confusion Matrix')
plt.show()

# Classification report
print("\nClassification Report:")
print(classification_report(y_test, y_pred, 
                          target_names=iris.target_names))

Conclusion and Next Steps

Decision trees excel when you need interpretable models and have categorical or mixed data types. They require minimal data preprocessing, handle non-linear relationships, and provide clear feature importance rankings.

However, they have limitations: high variance (small data changes cause different trees), tendency to overfit, and bias toward features with more levels. For production systems, prefer ensemble methods:

  • Random Forests: Build multiple trees on random subsets and average predictions
  • Gradient Boosting (XGBoost, LightGBM): Build trees sequentially, each correcting previous errors
  • Extra Trees: Similar to Random Forests but with random thresholds

Start with decision trees to understand your data and establish baselines. Then graduate to ensembles for better performance while retaining some interpretability through feature importance and SHAP values.

The code examples in this article provide a foundation for implementing decision trees in any project. Experiment with different hyperparameters, visualize your trees, and always validate on held-out data to ensure your model generalizes.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.