How to Use Array Functions in PySpark

Arrays in PySpark represent ordered collections of elements with the same data type, stored within a single column. You'll encounter them constantly when working with JSON data, denormalized schemas,...

Key Insights

  • Array functions in PySpark eliminate the need for expensive explode-aggregate patterns, letting you manipulate nested data directly within DataFrame operations
  • The transform() function paired with lambda expressions is your most powerful tool for element-wise array manipulation, but watch for performance pitfalls with very large arrays
  • Understanding when to use arrays versus normalized rows is critical—arrays excel at read-heavy workloads but complicate joins and updates

Introduction to Array Types in PySpark

Arrays in PySpark represent ordered collections of elements with the same data type, stored within a single column. You’ll encounter them constantly when working with JSON data, denormalized schemas, or any situation where a one-to-many relationship makes more sense as nested data than as separate rows.

The most common scenarios for array columns include ingesting semi-structured data from APIs, storing tags or categories, maintaining event sequences, and representing hierarchical relationships without expensive joins.

Let’s start by creating a DataFrame with array columns:

from pyspark.sql import SparkSession
from pyspark.sql.functions import array, col, lit
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, IntegerType

spark = SparkSession.builder.appName("ArrayFunctions").getOrCreate()

# Method 1: Using array() function
df = spark.createDataFrame([
    (1, "Alice", ["python", "scala", "java"]),
    (2, "Bob", ["python", "rust"]),
    (3, "Carol", ["java", "kotlin", "scala", "go"]),
], ["id", "name", "languages"])

# Method 2: From JSON (common real-world scenario)
json_data = [
    '{"user_id": 1, "scores": [85, 92, 78], "tags": ["premium", "active"]}',
    '{"user_id": 2, "scores": [90, 88], "tags": ["trial"]}',
]
df_json = spark.read.json(spark.sparkContext.parallelize(json_data))
df_json.printSchema()

The schema output confirms PySpark correctly infers array types from JSON:

root
 |-- scores: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- tags: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- user_id: long (nullable = true)

Basic Array Operations

Before transforming arrays, you need to inspect and filter them. PySpark provides straightforward functions for these fundamental operations.

from pyspark.sql.functions import size, array_contains, element_at

df = spark.createDataFrame([
    (1, "Alice", ["python", "scala", "java"]),
    (2, "Bob", ["python", "rust"]),
    (3, "Carol", ["java", "kotlin", "scala", "go"]),
    (4, "Dave", []),
], ["id", "name", "languages"])

# Get array length
df.select("name", "languages", size("languages").alias("num_languages")).show()

# Filter rows where array contains specific value
python_devs = df.filter(array_contains(col("languages"), "python"))
python_devs.show()

# Extract specific elements (1-indexed for element_at, 0-indexed for bracket notation)
df.select(
    "name",
    element_at("languages", 1).alias("first_lang"),  # 1-indexed
    element_at("languages", -1).alias("last_lang"),  # negative index from end
    col("languages")[0].alias("first_bracket")       # 0-indexed bracket notation
).show()

A critical gotcha: element_at() uses 1-based indexing while bracket notation uses 0-based indexing. Pick one convention and stick with it across your codebase.

Transforming Arrays

The transform() function is where array manipulation gets powerful. It applies a lambda expression to each element, similar to Python’s map().

from pyspark.sql.functions import transform, array_distinct, array_sort, array_remove, reverse, upper

# Sample data with scores
df = spark.createDataFrame([
    (1, [10, 20, 30, 20]),
    (2, [5, 15, 25]),
    (3, [100, 50, 75, 50]),
], ["id", "scores"])

# Double all values using transform
df.select(
    "scores",
    transform("scores", lambda x: x * 2).alias("doubled")
).show()

# Remove duplicates
df.select(
    "scores",
    array_distinct("scores").alias("unique_scores")
).show()

# Sort array (ascending by default)
df.select(
    "scores",
    array_sort("scores").alias("sorted_asc"),
    reverse(array_sort("scores")).alias("sorted_desc")
).show()

# Remove specific value
df.select(
    "scores",
    array_remove("scores", 20).alias("without_20")
).show()

For string arrays, you can chain transformations:

df_langs = spark.createDataFrame([
    (1, ["Python", "SCALA", "java"]),
], ["id", "languages"])

# Normalize to lowercase and sort
df_langs.select(
    transform("languages", lambda x: upper(x)).alias("upper_langs"),
    array_sort(transform("languages", lambda x: upper(x))).alias("sorted_upper")
).show(truncate=False)

Combining and Splitting Arrays

When working with multiple array columns, you’ll frequently need to merge, compare, or slice them.

from pyspark.sql.functions import array_union, array_intersect, array_except, concat, flatten, slice

df = spark.createDataFrame([
    (1, ["a", "b", "c"], ["b", "c", "d"]),
    (2, ["x", "y"], ["y", "z", "w"]),
], ["id", "arr1", "arr2"])

# Union: combine unique elements from both arrays
df.select("arr1", "arr2", array_union("arr1", "arr2").alias("union")).show()

# Intersection: elements present in both
df.select("arr1", "arr2", array_intersect("arr1", "arr2").alias("common")).show()

# Except: elements in arr1 but not in arr2
df.select("arr1", "arr2", array_except("arr1", "arr2").alias("only_in_arr1")).show()

# Concat: simple concatenation (keeps duplicates)
df.select("arr1", "arr2", concat("arr1", "arr2").alias("concatenated")).show()

# Slice: extract portion of array (1-indexed start, length)
df.select("arr1", slice("arr1", 1, 2).alias("first_two")).show()

# Flatten nested arrays
df_nested = spark.createDataFrame([
    (1, [["a", "b"], ["c", "d"]]),
], ["id", "nested"])
df_nested.select(flatten("nested").alias("flat")).show()

Exploding and Collecting Arrays

Converting between array columns and rows is essential for aggregations and joins. The explode family of functions handles this conversion.

from pyspark.sql.functions import explode, explode_outer, posexplode, collect_list, collect_set

df = spark.createDataFrame([
    (1, "Alice", ["python", "scala"]),
    (2, "Bob", ["java"]),
    (3, "Carol", []),  # empty array
], ["id", "name", "languages"])

# explode: creates one row per element (drops empty arrays)
df.select("id", "name", explode("languages").alias("language")).show()

# explode_outer: preserves rows with empty/null arrays
df.select("id", "name", explode_outer("languages").alias("language")).show()

# posexplode: includes position index
df.select("id", posexplode("languages").alias("pos", "language")).show()

The reverse operation uses aggregation functions:

# Collect back to arrays after transformation
exploded = df.select("id", "name", explode("languages").alias("lang"))
exploded = exploded.withColumn("lang_upper", upper(col("lang")))

# collect_list preserves duplicates and order
# collect_set removes duplicates
result = exploded.groupBy("id", "name").agg(
    collect_list("lang_upper").alias("languages_upper"),
    collect_set("lang_upper").alias("languages_set")
)
result.show(truncate=False)

Working with Arrays of Structs

Real-world data often contains arrays of complex objects. PySpark handles these elegantly with the filter() function and struct field access.

from pyspark.sql.functions import filter, col
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, ArrayType

# Define schema for array of structs
schema = StructType([
    StructField("user_id", IntegerType()),
    StructField("orders", ArrayType(StructType([
        StructField("order_id", StringType()),
        StructField("amount", IntegerType()),
        StructField("status", StringType())
    ])))
])

data = [
    (1, [{"order_id": "A1", "amount": 100, "status": "completed"},
         {"order_id": "A2", "amount": 50, "status": "pending"}]),
    (2, [{"order_id": "B1", "amount": 200, "status": "completed"},
         {"order_id": "B2", "amount": 75, "status": "completed"}]),
]

df = spark.createDataFrame(data, schema)

# Filter array of structs: keep only completed orders
df.select(
    "user_id",
    filter("orders", lambda x: x.status == "completed").alias("completed_orders")
).show(truncate=False)

# Extract specific field from all structs in array
df.select(
    "user_id",
    transform("orders", lambda x: x.amount).alias("all_amounts")
).show()

# Combine filter and transform: sum of completed order amounts
df.select(
    "user_id",
    aggregate(
        filter("orders", lambda x: x.status == "completed"),
        lit(0),
        lambda acc, x: acc + x.amount
    ).alias("completed_total")
).show()

Note: The aggregate() function requires importing from pyspark.sql.functions and is available in Spark 2.4+.

Performance Considerations and Best Practices

Array operations aren’t free. Here’s what you need to know:

Avoid arrays when joining frequently. Joining on array contents requires expensive operations like array_contains() in join conditions or exploding before joining. If you’re joining more than reading, normalize your data.

Watch memory with collect_list(). This function accumulates all values in memory before writing. For high-cardinality groupings with large arrays, you’ll hit memory limits. Consider window functions or iterative processing instead.

Prefer built-in functions over UDFs. A Python UDF that processes arrays serializes data between JVM and Python for every row. Native functions like transform() stay in the JVM and execute orders of magnitude faster.

# Bad: Python UDF
from pyspark.sql.functions import udf
@udf(ArrayType(IntegerType()))
def double_array_udf(arr):
    return [x * 2 for x in arr] if arr else []

# Good: Native transform
transform("scores", lambda x: x * 2)

Use array_distinct() early. If you’re going to deduplicate anyway, do it before expensive transformations to reduce the work.

Consider broadcast for small lookup arrays. When checking membership against a small set of values, broadcast the lookup values rather than using repeated array_contains() calls.

Arrays in PySpark strike a balance between query flexibility and storage efficiency. Use them for read-heavy, nested data scenarios, but don’t force them where normalized tables would serve better. Master these functions, and you’ll handle semi-structured data with confidence.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.