PySpark - Collect List and Collect Set
When working with grouped data in PySpark, you often need to aggregate multiple rows into a single array column. While functions like `sum()` and `count()` reduce values to scalars, `collect_list()`...
Key Insights
collect_list()preserves all values including duplicates and maintains insertion order, whilecollect_set()returns only unique values with no guaranteed ordering—choose based on whether you need complete history or distinct items- Both functions can cause memory issues with skewed data since all values for a group must fit in a single executor’s memory; consider limiting collection size or using alternative approaches for highly skewed datasets
- Window functions combined with collect operations enable advanced patterns like running aggregations and partitioned collections without full groupBy operations, giving you finer control over data aggregation scope
Understanding Collection Aggregations in PySpark
When working with grouped data in PySpark, you often need to aggregate multiple rows into a single array column. While functions like sum() and count() reduce values to scalars, collect_list() and collect_set() preserve the individual values by collecting them into arrays. This is essential for tasks like building user activity histories, creating feature vectors for machine learning, or consolidating related records.
The fundamental difference is simple: collect_list() keeps everything including duplicates, while collect_set() returns only unique values. Let’s see why this matters with a practical example:
from pyspark.sql import SparkSession
from pyspark.sql.functions import collect_list, collect_set
spark = SparkSession.builder.appName("CollectExample").getOrCreate()
# Sample transaction data
data = [
("user1", "laptop", 1200),
("user1", "mouse", 25),
("user1", "laptop", 1200), # duplicate purchase
("user2", "keyboard", 75),
("user2", "mouse", 25),
("user3", "laptop", 1200)
]
df = spark.createDataFrame(data, ["user_id", "product", "amount"])
df.show()
This gives us a typical transactional dataset where users may have multiple purchases, including duplicate items. Now we need to aggregate these purchases per user.
collect_list() - Preserving All Values
The collect_list() function aggregates all values in a group into an array, preserving duplicates and maintaining the order in which rows are encountered. This is crucial when the sequence or frequency of events matters.
# Collect all products purchased by each user
user_purchases = df.groupBy("user_id").agg(
collect_list("product").alias("all_products"),
collect_list("amount").alias("all_amounts")
)
user_purchases.show(truncate=False)
# Output:
# +-------+-------------------------+----------------+
# |user_id|all_products |all_amounts |
# +-------+-------------------------+----------------+
# |user1 |[laptop, mouse, laptop] |[1200, 25, 1200]|
# |user2 |[keyboard, mouse] |[75, 25] |
# |user3 |[laptop] |[1200] |
# +-------+-------------------------+----------------+
Notice that user1’s duplicate laptop purchase is preserved. This is exactly what you want when analyzing purchase patterns, calculating total spending, or tracking behavioral sequences.
Handling Null Values
collect_list() includes null values in the resulting array. If your data contains nulls, they’ll appear in the collected array:
data_with_nulls = [
("user1", "laptop"),
("user1", None),
("user1", "mouse"),
("user2", "keyboard")
]
df_nulls = spark.createDataFrame(data_with_nulls, ["user_id", "product"])
df_nulls.groupBy("user_id").agg(
collect_list("product").alias("products")
).show(truncate=False)
# Output:
# +-------+--------------------+
# |user_id|products |
# +-------+--------------------+
# |user1 |[laptop, null, mouse]|
# |user2 |[keyboard] |
# +-------+--------------------+
If you need to exclude nulls, filter them before aggregation or use array_remove() afterward.
Ordering Considerations
While collect_list() generally maintains insertion order within a partition, this isn’t guaranteed across the entire dataset unless you explicitly sort first:
from pyspark.sql.functions import col
# Ensure chronological order before collecting
ordered_purchases = (df
.orderBy("user_id", col("amount").desc())
.groupBy("user_id")
.agg(collect_list("product").alias("products_by_price"))
)
ordered_purchases.show(truncate=False)
collect_set() - Unique Values Only
When you only care about which distinct items appear in a group—not how many times or in what order—collect_set() is your function. It automatically deduplicates values and returns an unordered array.
# Collect unique products per user
unique_purchases = df.groupBy("user_id").agg(
collect_set("product").alias("unique_products")
)
unique_purchases.show(truncate=False)
# Output:
# +-------+----------------+
# |user_id|unique_products |
# +-------+----------------+
# |user1 |[laptop, mouse] | # duplicate laptop removed
# |user2 |[keyboard, mouse]|
# |user3 |[laptop] |
# +-------+----------------+
Notice user1 now has only two distinct products instead of three total purchases. This is perfect for answering questions like “What unique products has each user bought?” or creating distinct category lists.
Side-by-Side Comparison
Here’s a direct comparison showing both functions on the same data:
comparison = df.groupBy("user_id").agg(
collect_list("product").alias("all_products"),
collect_set("product").alias("unique_products"),
collect_list("amount").alias("all_amounts"),
collect_set("amount").alias("unique_amounts")
)
comparison.show(truncate=False)
# Output:
# +-------+-------------------------+----------------+----------------+---------------+
# |user_id|all_products |unique_products |all_amounts |unique_amounts |
# +-------+-------------------------+----------------+----------------+---------------+
# |user1 |[laptop, mouse, laptop] |[laptop, mouse] |[1200, 25, 1200]|[1200, 25] |
# |user2 |[keyboard, mouse] |[keyboard, mouse]|[75, 25] |[75, 25] |
# |user3 |[laptop] |[laptop] |[1200] |[1200] |
# +-------+-------------------------+----------------+----------------+---------------+
This clearly illustrates the deduplication behavior of collect_set() versus the complete preservation of collect_list().
Window Functions with Collect Operations
Beyond simple groupBy operations, you can use collect functions with window specifications for more sophisticated aggregations. This lets you partition and order data without collapsing rows.
from pyspark.sql.window import Window
# Create a window partitioned by user, ordered by amount
window_spec = Window.partitionBy("user_id").orderBy(col("amount").desc())
# Collect products in order of price within each user's partition
windowed_df = df.withColumn(
"products_by_price",
collect_list("product").over(window_spec)
)
windowed_df.show(truncate=False)
# Each row shows accumulated products up to that point
This pattern is powerful for creating running collections or when you need to maintain the original row structure while adding aggregated context:
from pyspark.sql.functions import row_number
# Number transactions and collect all previous products
window_ordered = Window.partitionBy("user_id").orderBy("amount")
enriched = df.withColumn(
"transaction_number",
row_number().over(window_ordered)
).withColumn(
"purchase_history",
collect_list("product").over(window_ordered)
)
enriched.show(truncate=False)
This creates a running history where each row shows all products purchased up to and including that transaction.
Performance Considerations and Best Practices
While collect functions are powerful, they come with important performance implications. All values for a group must fit in memory on a single executor, which can cause issues with skewed data.
Memory and Skew Management
If one user has 10,000 transactions while others have 10, that user’s data could overwhelm an executor:
# Risky with skewed data
all_events = df.groupBy("user_id").agg(
collect_list("product").alias("products") # Could be huge!
)
# Safer: limit collection size
from pyspark.sql.functions import size, slice
limited_events = df.groupBy("user_id").agg(
collect_list("product").alias("all_products")
).withColumn(
"recent_products",
slice("all_products", 1, 100) # Keep only first 100
)
collect_set() vs array_distinct()
Sometimes you have a choice between using collect_set() or collect_list() followed by array_distinct():
from pyspark.sql.functions import array_distinct
# Two approaches to get unique values
approach1 = df.groupBy("user_id").agg(
collect_set("product").alias("unique_products")
)
approach2 = df.groupBy("user_id").agg(
array_distinct(collect_list("product")).alias("unique_products")
)
Use collect_set() when you know you want distinct values from the start—it’s more efficient. Use array_distinct() when you need to deduplicate an already-collected array or when combining multiple operations.
Common Use Cases and Patterns
User Purchase History Aggregation
Building comprehensive user profiles is a classic use case:
user_profiles = df.groupBy("user_id").agg(
collect_set("product").alias("products_purchased"),
collect_list("amount").alias("transaction_amounts"),
sum("amount").alias("total_spent")
)
user_profiles.show(truncate=False)
Event Log Consolidation
Consolidating logs or events into single records:
event_data = [
("session1", "page_view", "homepage"),
("session1", "click", "product_link"),
("session1", "page_view", "product_page"),
("session2", "page_view", "homepage")
]
events_df = spark.createDataFrame(event_data, ["session_id", "event_type", "page"])
session_summary = events_df.groupBy("session_id").agg(
collect_list("event_type").alias("event_sequence"),
collect_set("page").alias("pages_visited")
)
session_summary.show(truncate=False)
Creating Delimited Strings
Sometimes you need the collected values as a delimited string rather than an array:
from pyspark.sql.functions import concat_ws
# Convert array to comma-separated string
product_strings = df.groupBy("user_id").agg(
concat_ws(", ", collect_set("product")).alias("product_list")
)
product_strings.show(truncate=False)
# Output:
# +-------+---------------+
# |user_id|product_list |
# +-------+---------------+
# |user1 |laptop, mouse |
# |user2 |keyboard, mouse|
# |user3 |laptop |
# +-------+---------------+
This is particularly useful for export operations or when interfacing with systems that expect delimited values.
Choosing the Right Function
Your choice between collect_list() and collect_set() should be driven by your specific requirements:
- Use
collect_list()when order matters, duplicates are meaningful, or you’re tracking sequences - Use
collect_set()when you only need to know what distinct values exist - Consider memory implications—both functions can be expensive with large groups
- For very large collections, explore alternatives like approximate algorithms or sampling
- Always test with production-scale data to identify skew issues before deployment
Understanding these collection functions unlocks powerful aggregation patterns in PySpark, enabling you to transform granular transactional data into rich, aggregated features suitable for analytics and machine learning pipelines.