Spark Aggregations


Aggregating is the act of collecting something together and is a cornerstone of big data analytics. In an aggregation, you will specify a key or grouping and an aggregation function that specifies how you should transform one or more columns. This function must produce one result for each group, given multiple input values. Spark’s aggregation capabilities are sophisticated and mature, with a variety of different use cases and possibilities. In general, you use aggregations to summarize numerical data usually by means of some grouping. This might be a summation, a product, or simple counting. Also, with Spark you can aggregate any kind of value into an array, list, or map, as we will see in “Aggregating to Complex Types”.

In addition to working with any type of values, Spark also allows us to create the following groupings types:

  • The simplest grouping is to just summarize a complete DataFrame by performing an aggregation in a select statement
  • A “group by” allows you to specify one or more keys as well as one or more aggregation functions to transform the value columns
  • A “window” gives you the ability to specify one or more keys as well as one or more aggregation functions to transform the value columns. However, the rows input to the function are somehow related to the current row
  • A “grouping set,” which you can use to aggregate at multiple different levels. Grouping sets are available as a primitive in SQL and via rollups and cubes in DataFrames
  • A “rollup” makes it possible for you to specify one or more keys as well as one or more aggregation functions to transform the value columns, which will be summarized hierarchically
  • A “cube” allows you to specify one or more keys as well as one or more aggregation functions to transform the value columns, which will be summarized across all combinations of columns

Each grouping returns a RelationalGroupedDataset on which we specify our aggregations.

An important thing to consider is how exact you need an answer to be. When performing calculations over big data, it can be quite expensive to get an exact answer to a question, and it’s often much cheaper to simply request an approximate to a reasonable degree of accuracy. You’ll note that we mention some approximation functions throughout the book and oftentimes this is a good opportunity to improve the speed and execution of your Spark jobs, especially for interactive and ad hoc analysis.

Let’s begin by reading in our data on purchases, repartitioning the data to have far fewer partitions (because we know it’s a small volume of data stored in a lot of small files), and caching the results for rapid access:

val df =
 .option("header", "true")
 .option("inferSchema", "true")


// As mentioned, basic aggregations apply to an entire DataFrame. The simplest example is the count method:

df = [InvoiceNo: string, StockCode: string ... 6 more fields]


Aggregate Functions

All aggregations are available as functions, in addition to the special cases that can appear on DataFrames or via .stat. You can find most aggregation functions in the org.apache.spark.sql.functions package.


We can do one of two things: specify a specific column to count, or all the columns by using count(*) or count(1) to represent that we want to count every row as the literal one, as shown in this example:

import org.apache.spark.sql.functions.count"StockCode"))
|          541909|


Sometimes, the total number is not relevant; rather, it’s the number of unique groups that you want. To get this number, you can use the countDistinct function. This is a bit more relevant for individual columns:

import org.apache.spark.sql.functions.countDistinct"StockCode"))
|count(DISTINCT StockCode)|
|                     4070|


There are times when an approximation to a certain degree of accuracy will work just fine, and for that, you can use the approx_count_distinct function. You will notice that approx_count_distinct took another parameter with which you can specify the maximum estimation error allowed. In this case, we specified a rather large error and thus receive an answer that is quite far off but does complete more quickly than countDistinct. You will see much greater performance gains with larger datasets.

import org.apache.spark.sql.functions.approx_count_distinct"StockCode", 0.1))
|                            3364|

first and last

You can get the first and last values from a DataFrame by using these two obviously named functions. This will be based on the rows in the DataFrame, not on the values in the DataFrame:

import org.apache.spark.sql.functions.{first, last}"StockCode"), last("StockCode"))
|first(StockCode, false)|last(StockCode, false)|
|                 85123A|                 22138|

min and max

To extract the minimum and maximum values from a DataFrame, use the min and max functions:

import org.apache.spark.sql.functions.{min, max}"Quantity"), max("Quantity"))
|       -80995|        80995|


Another simple task is to add all the values in a row using the sum function:

import org.apache.spark.sql.functions.sum"Quantity"))
|      5176450|


In addition to summing a total, you also can sum a distinct set of values by using the sumDistinct function:

import org.apache.spark.sql.functions.sumDistinct"Quantity"))
|sum(DISTINCT Quantity)|
|                 29310|


Although you can calculate average by dividing sum by count, Spark provides an easier way to get that value via the avg or mean functions. In this example, we use alias in order to more easily reuse these columns later:

import org.apache.spark.sql.functions.{sum, count, avg, expr}
|(total_purchases / total_transactions)|   avg_purchases|  mean_purchases|
|                      9.55224954743324|9.55224954743324|9.55224954743324|

Variance and Standard Deviation

Calculating the mean naturally brings up questions about the variance and standard deviation. These are both measures of the spread of the data around the mean. The variance is the average of the squared differences from the mean, and the standard deviation is the square root of the variance. You can calculate these in Spark by using their respective functions. However, something to note is that Spark has both the formula for the sample standard deviation as well as the formula for the population standard deviation. These are fundamentally different statistical formulae, and we need to differentiate between them. By default, Spark performs the formula for the sample standard deviation or variance if you use the variance or stddev functions.

You can also specify these explicitly or refer to the population standard deviation or variance:

import org.apache.spark.sql.functions.{var_pop, stddev_pop}
import org.apache.spark.sql.functions.{var_samp, stddev_samp}
    var_pop("Quantity"), var_samp("Quantity"),
    stddev_pop("Quantity"), stddev_samp("Quantity")
|47559.30364660923| 47559.39140929892|  218.08095663447835|   218.08115785023455|

skewness and kurtosis

We discussed single column aggregations, but some functions compare the interactions of the values in two difference columns together. Two of these functions are cov and corr, for covariance and correlation, respectively. Correlation measures the Pearson correlation coefficient, which is scaled between –1 and +1. The covariance is scaled according to the inputs in the data.

Like the var function, covariance can be calculated either as the sample covariance or the population covariance. Therefore it can be important to specify which formula you want to use. Correlation has no notion of this and therefore does not have calculations for population or sample. Here’s how they work:

import org.apache.spark.sql.functions.{corr, covar_pop, covar_samp}
    corr("InvoiceNo", "Quantity"), covar_samp("InvoiceNo", "Quantity"), covar_pop("InvoiceNo", "Quantity")
|corr(InvoiceNo, Quantity)|covar_samp(InvoiceNo, Quantity)|covar_pop(InvoiceNo, Quantity)|
|     4.912186085640497E-4|             1052.7280543915997|            1052.7260778754955|

Aggregating to Complex Types

In Spark, you can perform aggregations not just of numerical values using formulas, you can also perform them on complex types. For example, we can collect a list of values present in a given column or only the unique values by collecting to a set.

You can use this to carry out some more programmatic access later on in the pipeline or pass the entire collection in a user-defined function (UDF):

import org.apache.spark.sql.functions.{collect_set, collect_list}

df.agg(collect_set("Country"), collect_list("Country"))
|[Portugal, Italy,...| [United Kingdom, ...|


A more common task is to perform calculations based on groups in the data. This is typically done on categorical data for which we group our data on one column and perform some calculations on the other columns that end up in that group.

The best way to explain this is to begin performing some groupings. The first will be a count, just as we did before. We will group by each unique invoice number and get the count of items on that invoice. Note that this returns another DataFrame and is lazily performed.

We do this grouping in two phases. First we specify the column(s) on which we would like to group, and then we specify the aggregation(s). The first step returns a RelationalGroupedDataset, and the second step returns a DataFrame.

As mentioned, we can specify any number of columns on which we want to group:

df.groupBy("InvoiceNo", "CustomerId").count().show()
|   536846|     14573|   76|
|   537026|     12395|   12|
|   537883|     14437|    5|
|   538068|     17978|   12|
|   538279|     14952|    7|
|   538800|     16458|   10|
|   538942|     17346|   12|
|  C539947|     13854|    1|
|   540096|     13253|   16|
|   540530|     14755|   27|
|   541225|     14099|   19|
|   541978|     13551|    4|
|   542093|     17677|   16|
|   543188|     12567|   63|
|   543590|     17377|   19|
|  C543757|     13115|    1|
|  C544318|     12989|    1|
|   544578|     12365|    1|
|   545165|     16339|   20|
|   545289|     14732|   30|
only showing top 20 rows

Grouping with Expressions

As we saw earlier, counting is a bit of a special case because it exists as a method. For this, usually we prefer to use the count function. Rather than passing that function as an expression into a select statement, we specify it as within agg. This makes it possible for you to pass-in arbitrary expressions that just need to have some aggregation specified. You can even do things like alias a column after transforming it for later use in your data flow:

import org.apache.spark.sql.functions.count

|   536596|   6|              6|
|   536938|  14|             14|
|   537252|   1|              1|
|   537691|  20|             20|
|   538041|   1|              1|
|   538184|  26|             26|
|   538517|  53|             53|
|   538879|  19|             19|
|   539275|   6|              6|
|   539630|  12|             12|
|   540499|  24|             24|
|   540540|  22|             22|
|  C540850|   1|              1|
|   540976|  48|             48|
|   541432|   4|              4|
|   541518| 101|            101|
|   541783|  35|             35|
|   542026|   9|              9|
|   542375|   6|              6|
|  C542604|   8|              8|
only showing top 20 rows

Grouping with Maps

It can be easier to specify your transformations as a series of Maps for which the key is the column, and the value is the aggregation function (as a string) that you would like to perform. You can reuse multiple column names if you specify them inline, as well:

 .agg("Quantity" -> "avg", "Quantity" -> "stddev_pop")
|InvoiceNo|     avg(Quantity)|stddev_pop(Quantity)|
|   536596|               1.5|  1.1180339887498947|
|   536938|33.142857142857146|  20.698023172885524|
|   537252|              31.0|                 0.0|
|   537691|              8.15|   5.597097462078001|
|   538041|              30.0|                 0.0|
|   538184|12.076923076923077|   8.142590198943392|
|   538517|3.0377358490566038|  2.3946659604837897|
|   538879|21.157894736842106|  11.811070444356483|
|   539275|              26.0|  12.806248474865697|
|   539630|20.333333333333332|  10.225241100118645|
|   540499|              3.75|  2.6653642652865788|
|   540540|2.1363636363636362|  1.0572457590557278|
|  C540850|              -1.0|                 0.0|
|   540976|10.520833333333334|   6.496760677872902|
|   541432|             12.25|  10.825317547305483|
|   541518| 23.10891089108911|  20.550782784878713|
|   541783|11.314285714285715|   8.467657556242811|
|   542026| 7.666666666666667|   4.853406592853679|
|   542375|               8.0|  3.4641016151377544|
|  C542604|              -8.0|  15.173990905493518|
only showing top 20 rows

Window Functions

You can also use window functions to carry out some unique aggregations by either computing some aggregation on a specific “window” of data, which you define by using a reference to the current data. This window specification determines which rows will be passed in to this function. Now this is a bit abstract and probably similar to a standard group-by, so let’s differentiate them a bit more.

A group-by takes data, and every row can go only into one grouping. A window function calculates a return value for every input row of a table based on a group of rows, called a frame. Each row can fall into one or more frames. A common use case is to take a look at a rolling average of some value for which each row represents one day. If you were to do this, each row would end up in seven different frames. We cover defining frames a little later, but for your reference, Spark supports three kinds of window functions: ranking functions, analytic functions, and aggregate functions.

To demonstrate, we will add a date column that will convert our invoice date into a column that contains only date information (not time information, too):

import org.apache.spark.sql.functions.{col, to_date}

val dfWithDate = df.withColumn("date", to_date(col("InvoiceDate"), "MM/dd/yyyy H:mm"))
dfWithDate = [InvoiceNo: string, StockCode: string ... 7 more fields]

[InvoiceNo: string, StockCode: string ... 7 more fields]

The first step to a window function is to create a window specification. Note that the partition by is unrelated to the partitioning scheme concept that we have covered thus far. It’s just a similar concept that describes how we will be breaking up our group. The ordering determines the ordering within a given partition, and, finally, the frame specification (the rowsBetween statement) states which rows will be included in the frame based on its reference to the current input row. In the following example, we look at all previous rows up to the current row:

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.col

val windowSpec = Window
    .partitionBy("CustomerID", "date")
    .rowsBetween(Window.unboundedPreceding, Window.currentRow)
windowSpec = org.apache.spark.sql.expressions.WindowSpec@e25203


Now we want to use an aggregation function to learn more about each specific customer. An example might be establishing the maximum purchase quantity over all time. To answer this, we use the same aggregation functions that we saw earlier by passing a column name or expression. In addition, we indicate the window specification that defines to which frames of data this function will apply:

import org.apache.spark.sql.functions.max

val maxPurchaseQuantity = max(col("Quantity")).over(windowSpec)


You will notice that this returns a column (or expressions). We can now use this in a DataFrame select statement. Before doing so, though, we will create the purchase quantity rank. To do that we use the dense_rank function to determine which date had the maximum purchase quantity for every customer. We use dense_rank as opposed to rank to avoid gaps in the ranking sequence when there are tied values (or in our case, duplicate rows):

import org.apache.spark.sql.functions.{dense_rank, rank}

val purchaseDenseRank = dense_rank().over(windowSpec)
val purchaseRank = rank().over(windowSpec)

// This also returns a column that we can use in select statements. Now we can perform a select to view the calculated window values:

import org.apache.spark.sql.functions.col

    .where("CustomerID IS NOT NULL")
|CustomerID|      date|Quantity|quantityRank|quantityDenseRank|maxPurchaseQuantity|
|     12346|2011-01-18|   74215|           1|                1|              74215|
|     12346|2011-01-18|  -74215|           2|                2|              74215|
|     12347|2010-12-07|      36|           1|                1|                 36|
|     12347|2010-12-07|      30|           2|                2|                 36|
|     12347|2010-12-07|      24|           3|                3|                 36|
|     12347|2010-12-07|      12|           4|                4|                 36|
|     12347|2010-12-07|      12|           4|                4|                 36|
|     12347|2010-12-07|      12|           4|                4|                 36|
|     12347|2010-12-07|      12|           4|                4|                 36|
|     12347|2010-12-07|      12|           4|                4|                 36|
|     12347|2010-12-07|      12|           4|                4|                 36|
|     12347|2010-12-07|      12|           4|                4|                 36|
|     12347|2010-12-07|      12|           4|                4|                 36|
|     12347|2010-12-07|      12|           4|                4|                 36|
|     12347|2010-12-07|      12|           4|                4|                 36|
|     12347|2010-12-07|      12|           4|                4|                 36|
|     12347|2010-12-07|      12|           4|                4|                 36|
|     12347|2010-12-07|      12|           4|                4|                 36|
|     12347|2010-12-07|       6|          17|                5|                 36|
|     12347|2010-12-07|       6|          17|                5|                 36|
only showing top 20 rows



Grouping Sets

Sometimes we want something a bit more complete—an aggregation across multiple groups. We achieve this by using grouping sets. Grouping sets are a low-level tool for combining sets of aggregations together. They give you the ability to create arbitrary aggregation in their group-by statements.

Let’s work through an example to gain a better understanding. Here, we would like to get the total quantity of all stock codes and customers. To do so, we’ll use the following SQL expression:

val dfNoNull = dfWithDate.drop()
dfNoNull = [InvoiceNo: string, StockCode: string ... 7 more fields]

[InvoiceNo: string, StockCode: string ... 7 more fields]


When we set our grouping keys of multiple columns, Spark looks at those as well as the actual combinations that are visible in the dataset. A rollup is a multidimensional aggregation that performs a variety of group-by style calculations for us.

Let’s create a rollup that looks across time (with our new Date column) and space (with the Country column) and creates a new DataFrame that includes the grand total over all dates, the grand total for each date in the DataFrame, and the subtotal for each country on each date in the DataFrame:

val rolledUpDF = dfNoNull
    .rollup("Date", "Country")
    .selectExpr("Date", "Country", "`sum(Quantity)` as total_quantity")

// Now where you see the null values is where you’ll find the grand totals. A null in both rollup columns specifies the grand total across both of those columns:

rolledUpDF.where("Country IS NULL").show()

rolledUpDF.where("Date IS NULL").show()
|      Date|       Country|total_quantity|
|      null|          null|       5176450|
|2010-12-01|       Germany|           117|
|2010-12-01|        France|           449|
|2010-12-01|          EIRE|           243|
|2010-12-01|United Kingdom|         23949|
|2010-12-01|     Australia|           107|
|2010-12-01|          null|         26814|
|2010-12-01|        Norway|          1852|
|2010-12-01|   Netherlands|            97|
|2010-12-02|       Germany|           146|
|2010-12-02|          null|         21023|
|2010-12-02|          EIRE|             4|
|2010-12-02|United Kingdom|         20873|
|2010-12-03|         Italy|           164|
|2010-12-03|         Spain|           400|
|2010-12-03|       Germany|           170|
|2010-12-03|          null|         14830|
|2010-12-03|       Belgium|           528|
|2010-12-03|   Switzerland|           110|
|2010-12-03|        Poland|           140|
only showing top 20 rows

|      Date|Country|total_quantity|
|      null|   null|       5176450|
|2010-12-01|   null|         26814|
|2010-12-02|   null|         21023|
|2010-12-03|   null|         14830|
|2010-12-05|   null|         16395|
|2010-12-06|   null|         21419|
|2010-12-07|   null|         24995|
|2010-12-08|   null|         22741|
|2010-12-09|   null|         18431|
|2010-12-10|   null|         20297|
|2010-12-12|   null|         10565|
|2010-12-13|   null|         17623|
|2010-12-14|   null|         20098|
|2010-12-15|   null|         18229|
|2010-12-16|   null|         29632|
|2010-12-17|   null|         16069|
|2010-12-19|   null|          3795|
|2010-12-20|   null|         14965|
|2010-12-21|   null|         15467|
|2010-12-22|   null|          3192|
only showing top 20 rows

|null|   null|       5176450|

rolledUpDF = [Date: date, Country: string ... 1 more field]

[Date: date, Country: string ... 1 more field]


A cube takes the rollup to a level deeper. Rather than treating elements hierarchically, a cube does the same thing across all dimensions. This means that it won’t just go by date over the entire time period, but also the country.

  • The total across all dates and countries
  • The total for each date across all countries
  • The total for each country on each date
  • The total for each country across all dates

The method call is quite similar, but instead of calling rollup, we call cube.

This is a quick and easily accessible summary of nearly all of the information in our table, and it’s a great way to create a quick summary table that others can use later on.

    .cube("Date", "Country")
    .select("Date", "Country", "sum(Quantity)")
|Date|             Country|sum(Quantity)|
|null|            Portugal|        16180|
|null|               Japan|        25218|
|null|           Australia|        83653|
|null|             Germany|       117448|
|null|             Lebanon|          386|
|null|                 RSA|          352|
|null|             Finland|        10666|
|null|              Cyprus|         6317|
|null|         Unspecified|         3300|
|null|                null|      5176450|
|null|               Spain|        26824|
|null|                 USA|         1034|
|null|           Hong Kong|         4769|
|null|           Singapore|         5234|
|null|             Denmark|         8188|
|null|     Channel Islands|         9479|
|null|  European Community|          497|
|null|United Arab Emirates|          982|
|null|              Norway|        19247|
|null|      Czech Republic|          592|
only showing top 20 rows

Grouping Metadata

Sometimes when using cubes and rollups, you want to be able to query the aggregation levels so that you can easily filter them down accordingly. We can do this by using the grouping_id, which gives us a column specifying the level of aggregation that we have in our result set. The query in the example that follows returns four distinct grouping IDs:

import org.apache.spark.sql.functions.{grouping_id, sum, expr}

    .cube("customerId", "stockCode")
    .agg(grouping_id(), sum("Quantity"))
|      null|     null|            3|      5176450|
|      null|    23217|            2|         1309|
|      null|   90059E|            2|           19|
|      null|    22295|            2|         2795|
|      null|    22919|            2|         1745|
|      null|    22207|            2|         1259|
|      null|    22265|            2|          540|
|      null|    84670|            2|           23|
|      null|   51014C|            2|         2505|
|      null|    22459|            2|          183|
|      null|   47590B|            2|         2244|
|      null|    22275|            2|           69|
|      null|    21201|            2|          849|
|      null|   17013D|            2|          506|
|      null|   84931A|            2|          135|
|      null|    21946|            2|           86|
|      null|    22522|            2|          795|
|      null|    21631|            2|          169|
|      null|    85200|            2|          189|
|      null|    23003|            2|        -8516|
only showing top 20 rows


Pivots make it possible for you to convert a row into a column. For example, in our current data we have a Country column. With a pivot, we can aggregate according to some function for each of those given countries and display them in an easy-to-query way:

val pivoted = dfWithDate
pivoted = [date: date, Australia_sum(Quantity): bigint ... 113 more fields]

[date: date, Australia_sum(Quantity): bigint ... 113 more fields]

This DataFrame will now have a column for every combination of country, numeric variable, and a column specifying the date. For example, for USA we have the following columns: USA_sum(Quantity), USA_sum(UnitPrice), USA_sum(CustomerID). This represents one for each numeric column in our dataset (because we just performed an aggregation over all of them).

    .where("date > '2011-12-05'")
    .select("date", "`USA_sum(Quantity)`")
|      date|USA_sum(Quantity)|
|2011-12-06|             null|
|2011-12-09|             null|
|2011-12-08|             -196|
|2011-12-07|             null|

User-Defined Aggregation Functions

User-defined aggregation functions (UDAFs) are a way for users to define their own aggregation functions based on custom formulae or business rules. You can use UDAFs to compute custom calculations over groups of input data (as opposed to single rows). Spark maintains a single AggregationBuffer to store intermediate results for every group of input data. To create a UDAF, you must inherit from the UserDefinedAggregateFunction base class and implement the following methods:

  • inputSchema represents input arguments as a StructType
  • bufferSchema represents intermediate UDAF results as a StructType
  • dataType represents the return DataType
  • deterministic is a Boolean value that specifies whether this UDAF will return the same result for a given input
  • initialize allows you to initialize values of an aggregation buffer
  • update describes how you should update the internal buffer based on a given row
  • merge describes how two aggregtion buffers should be merged
  • evaluate will generate the final result of the aggregation

The following example implements a BoolAnd, which will inform us whether all the rows (for a given column) are true; if they’re not, it will return false:

import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

class BoolAnd extends UserDefinedAggregateFunction {
    def inputSchema: StructType = StructType(
        StructField("value", BooleanType) :: Nil)
    def bufferSchema: StructType = StructType(
        StructField("result", BooleanType) :: Nil
    def dataType: DataType = BooleanType
    def deterministic: Boolean = true
    def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer(0) = true
    def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        buffer(0) = buffer.getAs[Boolean](0) && input.getAs[Boolean](0)
    def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer1(0) = buffer1.getAs[Boolean](0) && buffer2.getAs[Boolean](0)
    def evaluate(buffer: Row): Any = {

// Now, we simply instantiate our class and/or register it as a function:

val ba = new BoolAnd
spark.udf.register("booland", ba)
import org.apache.spark.sql.functions._
    .selectExpr("explode(array(TRUE, TRUE, TRUE)) as t")
    .selectExpr("explode(array(TRUE, TRUE, TRUE)) as f", "t")
    .select(ba(col("t")), expr("booland(f)"))
|      true|      true|

defined class BoolAnd
ba = BoolAnd@355368ef