Decision Trees: Complete Guide with Examples
Decision trees are supervised learning algorithms that work for both classification and regression tasks. They make predictions by learning simple decision rules from data features, creating a...
Key Insights
- Decision trees split data recursively based on feature values, using metrics like Gini impurity or entropy to determine the best splits at each node, making them highly interpretable models that require no feature scaling.
- Overfitting is the primary weakness of decision trees—they easily memorize training data—but can be controlled through pruning, limiting tree depth, and requiring minimum samples per split.
- While single decision trees are unstable and sensitive to small data changes, they form the foundation of powerful ensemble methods like Random Forests and Gradient Boosting that dominate many real-world applications.
Introduction to Decision Trees
Decision trees are supervised learning algorithms that work for both classification and regression tasks. They make predictions by learning simple decision rules from data features, creating a tree-like model of decisions.
The structure consists of:
- Root node: The starting point representing the entire dataset
- Internal nodes: Decision points that split data based on feature conditions
- Branches: Outcomes of decisions connecting nodes
- Leaf nodes: Terminal nodes containing final predictions
Use decision trees when you need interpretability, have mixed data types, or want a model that requires minimal data preprocessing. They excel in scenarios where decision logic needs explanation to stakeholders.
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
# Simple tennis dataset
data = {
'Outlook': ['Sunny', 'Sunny', 'Overcast', 'Rain', 'Rain', 'Rain', 'Overcast', 'Sunny'],
'Temperature': ['Hot', 'Hot', 'Hot', 'Mild', 'Cool', 'Cool', 'Cool', 'Mild'],
'Humidity': ['High', 'High', 'High', 'High', 'Normal', 'Normal', 'Normal', 'High'],
'Wind': ['Weak', 'Strong', 'Weak', 'Weak', 'Weak', 'Strong', 'Strong', 'Weak'],
'PlayTennis': ['No', 'No', 'Yes', 'Yes', 'Yes', 'No', 'Yes', 'No']
}
df = pd.DataFrame(data)
X = pd.get_dummies(df.drop('PlayTennis', axis=1))
y = (df['PlayTennis'] == 'Yes').astype(int)
tree = DecisionTreeClassifier(max_depth=3, random_state=42)
tree.fit(X, y)
plt.figure(figsize=(12, 8))
plot_tree(tree, feature_names=X.columns, class_names=['No', 'Yes'], filled=True)
plt.show()
How Decision Trees Work
Decision trees split data by asking yes/no questions about features. The algorithm selects splits that best separate classes or reduce variance in regression targets.
Entropy measures impurity or disorder in a dataset:
Entropy = -Σ(p_i * log2(p_i))
Where p_i is the proportion of samples belonging to class i. Entropy is 0 for pure nodes and maximum (1 for binary classification) for evenly split nodes.
Gini Impurity is an alternative metric:
Gini = 1 - Σ(p_i²)
Information Gain measures the reduction in entropy after a split:
Information Gain = Entropy(parent) - Weighted_Average(Entropy(children))
Here’s a manual calculation:
import numpy as np
def entropy(y):
"""Calculate entropy of a label array"""
_, counts = np.unique(y, return_counts=True)
probabilities = counts / len(y)
return -np.sum(probabilities * np.log2(probabilities + 1e-10))
def gini_impurity(y):
"""Calculate Gini impurity"""
_, counts = np.unique(y, return_counts=True)
probabilities = counts / len(y)
return 1 - np.sum(probabilities ** 2)
def information_gain(parent, left_child, right_child, metric='entropy'):
"""Calculate information gain from a split"""
if metric == 'entropy':
func = entropy
else:
func = gini_impurity
n = len(parent)
n_left, n_right = len(left_child), len(right_child)
parent_impurity = func(parent)
children_impurity = (n_left/n * func(left_child) +
n_right/n * func(right_child))
return parent_impurity - children_impurity
# Example calculation
parent = np.array([1, 1, 1, 0, 0, 0, 0, 0])
left = np.array([1, 1, 1])
right = np.array([0, 0, 0, 0, 0])
print(f"Parent entropy: {entropy(parent):.4f}")
print(f"Information gain: {information_gain(parent, left, right):.4f}")
print(f"Parent Gini: {gini_impurity(parent):.4f}")
print(f"Gini gain: {information_gain(parent, left, right, 'gini'):.4f}")
Building Decision Trees from Scratch
Understanding the internals helps you debug issues and optimize hyperparameters. Here’s a minimal implementation:
import numpy as np
class SimpleDecisionTree:
def __init__(self, max_depth=5, min_samples_split=2):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.tree = None
def _gini(self, y):
"""Calculate Gini impurity"""
_, counts = np.unique(y, return_counts=True)
probabilities = counts / len(y)
return 1 - np.sum(probabilities ** 2)
def _split(self, X, y, feature, threshold):
"""Split dataset based on feature and threshold"""
left_mask = X[:, feature] <= threshold
right_mask = ~left_mask
return (X[left_mask], y[left_mask],
X[right_mask], y[right_mask])
def _best_split(self, X, y):
"""Find the best feature and threshold to split on"""
best_gain = -1
best_feature = None
best_threshold = None
n_features = X.shape[1]
parent_gini = self._gini(y)
for feature in range(n_features):
thresholds = np.unique(X[:, feature])
for threshold in thresholds:
X_left, y_left, X_right, y_right = self._split(
X, y, feature, threshold)
if len(y_left) == 0 or len(y_right) == 0:
continue
n = len(y)
gini_left = self._gini(y_left)
gini_right = self._gini(y_right)
weighted_gini = (len(y_left)/n * gini_left +
len(y_right)/n * gini_right)
gain = parent_gini - weighted_gini
if gain > best_gain:
best_gain = gain
best_feature = feature
best_threshold = threshold
return best_feature, best_threshold
def _build_tree(self, X, y, depth=0):
"""Recursively build the decision tree"""
n_samples = len(y)
n_classes = len(np.unique(y))
# Stopping criteria
if (depth >= self.max_depth or
n_samples < self.min_samples_split or
n_classes == 1):
leaf_value = np.argmax(np.bincount(y))
return {'leaf': True, 'value': leaf_value}
# Find best split
feature, threshold = self._best_split(X, y)
if feature is None:
leaf_value = np.argmax(np.bincount(y))
return {'leaf': True, 'value': leaf_value}
# Split and recurse
X_left, y_left, X_right, y_right = self._split(
X, y, feature, threshold)
left_subtree = self._build_tree(X_left, y_left, depth + 1)
right_subtree = self._build_tree(X_right, y_right, depth + 1)
return {
'leaf': False,
'feature': feature,
'threshold': threshold,
'left': left_subtree,
'right': right_subtree
}
def fit(self, X, y):
"""Build decision tree from training data"""
self.tree = self._build_tree(X, y)
return self
def _predict_sample(self, x, tree):
"""Predict class for a single sample"""
if tree['leaf']:
return tree['value']
if x[tree['feature']] <= tree['threshold']:
return self._predict_sample(x, tree['left'])
else:
return self._predict_sample(x, tree['right'])
def predict(self, X):
"""Predict classes for samples"""
return np.array([self._predict_sample(x, self.tree) for x in X])
# Test the implementation
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(
iris.data, iris.target, test_size=0.2, random_state=42)
tree = SimpleDecisionTree(max_depth=3)
tree.fit(X_train, y_train)
predictions = tree.predict(X_test)
print(f"Accuracy: {accuracy_score(y_test, predictions):.3f}")
Using Scikit-learn for Decision Trees
For production use, leverage scikit-learn’s optimized implementation with more features and better performance.
Classification Example:
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
# Load data
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(
iris.data, iris.target, test_size=0.2, random_state=42)
# Train classifier
clf = DecisionTreeClassifier(
criterion='gini', # or 'entropy'
max_depth=4,
min_samples_split=5,
min_samples_leaf=2,
random_state=42
)
clf.fit(X_train, y_train)
# Evaluate
y_pred = clf.predict(X_test)
print(classification_report(y_test, y_pred, target_names=iris.target_names))
# Visualize
plt.figure(figsize=(15, 10))
plot_tree(clf, feature_names=iris.feature_names,
class_names=iris.target_names, filled=True)
plt.show()
Regression Example:
from sklearn.tree import DecisionTreeRegressor
from sklearn.datasets import fetch_california_housing
from sklearn.metrics import mean_squared_error, r2_score
import numpy as np
# Load housing data
housing = fetch_california_housing()
X_train, X_test, y_train, y_test = train_test_split(
housing.data, housing.target, test_size=0.2, random_state=42)
# Train regressor
reg = DecisionTreeRegressor(
max_depth=5,
min_samples_split=20,
min_samples_leaf=10,
random_state=42
)
reg.fit(X_train, y_train)
# Evaluate
y_pred = reg.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f"MSE: {mse:.3f}")
print(f"RMSE: {np.sqrt(mse):.3f}")
print(f"R² Score: {r2:.3f}")
Preventing Overfitting
Decision trees overfit easily by creating overly complex trees that memorize training data. Control this through:
- Pre-pruning: Limit growth during training
- Post-pruning: Build full tree, then remove branches
- Hyperparameter tuning: Find optimal complexity
from sklearn.model_selection import GridSearchCV, validation_curve
from sklearn.datasets import load_breast_cancer
import matplotlib.pyplot as plt
# Load data
cancer = load_breast_cancer()
X_train, X_test, y_train, y_test = train_test_split(
cancer.data, cancer.target, test_size=0.2, random_state=42)
# Demonstrate overfitting
overfit_tree = DecisionTreeClassifier(random_state=42) # No limits
overfit_tree.fit(X_train, y_train)
tuned_tree = DecisionTreeClassifier(max_depth=5, min_samples_leaf=10,
random_state=42)
tuned_tree.fit(X_train, y_train)
print(f"Overfitted - Train: {overfit_tree.score(X_train, y_train):.3f}, "
f"Test: {overfit_tree.score(X_test, y_test):.3f}")
print(f"Tuned - Train: {tuned_tree.score(X_train, y_train):.3f}, "
f"Test: {tuned_tree.score(X_test, y_test):.3f}")
# Grid search for optimal hyperparameters
param_grid = {
'max_depth': [3, 5, 7, 10, None],
'min_samples_split': [2, 5, 10, 20],
'min_samples_leaf': [1, 2, 5, 10],
'criterion': ['gini', 'entropy']
}
grid_search = GridSearchCV(
DecisionTreeClassifier(random_state=42),
param_grid,
cv=5,
scoring='accuracy',
n_jobs=-1
)
grid_search.fit(X_train, y_train)
print(f"\nBest parameters: {grid_search.best_params_}")
print(f"Best CV score: {grid_search.best_score_:.3f}")
print(f"Test score: {grid_search.score(X_test, y_test):.3f}")
# Validation curve for max_depth
train_scores, val_scores = validation_curve(
DecisionTreeClassifier(random_state=42),
X_train, y_train,
param_name='max_depth',
param_range=range(1, 20),
cv=5
)
plt.figure(figsize=(10, 6))
plt.plot(range(1, 20), train_scores.mean(axis=1), label='Training score')
plt.plot(range(1, 20), val_scores.mean(axis=1), label='Validation score')
plt.xlabel('Max Depth')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Validation Curve for max_depth')
plt.show()
Advantages, Limitations, and Real-World Applications
Advantages:
- Highly interpretable and easy to visualize
- No feature scaling or normalization required
- Handles both numerical and categorical data
- Captures non-linear relationships
- Fast prediction time
Limitations:
- Prone to overfitting without proper constraints
- Unstable—small data changes cause different trees
- Biased toward features with many levels
- Cannot extrapolate beyond training data range
- Greedy algorithm may miss optimal global tree
When to Use:
- Need model interpretability for stakeholders
- Mixed data types without extensive preprocessing
- Quick baseline model before trying complex algorithms
- Feature interactions are important
Ensemble Extensions:
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_wine
# Load data
wine = load_wine()
X_train, X_test, y_train, y_test = train_test_split(
wine.data, wine.target, test_size=0.2, random_state=42)
# Single decision tree
dt = DecisionTreeClassifier(max_depth=5, random_state=42)
dt.fit(X_train, y_train)
dt_score = dt.score(X_test, y_test)
# Random Forest (ensemble of trees)
rf = RandomForestClassifier(n_estimators=100, max_depth=5, random_state=42)
rf.fit(X_train, y_train)
rf_score = rf.score(X_test, y_test)
print(f"Decision Tree Accuracy: {dt_score:.3f}")
print(f"Random Forest Accuracy: {rf_score:.3f}")
print(f"Improvement: {(rf_score - dt_score):.3f}")
# Feature importance comparison
import pandas as pd
importance_df = pd.DataFrame({
'feature': wine.feature_names,
'dt_importance': dt.feature_importances_,
'rf_importance': rf.feature_importances_
}).sort_values('rf_importance', ascending=False)
print("\nTop 5 Important Features:")
print(importance_df.head())
Best Practices and Conclusion
Always start with constrained trees. Set max_depth between 3-10, min_samples_split to 10-50, and min_samples_leaf to 5-20 depending on dataset size. Use cross-validation to find optimal values.
For feature engineering, decision trees handle missing values poorly—impute them first. They also don’t benefit from feature scaling, but categorical encoding matters. Use one-hot encoding for nominal features and ordinal encoding for ordered categories.
In production, monitor tree depth and node counts. Deep trees signal overfitting. Consider ensemble methods like Random Forests for better generalization, or Gradient Boosting when you need maximum accuracy and can sacrifice some interpretability.
Decision trees remain valuable for their transparency. When you need to explain predictions to non-technical stakeholders or comply with regulations requiring model interpretability, decision trees are often the right choice. Master them as foundational building blocks before moving to more complex ensemble methods.