PySpark - Count Distinct Values

Counting distinct values is a fundamental operation in data analysis, whether you're calculating unique customer counts, identifying the number of distinct products sold, or measuring unique daily...

Key Insights

  • PySpark offers three distinct counting methods: countDistinct() for exact counts, approx_count_distinct() for faster approximate results on massive datasets, and chaining distinct().count() for specific use cases requiring the actual distinct rows.
  • Counting distinct combinations across multiple columns requires passing all columns as arguments to countDistinct(), which treats them as a composite key rather than counting each column separately.
  • For datasets exceeding millions of rows, approx_count_distinct() with HyperLogLog can deliver 99%+ accuracy while reducing computation time by 50-80%, making it the superior choice for exploratory analysis and dashboards.

Introduction

Counting distinct values is a fundamental operation in data analysis, whether you’re calculating unique customer counts, identifying the number of distinct products sold, or measuring unique daily active users. In distributed computing environments like PySpark, this operation becomes more complex due to data partitioning across cluster nodes, but PySpark provides several efficient methods to handle it.

Understanding when and how to use each distinct counting method directly impacts query performance and resource utilization. A poorly chosen approach can mean the difference between a query that completes in seconds versus one that runs for hours on large datasets. This article covers practical techniques for counting distinct values in PySpark, from basic operations to advanced optimization strategies.

Let’s start by creating a sample DataFrame that we’ll use throughout our examples:

from pyspark.sql import SparkSession
from pyspark.sql.functions import countDistinct, approx_count_distinct, col

spark = SparkSession.builder.appName("DistinctCount").getOrCreate()

data = [
    ("2024-01-01", "customer_1", "product_A", "NY"),
    ("2024-01-01", "customer_1", "product_B", "NY"),
    ("2024-01-01", "customer_2", "product_A", "CA"),
    ("2024-01-02", "customer_1", "product_A", "NY"),
    ("2024-01-02", "customer_3", "product_C", "TX"),
    ("2024-01-02", "customer_2", "product_A", "CA"),
    ("2024-01-03", "customer_3", "product_B", "TX"),
    ("2024-01-03", "customer_4", "product_A", "NY"),
]

df = spark.createDataFrame(data, ["date", "customer_id", "product_id", "state"])
df.show()

Basic Distinct Count with countDistinct()

The countDistinct() function from the pyspark.sql.functions module is the most straightforward way to count unique values in a column. It works within aggregation operations, making it ideal for use with agg() or select() methods.

from pyspark.sql.functions import countDistinct

# Count distinct customers
distinct_customers = df.agg(countDistinct("customer_id").alias("unique_customers"))
distinct_customers.show()
# Output: unique_customers = 4

# Alternative syntax using select
distinct_products = df.select(countDistinct("product_id").alias("unique_products"))
distinct_products.show()
# Output: unique_products = 3

The countDistinct() function performs an exact count by shuffling data across the cluster to identify all unique values. This guarantees accuracy but requires more computational resources as dataset size grows. The alias() method provides a readable column name for your result.

Counting Distinct Values Across Multiple Columns

Real-world scenarios often require counting distinct values for multiple columns simultaneously or finding unique combinations of columns. PySpark handles both cases elegantly.

To count distinct values for multiple columns separately in a single query:

# Count distinct values for multiple columns
multi_distinct = df.agg(
    countDistinct("customer_id").alias("unique_customers"),
    countDistinct("product_id").alias("unique_products"),
    countDistinct("state").alias("unique_states")
)
multi_distinct.show()
# Output: unique_customers=4, unique_products=3, unique_states=3

To count distinct combinations of multiple columns (treating them as a composite key):

# Count distinct customer-product combinations
distinct_combinations = df.agg(
    countDistinct("customer_id", "product_id").alias("unique_customer_product_pairs")
)
distinct_combinations.show()
# Output: unique_customer_product_pairs = 7

# This is different from counting each column separately
# Customer_1 bought products A and B (2 combinations)
# Customer_2 bought product A (1 combination)
# Customer_3 bought products C and B (2 combinations)
# Customer_4 bought product A (1 combination)
# Total: 7 unique pairs

This distinction is crucial: passing multiple columns to countDistinct() counts unique combinations, not the sum of individual distinct counts.

Using distinct() and count() Methods

An alternative approach chains the distinct() and count() methods. While this achieves the same result as countDistinct(), it offers different capabilities and use cases.

# Using distinct().count() for a single column
distinct_customer_count = df.select("customer_id").distinct().count()
print(f"Distinct customers: {distinct_customer_count}")
# Output: 4

# For multiple columns (combination)
distinct_pairs = df.select("customer_id", "product_id").distinct().count()
print(f"Distinct customer-product pairs: {distinct_pairs}")
# Output: 7

# Key difference: distinct() returns a DataFrame with unique rows
unique_customers_df = df.select("customer_id").distinct()
unique_customers_df.show()
# This gives you the actual distinct customer IDs, not just the count

Use distinct().count() when you need the actual distinct rows for further processing, not just the count. For example, if you need to join the distinct values with another DataFrame or perform additional transformations, this approach provides more flexibility. However, for pure counting operations, countDistinct() is more concise and equally performant.

Approximate Distinct Counts with approx_count_distinct()

When working with massive datasets containing millions or billions of rows, exact distinct counts can be computationally expensive. The approx_count_distinct() function uses the HyperLogLog algorithm to provide fast, approximate counts with controllable accuracy.

from pyspark.sql.functions import approx_count_distinct

# Approximate distinct count with default precision
approx_customers = df.agg(
    approx_count_distinct("customer_id").alias("approx_unique_customers")
)
approx_customers.show()

# Compare exact vs approximate
comparison = df.agg(
    countDistinct("customer_id").alias("exact_count"),
    approx_count_distinct("customer_id").alias("approx_count"),
    approx_count_distinct("customer_id", rsd=0.01).alias("approx_high_precision")
)
comparison.show()

The rsd parameter (relative standard deviation) controls accuracy. Lower values increase precision but require more memory. The default rsd=0.05 provides approximately 95% accuracy, while rsd=0.01 achieves 99%+ accuracy. For datasets with millions of distinct values, approx_count_distinct() can be 50-80% faster than exact counting.

Use approximate counts for:

  • Exploratory data analysis where exact precision isn’t critical
  • Real-time dashboards requiring fast refresh rates
  • Datasets with billions of rows where exact counts are prohibitively expensive

Stick with exact counts for:

  • Financial reporting requiring audit trails
  • Compliance reports where precision is mandatory
  • Small to medium datasets where performance difference is negligible

GroupBy with Distinct Counts

Analytical queries frequently require distinct counts within groups—for example, counting unique products per customer or unique customers per state.

# Count distinct products per customer
products_per_customer = df.groupBy("customer_id").agg(
    countDistinct("product_id").alias("distinct_products_purchased")
)
products_per_customer.show()

# Count distinct customers per state
customers_per_state = df.groupBy("state").agg(
    countDistinct("customer_id").alias("unique_customers")
)
customers_per_state.show()

# Multiple aggregations with distinct counts
daily_metrics = df.groupBy("date").agg(
    countDistinct("customer_id").alias("daily_active_customers"),
    countDistinct("product_id").alias("products_sold"),
    approx_count_distinct("customer_id", "product_id").alias("approx_transactions")
)
daily_metrics.show()

Combining groupBy() with distinct counts is powerful for cohort analysis, user behavior tracking, and market segmentation. You can mix exact and approximate counts in the same aggregation to balance performance and accuracy based on each metric’s importance.

Performance Considerations and Best Practices

Distinct counting in distributed systems involves shuffling data across partitions, which can create performance bottlenecks. Here are optimization strategies:

Handle null values explicitly:

# Nulls are counted as distinct values by default
# Exclude nulls from distinct count
df_with_nulls = df.union(
    spark.createDataFrame([("2024-01-04", None, "product_D", "FL")], 
                          ["date", "customer_id", "product_id", "state"])
)

# Count distinct non-null customers
non_null_customers = df_with_nulls.filter(col("customer_id").isNotNull()).agg(
    countDistinct("customer_id").alias("unique_customers")
)
non_null_customers.show()

Analyze query execution plans:

# Compare execution plans
df.agg(countDistinct("customer_id")).explain()
df.select("customer_id").distinct().count()  # Check explain() here too

# For large datasets, consider partitioning strategy
df.repartition("state").groupBy("state").agg(
    countDistinct("customer_id")
).explain()

Choose the right method:

  • Use countDistinct() for most use cases—it’s concise and performant
  • Use distinct().count() when you need the actual distinct rows
  • Use approx_count_distinct() for datasets with >10 million rows when approximate results are acceptable
  • Consider caching DataFrames if you’re performing multiple distinct count operations: df.cache()

Optimize for skewed data:

When certain values appear far more frequently than others (data skew), distinct counting can create hot partitions. Consider salting techniques or using repartition() to distribute the load evenly across your cluster.

Mastering these distinct counting techniques enables you to write efficient PySpark queries that scale from prototype to production, handling everything from thousands to billions of rows with appropriate performance characteristics.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.