PySpark - Distinct Values in Column
Finding distinct values in PySpark columns is a fundamental operation in big data processing. Whether you're profiling a new dataset, validating data quality, removing duplicates, or analyzing...
Key Insights
- PySpark offers three primary methods for finding distinct values:
select().distinct(),dropDuplicates(), andcountDistinct(), each optimized for different use cases and performance profiles - Collecting distinct values to the driver requires careful memory management—use
collect()only when the result set is guaranteed to be small, otherwise leverage distributed operations - For production workloads, always filter nulls explicitly and consider partitioning strategies before executing distinct operations on large datasets to avoid performance bottlenecks
Introduction & Use Cases
Finding distinct values in PySpark columns is a fundamental operation in big data processing. Whether you’re profiling a new dataset, validating data quality, removing duplicates, or analyzing categorical distributions, understanding how to efficiently extract unique values is essential.
Common scenarios include identifying unique customer IDs for deduplication, discovering all possible values in a categorical field for validation, counting distinct events for analytics, and detecting data anomalies through cardinality checks. Each scenario demands different approaches based on dataset size, cluster resources, and whether you need the actual values or just counts.
Let’s start with a sample dataset that we’ll use throughout this article:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, countDistinct
spark = SparkSession.builder.appName("DistinctValues").getOrCreate()
# Sample data with duplicates
data = [
("Alice", "Engineering", "USA"),
("Bob", "Sales", "UK"),
("Charlie", "Engineering", "USA"),
("David", "Marketing", "Canada"),
("Eve", "Sales", "USA"),
("Frank", "Engineering", "UK"),
("Alice", "Engineering", "USA"), # Duplicate
("Bob", "Sales", "UK"), # Duplicate
]
df = spark.createDataFrame(data, ["name", "department", "country"])
df.show()
This creates a DataFrame with intentional duplicates that we’ll use to demonstrate various distinct value extraction techniques.
Basic distinct() Method
The most straightforward approach to finding distinct values is using select() combined with distinct(). This method returns a new DataFrame containing only unique rows for the selected column.
# Get distinct departments
distinct_departments = df.select("department").distinct()
distinct_departments.show()
# Output:
# +------------+
# | department|
# +------------+
# |Engineering |
# | Sales |
# | Marketing |
# +------------+
For counting distinct values without materializing them, use countDistinct():
from pyspark.sql.functions import countDistinct
# Count distinct values
dept_count = df.select(countDistinct("department").alias("unique_departments"))
dept_count.show()
# Output:
# +-------------------+
# |unique_departments |
# +-------------------+
# | 3|
# +-------------------+
# Count distinct values for multiple columns in one pass
distinct_counts = df.select(
countDistinct("department").alias("unique_depts"),
countDistinct("country").alias("unique_countries"),
countDistinct("name").alias("unique_names")
)
distinct_counts.show()
The countDistinct() function is significantly more efficient when you only need cardinality information, as it avoids materializing the actual distinct values and performs the aggregation in a single pass.
Using dropDuplicates() for Column-Specific Deduplication
While distinct() operates on the entire row of selected columns, dropDuplicates() provides more granular control. You can specify which columns to consider for uniqueness while retaining all other columns in the output.
# Drop duplicates based on department column only
# Keeps first occurrence of each department with all columns
dedup_by_dept = df.dropDuplicates(["department"])
dedup_by_dept.show()
# Output:
# +-------+------------+-------+
# | name| department|country|
# +-------+------------+-------+
# | Alice|Engineering | USA|
# | Bob| Sales| UK|
# | David| Marketing| Canada|
# +-------+------------+-------+
# Drop duplicates based on multiple columns
dedup_by_dept_country = df.dropDuplicates(["department", "country"])
dedup_by_dept_country.show()
The key difference: select("department").distinct() returns only the department column, while dropDuplicates(["department"]) returns entire rows where department values are unique. Use dropDuplicates() when you need to preserve the full row context.
Collecting Distinct Values to a List
For smaller result sets, you often need distinct values as a Python list for further processing, validation, or display purposes. PySpark provides several approaches with different performance characteristics.
# Method 1: Using collect() with list comprehension
distinct_depts_list = [row.department for row in
df.select("department").distinct().collect()]
print(distinct_depts_list)
# Output: ['Engineering', 'Sales', 'Marketing']
# Method 2: Using rdd.map() and collect()
distinct_countries = (df.select("country")
.distinct()
.rdd
.map(lambda row: row[0])
.collect())
print(distinct_countries)
# Output: ['USA', 'UK', 'Canada']
# Method 3: Using toPandas() for smaller datasets
distinct_depts_pandas = (df.select("department")
.distinct()
.toPandas()["department"]
.tolist())
print(distinct_depts_pandas)
Critical warning: collect() brings all data to the driver node. Only use this when you’re certain the distinct count is small (typically under 10,000 values). For large cardinality columns, this will cause out-of-memory errors.
For safer collection with limits:
# Limit distinct values before collecting
safe_distinct = (df.select("department")
.distinct()
.limit(100)
.rdd
.map(lambda row: row[0])
.collect())
Performance Considerations & Best Practices
Distinct operations trigger shuffles, which are expensive in distributed computing. Understanding performance implications helps you write efficient code.
Handling null values: By default, PySpark treats null as a distinct value. Filter them explicitly if needed:
# Add some null values to demonstrate
from pyspark.sql.functions import when
df_with_nulls = df.withColumn(
"department",
when(col("name") == "Eve", None).otherwise(col("department"))
)
# Get distinct values excluding nulls
distinct_non_null = (df_with_nulls
.select("department")
.filter(col("department").isNotNull())
.distinct())
distinct_non_null.show()
Performance optimization strategies:
# 1. Cache when performing multiple distinct operations
df.cache()
dept_distinct = df.select("department").distinct().count()
country_distinct = df.select("country").distinct().count()
df.unpersist()
# 2. Use approximate distinct for massive datasets
from pyspark.sql.functions import approx_count_distinct
# Faster but approximate (typically 2-3% error)
approx_count = df.select(
approx_count_distinct("name", rsd=0.05).alias("approx_unique_names")
)
approx_count.show()
# 3. Partition before distinct for very large datasets
df_partitioned = df.repartition(10, "department")
distinct_partitioned = df_partitioned.select("department").distinct()
The approx_count_distinct() function uses HyperLogLog algorithm and is significantly faster for high-cardinality columns, making it ideal for data profiling where exact counts aren’t critical.
Advanced: Multiple Columns & Grouping
Real-world scenarios often require finding distinct combinations across multiple columns or distinct counts within groups.
# Distinct combinations of multiple columns
distinct_combinations = df.select("department", "country").distinct()
distinct_combinations.show()
# Output shows unique department-country pairs
# +------------+-------+
# | department|country|
# +------------+-------+
# | Engineering| USA|
# | Sales| UK|
# | Marketing| Canada|
# | Engineering| UK|
# | Sales| USA|
# +------------+-------+
# Group-based distinct counts
dept_country_analysis = (df.groupBy("department")
.agg(countDistinct("country").alias("countries_present"),
countDistinct("name").alias("unique_employees")))
dept_country_analysis.show()
# Output:
# +------------+-----------------+----------------+
# | department|countries_present|unique_employees|
# +------------+-----------------+----------------+
# | Engineering| 2| 3|
# | Sales| 2| 2|
# | Marketing| 1| 1|
# +------------+-----------------+----------------+
# Find departments present in multiple countries
multi_country_depts = (dept_country_analysis
.filter(col("countries_present") > 1)
.select("department"))
multi_country_depts.show()
For complex analytical queries, combine grouping with window functions:
from pyspark.sql.window import Window
from pyspark.sql.functions import collect_set
# Collect all distinct countries per department
dept_countries = (df.groupBy("department")
.agg(collect_set("country").alias("countries")))
dept_countries.show(truncate=False)
# Output:
# +------------+------------+
# |department |countries |
# +------------+------------+
# |Engineering |[USA, UK] |
# |Sales |[UK, USA] |
# |Marketing |[Canada] |
# +------------+------------+
Conclusion & Summary
Choosing the right method for finding distinct values depends on your specific requirements:
Use select().distinct() when you need distinct values from specific columns and want only those columns in the result. This is ideal for simple distinct value extraction.
Use dropDuplicates() when you need to deduplicate based on specific columns but retain all columns in your DataFrame. Perfect for removing duplicate records while preserving row context.
Use countDistinct() when you only need the count of unique values, not the values themselves. This is the most performant option for cardinality checks.
Use approx_count_distinct() for massive datasets where approximate counts are acceptable, offering significant performance improvements.
Always consider your data size before calling collect(). For production systems, implement safeguards like limits or use distributed operations that keep data in the cluster. Filter nulls explicitly when they shouldn’t be counted as distinct values, and leverage caching when performing multiple distinct operations on the same DataFrame.
By understanding these methods and their performance characteristics, you can efficiently handle distinct value operations at any scale in PySpark.