PySpark - Get Unique Values from Column
Extracting unique values from DataFrame columns is a fundamental operation in PySpark that serves multiple critical purposes. Whether you're profiling data quality, validating business rules,...
Key Insights
- Use
distinct()for straightforward unique value extraction, but preferdropDuplicates()when you need fine-grained control over which columns to deduplicate or when working with multiple columns simultaneously. - Always use
countDistinct()orapprox_count_distinct()instead ofdistinct().count()when you only need the count—this avoids the expensive shuffle operation of materializing all unique values. - Avoid calling
collect()on large result sets of unique values; usetake()with a limit or write results to distributed storage to prevent driver memory overflow in production environments.
Introduction
Extracting unique values from DataFrame columns is a fundamental operation in PySpark that serves multiple critical purposes. Whether you’re profiling data quality, validating business rules, building categorical features for machine learning models, or generating lookup tables, understanding how to efficiently retrieve distinct values is essential for any data engineer or analyst working with large-scale datasets.
The challenge with PySpark compared to Pandas is that you’re working in a distributed environment. Operations that seem trivial—like getting a list of unique countries from a customer table—require careful consideration of data shuffling, memory constraints, and cluster resources. Choose the wrong approach, and you’ll either crash your driver node or waste hours waiting for unnecessary computations.
This article covers the practical techniques for extracting unique values in PySpark, with real-world examples and performance considerations that matter when working with production data at scale.
Using distinct() Method
The distinct() method is the most straightforward approach for getting unique values from a column. It returns a new DataFrame containing only the distinct rows based on all columns in the DataFrame (or the selected subset).
from pyspark.sql import SparkSession
# Initialize Spark session
spark = SparkSession.builder.appName("UniqueValues").getOrCreate()
# Create sample DataFrame
data = [
("USA", "New York"),
("USA", "California"),
("Canada", "Toronto"),
("USA", "New York"),
("Mexico", "Mexico City"),
("Canada", "Vancouver")
]
df = spark.createDataFrame(data, ["country", "city"])
# Get unique countries
unique_countries = df.select("country").distinct()
unique_countries.show()
# Convert to Python list
countries_list = [row.country for row in unique_countries.collect()]
print(countries_list) # ['USA', 'Canada', 'Mexico']
The distinct() method works on the entire row of the selected DataFrame. When you call df.select("country").distinct(), PySpark creates a new DataFrame with only the country column, then removes duplicate rows. This triggers a shuffle operation across your cluster to identify duplicates.
For getting a sorted list of unique values, chain the orderBy() method:
unique_countries_sorted = (df.select("country")
.distinct()
.orderBy("country")
.collect())
Using dropDuplicates() Method
The dropDuplicates() method provides more flexibility than distinct(), especially when working with DataFrames containing multiple columns. You can specify exactly which columns to consider when identifying duplicates.
# Drop duplicates based on country column only
unique_by_country = df.dropDuplicates(["country"])
unique_by_country.show()
# This keeps all columns but removes rows with duplicate countries
# Output includes city information for the first occurrence of each country
The key difference: distinct() considers all columns in the DataFrame, while dropDuplicates() lets you specify which columns matter for uniqueness. When you call dropDuplicates(["country"]), PySpark keeps the first row encountered for each unique country value, preserving all other columns.
# Get unique country-city combinations
unique_combinations = df.dropDuplicates(["country", "city"])
# Equivalent to:
unique_combinations_alt = df.distinct()
# These produce identical results when all columns are considered
For single-column unique value extraction, distinct() after selecting the column is cleaner. Use dropDuplicates() when you need to preserve additional columns or specify custom deduplication logic.
Counting Unique Values
Often you don’t need the actual unique values—just the count. This is critical for data profiling and validation checks. Using countDistinct() is significantly more efficient than materializing all unique values and counting them.
from pyspark.sql.functions import countDistinct, approx_count_distinct
# Efficient: count without materializing all unique values
country_count = df.select(countDistinct("country")).collect()[0][0]
print(f"Number of unique countries: {country_count}")
# Inefficient: don't do this for large datasets
country_count_bad = df.select("country").distinct().count()
# For very large datasets, use approximate counting
approx_country_count = df.select(approx_count_distinct("country", rsd=0.05)).collect()[0][0]
print(f"Approximate unique countries: {approx_country_count}")
The approx_count_distinct() function uses HyperLogLog algorithm with a configurable relative standard deviation (rsd). Setting rsd=0.05 means approximately 5% error margin, which is acceptable for many use cases and can be dramatically faster on datasets with billions of rows.
# Count unique values for multiple columns in one pass
from pyspark.sql.functions import col
unique_counts = df.select(
countDistinct("country").alias("unique_countries"),
countDistinct("city").alias("unique_cities")
)
unique_counts.show()
Getting Unique Values from Multiple Columns
When you need unique combinations across multiple columns, you have several approaches depending on your use case.
# Get unique country-city pairs
unique_pairs = df.select("country", "city").distinct()
unique_pairs.show()
# Convert to list of tuples
pairs_list = [(row.country, row.city) for row in unique_pairs.collect()]
print(pairs_list)
# Create a combined column for unique values
from pyspark.sql.functions import concat_ws
unique_combined = (df.select(concat_ws("-", "country", "city").alias("location"))
.distinct()
.collect())
locations = [row.location for row in unique_combined]
print(locations) # ['USA-New York', 'USA-California', ...]
For creating lookup dictionaries or mapping tables:
# Create a dictionary mapping countries to their cities
from collections import defaultdict
country_cities = defaultdict(set)
for row in df.collect():
country_cities[row.country].add(row.city)
# Convert sets to lists
country_cities = {k: list(v) for k, v in country_cities.items()}
print(country_cities)
# {'USA': ['New York', 'California'], 'Canada': ['Toronto', 'Vancouver'], ...}
Performance Considerations
Understanding the performance implications of different approaches is crucial when working with large datasets. The wrong method can mean the difference between a query that completes in seconds versus one that runs for hours or crashes.
# BAD: Collecting large result sets to driver
# This will crash if you have millions of unique values
unique_values = df.select("country").distinct().collect()
# BETTER: Limit the results
unique_values_sample = df.select("country").distinct().take(100)
# BEST: Write to distributed storage
df.select("country").distinct().write.parquet("s3://bucket/unique_countries/")
# Compare execution plans
df.select("country").distinct().explain()
df.dropDuplicates(["country"]).explain()
For operations you’ll perform multiple times, cache the DataFrame:
# Cache when you'll reuse the unique values
unique_df = df.select("country").distinct().cache()
# First action materializes and caches
count1 = unique_df.count()
# Subsequent actions use cached data
countries = unique_df.collect()
count2 = unique_df.count() # Much faster
# Don't forget to unpersist when done
unique_df.unpersist()
Use repartition() or coalesce() strategically when you know the result set is small:
# If unique values are few, reduce partitions to avoid small files
unique_countries = (df.select("country")
.distinct()
.coalesce(1)
.write
.csv("output/countries"))
Practical Example: Complete Workflow
Let’s walk through a realistic scenario: analyzing web server logs to extract unique user agents and IP addresses for security auditing.
from pyspark.sql import SparkSession
from pyspark.sql.functions import countDistinct, approx_count_distinct, col
from pyspark.sql.types import StructType, StructField, StringType, TimestampType
# Initialize Spark
spark = SparkSession.builder \
.appName("LogAnalysis") \
.config("spark.sql.shuffle.partitions", "200") \
.getOrCreate()
# Sample log data
log_data = [
("192.168.1.1", "Mozilla/5.0", "2024-01-15 10:30:00", "/home"),
("192.168.1.2", "Chrome/96.0", "2024-01-15 10:31:00", "/products"),
("192.168.1.1", "Mozilla/5.0", "2024-01-15 10:32:00", "/cart"),
("10.0.0.5", "Safari/15.0", "2024-01-15 10:33:00", "/home"),
("192.168.1.2", "Chrome/96.0", "2024-01-15 10:34:00", "/checkout"),
("10.0.0.5", "Safari/15.0", "2024-01-15 10:35:00", "/products"),
]
schema = StructType([
StructField("ip_address", StringType(), True),
StructField("user_agent", StringType(), True),
StructField("timestamp", StringType(), True),
StructField("endpoint", StringType(), True)
])
logs_df = spark.createDataFrame(log_data, schema)
# Step 1: Get summary statistics
print("=== Summary Statistics ===")
summary = logs_df.select(
countDistinct("ip_address").alias("unique_ips"),
countDistinct("user_agent").alias("unique_user_agents"),
countDistinct("endpoint").alias("unique_endpoints")
)
summary.show()
# Step 2: Extract unique IP addresses (security audit)
print("\n=== Unique IP Addresses ===")
unique_ips = logs_df.select("ip_address").distinct().orderBy("ip_address")
unique_ips.show()
# Step 3: Get IP-UserAgent combinations (device fingerprinting)
print("\n=== Unique IP-UserAgent Combinations ===")
unique_combinations = (logs_df.select("ip_address", "user_agent")
.distinct()
.orderBy("ip_address"))
unique_combinations.show(truncate=False)
# Step 4: Create lookup for IPs to user agents
ip_to_agents = {}
for row in unique_combinations.collect():
if row.ip_address not in ip_to_agents:
ip_to_agents[row.ip_address] = []
ip_to_agents[row.ip_address].append(row.user_agent)
print("\n=== IP to User Agents Mapping ===")
for ip, agents in ip_to_agents.items():
print(f"{ip}: {agents}")
# Step 5: Find IPs with multiple user agents (potential security concern)
from pyspark.sql.functions import collect_set, size
suspicious_ips = (logs_df.groupBy("ip_address")
.agg(collect_set("user_agent").alias("user_agents"))
.filter(size(col("user_agents")) > 1))
print("\n=== IPs with Multiple User Agents ===")
suspicious_ips.show(truncate=False)
This complete example demonstrates practical patterns: getting summary statistics without materializing data, extracting unique values for reporting, creating lookup structures, and identifying anomalies—all common tasks in production data pipelines.
The key takeaway is to always consider whether you need the actual values or just aggregates, and to be mindful of memory constraints when collecting results to the driver. Use PySpark’s distributed processing capabilities by writing results to storage systems rather than pulling everything into memory when working with production-scale data.