PySpark - Decision Tree Classifier with MLlib
• Decision Trees in PySpark MLlib provide interpretable classification models that handle both numerical and categorical features natively, making them ideal for production environments where model...
Key Insights
• Decision Trees in PySpark MLlib provide interpretable classification models that handle both numerical and categorical features natively, making them ideal for production environments where model explainability matters. • The StringIndexer and VectorAssembler pipeline components are essential for transforming raw data into the feature vectors required by MLlib classifiers, with proper handling of categorical variables through indexing. • Cross-validation and hyperparameter tuning through ParamGridBuilder significantly improve model performance, while built-in evaluation metrics like accuracy, precision, and recall provide comprehensive model assessment.
Setting Up the Environment
PySpark’s MLlib library provides distributed machine learning capabilities that scale across clusters. Before building a Decision Tree classifier, initialize a SparkSession with appropriate configurations.
from pyspark.sql import SparkSession
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml import Pipeline
spark = SparkSession.builder \
.appName("DecisionTreeClassifier") \
.config("spark.driver.memory", "4g") \
.getOrCreate()
# Load sample data
data = spark.read.csv("customer_data.csv", header=True, inferSchema=True)
data.show(5)
For this example, assume a dataset with features like age, income, education level, and a target variable indicating customer churn.
Data Preprocessing and Feature Engineering
Decision Trees in MLlib require features in a single vector column. Use VectorAssembler to combine multiple feature columns, and StringIndexer to convert categorical variables into numerical indices.
# Display schema to understand data types
data.printSchema()
# Handle categorical variables
education_indexer = StringIndexer(
inputCol="education",
outputCol="education_index",
handleInvalid="keep"
)
employment_indexer = StringIndexer(
inputCol="employment_status",
outputCol="employment_index",
handleInvalid="keep"
)
# Index the label column
label_indexer = StringIndexer(
inputCol="churn",
outputCol="label"
)
# Combine all features into a single vector
feature_cols = [
"age",
"income",
"tenure_months",
"education_index",
"employment_index"
]
assembler = VectorAssembler(
inputCols=feature_cols,
outputCol="features",
handleInvalid="skip"
)
The handleInvalid="skip" parameter ensures rows with null values are filtered out during transformation, preventing runtime errors.
Building the Decision Tree Model
Configure the DecisionTreeClassifier with parameters that control tree depth, splitting criteria, and minimum instances per node. These parameters directly impact model complexity and generalization.
# Initialize Decision Tree classifier
dt = DecisionTreeClassifier(
featuresCol="features",
labelCol="label",
maxDepth=5,
minInstancesPerNode=20,
impurity="gini",
maxBins=32
)
# Create ML Pipeline
pipeline = Pipeline(stages=[
education_indexer,
employment_indexer,
label_indexer,
assembler,
dt
])
# Split data into training and test sets
train_data, test_data = data.randomSplit([0.8, 0.2], seed=42)
print(f"Training set size: {train_data.count()}")
print(f"Test set size: {test_data.count()}")
# Train the model
model = pipeline.fit(train_data)
The maxBins parameter determines how many bins are used for discretizing continuous features, affecting both performance and memory usage.
Making Predictions and Model Evaluation
After training, extract the Decision Tree model from the pipeline and generate predictions on the test set. Evaluate performance using multiple metrics.
# Make predictions
predictions = model.transform(test_data)
# Display sample predictions
predictions.select(
"features",
"label",
"prediction",
"probability"
).show(10, truncate=False)
# Extract the trained Decision Tree model
dt_model = model.stages[-1]
# Evaluate accuracy
accuracy_evaluator = MulticlassClassificationEvaluator(
labelCol="label",
predictionCol="prediction",
metricName="accuracy"
)
accuracy = accuracy_evaluator.evaluate(predictions)
print(f"Test Accuracy: {accuracy:.4f}")
# Evaluate precision
precision_evaluator = MulticlassClassificationEvaluator(
labelCol="label",
predictionCol="prediction",
metricName="weightedPrecision"
)
precision = precision_evaluator.evaluate(predictions)
print(f"Weighted Precision: {precision:.4f}")
# Evaluate recall
recall_evaluator = MulticlassClassificationEvaluator(
labelCol="label",
predictionCol="prediction",
metricName="weightedRecall"
)
recall = recall_evaluator.evaluate(predictions)
print(f"Weighted Recall: {recall:.4f}")
Hyperparameter Tuning with Cross-Validation
Use ParamGridBuilder and CrossValidator to systematically search for optimal hyperparameters. This approach prevents overfitting and improves model generalization.
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
# Create parameter grid
param_grid = ParamGridBuilder() \
.addGrid(dt.maxDepth, [3, 5, 7, 10]) \
.addGrid(dt.minInstancesPerNode, [10, 20, 50]) \
.addGrid(dt.impurity, ["gini", "entropy"]) \
.build()
# Configure cross-validator
cv = CrossValidator(
estimator=pipeline,
estimatorParamMaps=param_grid,
evaluator=accuracy_evaluator,
numFolds=5,
seed=42,
parallelism=4
)
# Train with cross-validation
cv_model = cv.fit(train_data)
# Get best model
best_model = cv_model.bestModel
# Evaluate best model
best_predictions = best_model.transform(test_data)
best_accuracy = accuracy_evaluator.evaluate(best_predictions)
print(f"Best Model Accuracy: {best_accuracy:.4f}")
# Extract best parameters
best_dt_model = best_model.stages[-1]
print(f"Best maxDepth: {best_dt_model.getMaxDepth()}")
print(f"Best minInstancesPerNode: {best_dt_model.getMinInstancesPerNode()}")
print(f"Best impurity: {best_dt_model.getImpurity()}")
Feature Importance Analysis
Decision Trees provide feature importance scores that indicate which variables contribute most to predictions. This interpretability is crucial for business stakeholders.
# Get feature importances
feature_importance = best_dt_model.featureImportances
# Create feature importance DataFrame
import pandas as pd
importance_df = pd.DataFrame({
'feature': feature_cols,
'importance': feature_importance.toArray()
})
importance_df = importance_df.sort_values('importance', ascending=False)
print("\nFeature Importances:")
print(importance_df)
# Visualize the decision tree structure
print("\nDecision Tree Structure:")
print(best_dt_model.toDebugString)
Handling Class Imbalance
For imbalanced datasets, apply class weights or resampling techniques to prevent bias toward the majority class.
# Calculate class distribution
class_counts = train_data.groupBy("churn").count()
class_counts.show()
# Apply class weights
from pyspark.sql.functions import when
total_count = train_data.count()
positive_count = train_data.filter(train_data.churn == "Yes").count()
negative_count = total_count - positive_count
balance_ratio = negative_count / positive_count
# Create weighted dataset
weighted_data = train_data.withColumn(
"weight",
when(train_data.churn == "Yes", balance_ratio).otherwise(1.0)
)
# Train with weighted data
dt_weighted = DecisionTreeClassifier(
featuresCol="features",
labelCol="label",
weightCol="weight",
maxDepth=5
)
# Update pipeline with weighted classifier
weighted_pipeline = Pipeline(stages=[
education_indexer,
employment_indexer,
label_indexer,
assembler,
dt_weighted
])
weighted_model = weighted_pipeline.fit(weighted_data)
Model Persistence
Save trained models to disk for deployment in production environments. PySpark supports saving both individual models and complete pipelines.
# Save the complete pipeline model
model_path = "hdfs://path/to/model/dt_classifier_model"
best_model.write().overwrite().save(model_path)
# Load the model for inference
from pyspark.ml import PipelineModel
loaded_model = PipelineModel.load(model_path)
# Use loaded model for predictions
new_predictions = loaded_model.transform(test_data)
new_predictions.select("prediction", "probability").show(5)
Decision Tree classifiers in PySpark MLlib combine scalability with interpretability, making them practical for production machine learning workflows. The pipeline architecture ensures reproducible transformations, while cross-validation and feature importance analysis provide the insights needed for model refinement and stakeholder communication.