How to Implement Decision Trees in Python
Decision trees are supervised learning algorithms that make predictions by learning a series of if-then-else decision rules from training data. Think of them as flowcharts where each internal node...
Key Insights
- Decision trees split data recursively based on feature values that maximize information gain or minimize Gini impurity, creating an interpretable model structure that mirrors human decision-making
- Building a decision tree from scratch requires implementing node splitting logic, recursive tree construction, and stopping criteria—understanding this foundation makes you better at tuning scikit-learn’s implementation
- Overfitting is decision trees’ biggest weakness; control it through max_depth, min_samples_split, and pruning, or use ensemble methods like Random Forests for production systems
Introduction to Decision Trees
Decision trees are supervised learning algorithms that make predictions by learning a series of if-then-else decision rules from training data. Think of them as flowcharts where each internal node represents a test on a feature, each branch represents the outcome of that test, and each leaf node represents a class label or continuous value.
The algorithm works by recursively partitioning the feature space. Starting at the root node with all training data, it selects the best feature and threshold to split the data into two subsets. This process repeats for each subset until a stopping criterion is met—typically when nodes become pure (contain only one class), reach a minimum size, or hit a maximum depth.
Decision trees handle both classification (predicting discrete labels) and regression (predicting continuous values). Their key advantage is interpretability: you can trace exactly why a model made a specific prediction. This makes them invaluable in healthcare, finance, and other domains where model transparency matters.
Understanding the Mathematics Behind Decision Trees
Decision trees need a metric to evaluate how good a split is. The two most common metrics are Gini impurity and entropy.
Gini impurity measures the probability of incorrectly classifying a randomly chosen element. For a node with classes, it’s calculated as: Gini = 1 - Σ(p_i)², where p_i is the probability of class i. A Gini of 0 means perfect purity (all samples belong to one class).
Entropy measures disorder or uncertainty. It’s calculated as: Entropy = -Σ(p_i × log₂(p_i)). Like Gini, lower values indicate purer nodes.
Information gain is the reduction in entropy after a split. The algorithm selects splits that maximize information gain.
Let’s calculate these manually:
import numpy as np
def gini_impurity(labels):
"""Calculate Gini impurity for a list of labels."""
if len(labels) == 0:
return 0
unique, counts = np.unique(labels, return_counts=True)
probabilities = counts / len(labels)
return 1 - np.sum(probabilities ** 2)
def entropy(labels):
"""Calculate entropy for a list of labels."""
if len(labels) == 0:
return 0
unique, counts = np.unique(labels, return_counts=True)
probabilities = counts / len(labels)
return -np.sum(probabilities * np.log2(probabilities + 1e-10))
# Example dataset: [Yes, Yes, No, No, No]
labels = np.array([1, 1, 0, 0, 0])
print(f"Gini Impurity: {gini_impurity(labels):.4f}") # 0.4800
print(f"Entropy: {entropy(labels):.4f}") # 0.9710
# After split: [Yes, Yes] and [No, No, No]
left_labels = np.array([1, 1])
right_labels = np.array([0, 0, 0])
weighted_gini = (len(left_labels) * gini_impurity(left_labels) +
len(right_labels) * gini_impurity(right_labels)) / len(labels)
print(f"Weighted Gini after split: {weighted_gini:.4f}") # 0.0000 - perfect split!
Building a Decision Tree from Scratch
Understanding the internals makes you better at using libraries. Here’s a minimal decision tree classifier:
import numpy as np
class Node:
def __init__(self, feature=None, threshold=None, left=None, right=None, value=None):
self.feature = feature # Index of feature to split on
self.threshold = threshold # Threshold value for split
self.left = left # Left subtree
self.right = right # Right subtree
self.value = value # Class value if leaf node
class DecisionTree:
def __init__(self, max_depth=10, min_samples_split=2):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.root = None
def fit(self, X, y):
self.root = self._build_tree(X, y, depth=0)
def _build_tree(self, X, y, depth):
n_samples, n_features = X.shape
n_classes = len(np.unique(y))
# Stopping criteria
if (depth >= self.max_depth or
n_classes == 1 or
n_samples < self.min_samples_split):
leaf_value = self._most_common_label(y)
return Node(value=leaf_value)
# Find best split
best_feature, best_threshold = self._best_split(X, y, n_features)
# Create child nodes
left_idxs = X[:, best_feature] <= best_threshold
right_idxs = ~left_idxs
left = self._build_tree(X[left_idxs], y[left_idxs], depth + 1)
right = self._build_tree(X[right_idxs], y[right_idxs], depth + 1)
return Node(best_feature, best_threshold, left, right)
def _best_split(self, X, y, n_features):
best_gain = -1
split_idx, split_threshold = None, None
for feature in range(n_features):
thresholds = np.unique(X[:, feature])
for threshold in thresholds:
gain = self._information_gain(X[:, feature], y, threshold)
if gain > best_gain:
best_gain = gain
split_idx = feature
split_threshold = threshold
return split_idx, split_threshold
def _information_gain(self, X_column, y, threshold):
parent_entropy = entropy(y)
left_idxs = X_column <= threshold
right_idxs = ~left_idxs
if np.sum(left_idxs) == 0 or np.sum(right_idxs) == 0:
return 0
n = len(y)
n_left, n_right = np.sum(left_idxs), np.sum(right_idxs)
e_left, e_right = entropy(y[left_idxs]), entropy(y[right_idxs])
child_entropy = (n_left / n) * e_left + (n_right / n) * e_right
return parent_entropy - child_entropy
def _most_common_label(self, y):
return np.bincount(y).argmax()
def predict(self, X):
return np.array([self._traverse_tree(x, self.root) for x in X])
def _traverse_tree(self, x, node):
if node.value is not None:
return node.value
if x[node.feature] <= node.threshold:
return self._traverse_tree(x, node.left)
return self._traverse_tree(x, node.right)
Using Scikit-learn’s DecisionTreeClassifier
For production use, scikit-learn provides an optimized implementation:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
# Load data
iris = load_iris()
X, y = iris.data, iris.target
# Split dataset
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# Train decision tree
clf = DecisionTreeClassifier(max_depth=3, random_state=42)
clf.fit(X_train, y_train)
# Make predictions
y_pred = clf.predict(X_test)
accuracy = clf.score(X_test, y_test)
print(f"Accuracy: {accuracy:.3f}")
# Visualize the tree
plt.figure(figsize=(15, 10))
plot_tree(clf, feature_names=iris.feature_names,
class_names=iris.target_names, filled=True)
plt.savefig('decision_tree.png', dpi=300, bbox_inches='tight')
Hyperparameter Tuning and Avoiding Overfitting
Decision trees easily overfit by memorizing training data. Control this through hyperparameters:
- max_depth: Limits tree depth (start with 3-10)
- min_samples_split: Minimum samples required to split a node
- min_samples_leaf: Minimum samples required at leaf nodes
- max_features: Number of features to consider for best split
from sklearn.model_selection import cross_val_score
# Compare different max_depth values
depths = [1, 3, 5, 10, None]
train_scores = []
val_scores = []
for depth in depths:
clf = DecisionTreeClassifier(max_depth=depth, random_state=42)
# Training score
clf.fit(X_train, y_train)
train_scores.append(clf.score(X_train, y_train))
# Validation score via cross-validation
cv_scores = cross_val_score(clf, X_train, y_train, cv=5)
val_scores.append(cv_scores.mean())
# Visualize overfitting
plt.figure(figsize=(10, 6))
plt.plot([str(d) for d in depths], train_scores, label='Training Score', marker='o')
plt.plot([str(d) for d in depths], val_scores, label='Validation Score', marker='o')
plt.xlabel('Max Depth')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Training vs Validation Accuracy by Max Depth')
plt.grid(True)
plt.show()
Use GridSearchCV for systematic tuning:
from sklearn.model_selection import GridSearchCV
param_grid = {
'max_depth': [3, 5, 7, 10],
'min_samples_split': [2, 5, 10],
'min_samples_leaf': [1, 2, 4]
}
grid_search = GridSearchCV(
DecisionTreeClassifier(random_state=42),
param_grid,
cv=5,
scoring='accuracy'
)
grid_search.fit(X_train, y_train)
print(f"Best parameters: {grid_search.best_params_}")
print(f"Best cross-validation score: {grid_search.best_score_:.3f}")
Real-World Application: End-to-End Project
Here’s a complete classification workflow:
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import roc_curve, auc
import seaborn as sns
# Train final model with best parameters
best_clf = grid_search.best_estimator_
best_clf.fit(X_train, y_train)
y_pred = best_clf.predict(X_test)
# Feature importance
feature_importance = pd.DataFrame({
'feature': iris.feature_names,
'importance': best_clf.feature_importances_
}).sort_values('importance', ascending=False)
print("Feature Importance:")
print(feature_importance)
# Confusion matrix
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=iris.target_names,
yticklabels=iris.target_names)
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title('Confusion Matrix')
plt.show()
# Classification report
print("\nClassification Report:")
print(classification_report(y_test, y_pred,
target_names=iris.target_names))
Conclusion and Next Steps
Decision trees excel when you need interpretable models and have categorical or mixed data types. They require minimal data preprocessing, handle non-linear relationships, and provide clear feature importance rankings.
However, they have limitations: high variance (small data changes cause different trees), tendency to overfit, and bias toward features with more levels. For production systems, prefer ensemble methods:
- Random Forests: Build multiple trees on random subsets and average predictions
- Gradient Boosting (XGBoost, LightGBM): Build trees sequentially, each correcting previous errors
- Extra Trees: Similar to Random Forests but with random thresholds
Start with decision trees to understand your data and establish baselines. Then graduate to ensembles for better performance while retaining some interpretability through feature importance and SHAP values.
The code examples in this article provide a foundation for implementing decision trees in any project. Experiment with different hyperparameters, visualize your trees, and always validate on held-out data to ensure your model generalizes.