PySpark - Get Number of Rows in DataFrame (count)
Counting rows is one of the most fundamental operations you'll perform with PySpark DataFrames. Whether you're validating data ingestion, monitoring pipeline health, or debugging transformations,...
Key Insights
- The
count()method triggers a full DataFrame scan and can be expensive on large datasets—cache your DataFrame if you need to call count multiple times - PySpark offers multiple counting approaches including
df.count(), SQL-stylecount(*), and filtered counts withfilter().count(), each with different use cases - For approximate counts on massive datasets, consider
approx_count_distinct()which trades accuracy for significant performance gains using HyperLogLog algorithms
Understanding Row Counts in PySpark
Counting rows is one of the most fundamental operations you’ll perform with PySpark DataFrames. Whether you’re validating data ingestion, monitoring pipeline health, or debugging transformations, knowing exactly how many records you’re working with is essential. Unlike pandas where counting is nearly instantaneous on in-memory data, PySpark’s distributed nature means counting can trigger expensive cluster-wide computations.
Understanding the nuances of different counting methods, their performance characteristics, and when to use each approach will save you significant execution time and compute costs in production environments.
The Basic count() Method
The most straightforward way to get the number of rows is using the count() method directly on a DataFrame. This method returns a simple integer representing the total row count.
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
# Initialize Spark session
spark = SparkSession.builder.appName("CountExample").getOrCreate()
# Create sample DataFrame
data = [
(1, "Alice", 34),
(2, "Bob", 45),
(3, "Charlie", 29),
(4, "Diana", 38),
(5, "Eve", 41)
]
df = spark.createDataFrame(data, ["id", "name", "age"])
# Basic count operation
total_rows = df.count()
print(f"Total rows: {total_rows}") # Output: Total rows: 5
The count() method is an action, not a transformation. This means it immediately triggers execution of the entire DataFrame computation graph. If your DataFrame is the result of multiple transformations, all those operations will execute when you call count().
For large datasets spanning terabytes, a single count operation can take minutes or even hours. This isn’t a flaw—it’s the reality of scanning distributed data across hundreds of partitions.
Alternative Counting Methods
PySpark provides several other ways to count rows, each with subtle differences in syntax and use cases.
Using select with count()
You can use SQL-style counting with the select() method combined with the count() function from pyspark.sql.functions:
from pyspark.sql.functions import count, lit
# Count using select with count(*)
row_count = df.select(count("*")).collect()[0][0]
print(f"Count using select: {row_count}")
# Count using a specific column
row_count_col = df.select(count("id")).collect()[0][0]
print(f"Count using column: {row_count_col}")
# Count with literal
row_count_lit = df.select(count(lit(1))).collect()[0][0]
print(f"Count using literal: {row_count_lit}")
The key difference: count("*") counts all rows including nulls, while count("column_name") counts only non-null values in that specific column. This distinction is crucial when dealing with datasets containing missing values.
SQL Queries
If you prefer SQL syntax, you can register your DataFrame as a temporary view and use standard SQL:
# Register DataFrame as temporary view
df.createOrReplaceTempView("people")
# Count using SQL
sql_count = spark.sql("SELECT COUNT(*) as total FROM people").collect()[0][0]
print(f"SQL count: {sql_count}")
# More complex SQL count
age_stats = spark.sql("""
SELECT
COUNT(*) as total,
COUNT(DISTINCT age) as unique_ages
FROM people
""").collect()[0]
print(f"Total: {age_stats['total']}, Unique ages: {age_stats['unique_ages']}")
Performance Considerations and Caching
Count operations can become a performance bottleneck if not handled carefully. Since count() is an action that triggers full computation, calling it multiple times on the same DataFrame repeats all transformations.
import time
# Create a larger DataFrame with transformations
large_df = spark.range(0, 10000000).withColumn("doubled", col("id") * 2)
# Without caching - multiple counts
start = time.time()
count1 = large_df.count()
count2 = large_df.count()
count3 = large_df.count()
no_cache_time = time.time() - start
print(f"Without cache: {no_cache_time:.2f} seconds")
# With caching - multiple counts
large_df_cached = large_df.cache()
start = time.time()
count1 = large_df_cached.count() # First count triggers caching
count2 = large_df_cached.count() # Subsequent counts use cache
count3 = large_df_cached.count()
cache_time = time.time() - start
print(f"With cache: {cache_time:.2f} seconds")
# Clean up cache
large_df_cached.unpersist()
The first count with caching takes longer because it computes and stores the data, but subsequent counts are nearly instantaneous. Use caching when you need to perform multiple operations (not just counts) on the same DataFrame.
Important: Only cache DataFrames you’ll reuse multiple times. Caching consumes cluster memory, and unnecessary caching can actually degrade performance.
Counting with Conditions
Often you need to count rows meeting specific criteria rather than the total count. PySpark makes this straightforward with filter() or where() combined with count().
# Count rows where age > 35
older_count = df.filter(col("age") > 35).count()
print(f"People older than 35: {older_count}")
# Using where() - functionally identical to filter()
younger_count = df.where(col("age") <= 35).count()
print(f"People 35 or younger: {younger_count}")
# Multiple conditions
specific_count = df.filter(
(col("age") > 30) & (col("name").startswith("A"))
).count()
print(f"People over 30 whose name starts with A: {specific_count}")
# Count nulls in a column
df_with_nulls = spark.createDataFrame([
(1, "Alice", 34),
(2, None, 45),
(3, "Charlie", None),
(4, "Diana", 38)
], ["id", "name", "age"])
null_names = df_with_nulls.filter(col("name").isNull()).count()
non_null_ages = df_with_nulls.filter(col("age").isNotNull()).count()
print(f"Null names: {null_names}, Non-null ages: {non_null_ages}")
For grouped counts, use groupBy() with count():
# Create DataFrame with categories
category_data = [
(1, "Alice", "Engineering"),
(2, "Bob", "Sales"),
(3, "Charlie", "Engineering"),
(4, "Diana", "Sales"),
(5, "Eve", "Marketing")
]
df_categories = spark.createDataFrame(category_data, ["id", "name", "department"])
# Count by department
dept_counts = df_categories.groupBy("department").count()
dept_counts.show()
# Multiple grouping columns
age_dept_data = [(1, "Alice", "Engineering", 34),
(2, "Bob", "Sales", 45),
(3, "Charlie", "Engineering", 34)]
df_multi = spark.createDataFrame(age_dept_data, ["id", "name", "dept", "age"])
multi_counts = df_multi.groupBy("dept", "age").count()
multi_counts.show()
Common Pitfalls and Best Practices
Empty DataFrames
Always handle the possibility of empty DataFrames, especially when working with filtered data:
# Safe counting with validation
def safe_count(dataframe, operation_name="operation"):
row_count = dataframe.count()
if row_count == 0:
print(f"Warning: {operation_name} returned empty DataFrame")
return row_count
# Example usage
filtered_df = df.filter(col("age") > 100) # Likely empty
count = safe_count(filtered_df, "age filter > 100")
Approximate Counts for Massive Datasets
When exact counts aren’t critical and you’re working with billions of rows, approximate counting can save enormous amounts of time:
from pyspark.sql.functions import approx_count_distinct
# For distinct counts on large datasets
large_dataset = spark.range(0, 100000000)
# Exact distinct count (slow)
exact_distinct = large_dataset.select(count("id")).collect()[0][0]
# Approximate distinct count (much faster)
approx_distinct = large_dataset.select(
approx_count_distinct("id", rsd=0.05)
).collect()[0][0]
print(f"Exact: {exact_distinct}, Approximate: {approx_distinct}")
The rsd parameter controls the relative standard deviation—lower values mean higher accuracy but slower performance. A value of 0.05 provides roughly 95% accuracy, which is acceptable for most monitoring and analytics use cases.
Avoid Count in Conditionals
Never use count() just to check if a DataFrame is empty. Use first() or take(1) instead:
# Bad - triggers full scan
if df.count() > 0:
process_data(df)
# Good - stops after finding first row
if df.first() is not None:
process_data(df)
# Also good
if len(df.take(1)) > 0:
process_data(df)
Partition Awareness
Understanding your data’s partitioning can help optimize counts:
# Check number of partitions
num_partitions = df.rdd.getNumPartitions()
print(f"DataFrame has {num_partitions} partitions")
# Repartition for better performance if needed
optimized_df = df.repartition(200) # Adjust based on cluster size
Counting is a fundamental operation in PySpark, but it’s not free. Choose your counting method based on your specific needs: use df.count() for simplicity, cache when counting repeatedly, leverage approximate counts for massive datasets, and always consider the performance implications in distributed environments. Understanding these nuances separates developers who write PySpark code from those who write efficient, production-ready PySpark applications.