Apache Spark: mapPartitions implementation in Spark in Java

Chandra Prakash
5 min readFeb 27, 2021

In this blog, we will look at the use case of mapPartitions and it’s implementation in Spark in Java API. Before going forward, please check out the below one as this blog is an extension to previous one —

Definition of mapPartitions

As per Spark doc, mapPartitions(func) is similar to map, but runs separately on each partition (block) of the RDD, so func must be of type Iterator<T> => Iterator<U> when running on an RDD of type T or the function func() accepts a pointer to a single partition (as an iterator of type T) and returns an object of type U; T and U can be any data types and they do not have to be the same. It transform an RDD[T] to RDD[U].

mapPartitions() is a very powerful, distributed and efficient Spark mapper transformation, which processes one partition (instead of each RDD element) at a time and implements Summarization Design Pattern — summarize each partition of a source RDD into a single element of the target RDD. The goal of this transformation is to process one partition at a time (although, many partitions can be processed independently and concurrently), iterate all of the partition elements and summarize the result in a compact data structure such as a dictionary, list of elements, tuples, or list of tuples. Therefore, there will be one-to-one mapping between partitions of the source RDD and the target RDD.

From Java code perspective, this API can be used as —

rdd.mapPartitions(new FlatMapFunction<Iterator<Integer>, Integer>()
{
@Override
public Iterator<Integer> call(Iterator<Integer> integerIterator) throws Exception
{
List<Integer> list = new ArrayList<>();
Iterator<Integer> itr = list.iterator();
while(itr.hasNext())
{
list.add(itr.next()+1);
}
return list.iterator();
}
});

OR

public static class MapPartitions implements FlatMapFunction<Iterator<Integer>, Integer>
{
@Override
public Iterator<Integer> call(Iterator<Integer> integerIterator) throws Exception
{
List<Integer> list = new ArrayList<>();
Iterator<Integer> itr = list.iterator();
while(itr.hasNext())
{
list.add(itr.next()+1);
}
return list.iterator();
}
}
}

Benefit of using mapPartitions

mapPartitions helps us to achieve — Efficient Local Aggregation

Since mapPartitions() works on the partitions level, it gives the opportunity to the user to perform filtering and aggregation at a partition level. This local aggregation on a partition level greatly reduces the amount of shuffled data. With mapPartitions(), we are reducing a partition into a small contained data structure. It is evident that reduction in sort and shuffle of data results in improvement in efficiency and reliability of reduce operations. To achieve local aggregation, one needs to instantiate a Hashmap<> for storing the aggregated value objects against the corresponding grouping key(s). This Hashmap<> is then repeatedly updated while iterating over the data collection of the underlying partition. Finally, we return an iterator on aggregated value/objects (optionally along with associated grouping key(s)) contained in the map is returned back

Let’s talk about 2 things now

  1. When should we use mapPartitions ?
  2. And, how Spark’s mapPartitions() can be used to implement Summarization design pattern ?

Answers to above 2 questions are below

Spark’s mapPartitions() transformation should be used when you want to extract minimal information (such as finding the minimum and maximum of numbers, top-10 URLs, or finding count of DNA bases which we will see in sometime in this blog) from each partition, where each partition holds a large dataset. Let’s suppose, you want to find out the minimum and maximum of all numbers in your input, then using map() can be pretty inefficient, since you will be generating huge numbers of intermediate (key, value) pairs, but our goal is just to find two numbers: the minimum and maximum of all numbers in your input. Therefore, using mapPartitions() is the right approach to solve DNA Base Count problem without having any scalability issue.

Summarization Design pattern is all about grouping similar data together and then performing an operation such as calculating a statistic, building an index, or just simply counting. In such scenarios, mapPartitions()can help us. We will see in this blog, how we achieve this in our DNA Base Count problem.

Let’s deep dive into Java code for DNA Base Count problem. Please refer my blog whose link is attached to the beginning of this blog to understand the problem statement and the data set used for —

package com.spark.rdd.tutorial;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;

import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

public class DNABaseCount
{
public static void main(String args[])
{
SparkConf conf = new SparkConf().setMaster("local").setAppName("app");
JavaSparkContext jsc = new JavaSparkContext(conf);

JavaRDD<String> inputRDD = jsc.textFile("C:\\Users\\Lenovo\\Desktop\\log_.txt");

JavaRDD<Map<Character, Long>> mapJavaRDD = inputRDD.mapPartitions(new FlatMapFunction<Iterator<String>, Map<Character, Long>>() {
Map<Character, Long> baseCounts = new HashMap<Character, Long>();

@Override
public Iterator<Map<Character, Long>> call(Iterator<String> iterator) throws Exception {
while (iterator.hasNext()) {
String value = iterator.next();
if (value.startsWith(">")) {
// do nothing
} else {
String str = value.toUpperCase();
for (int i = 0; i < str.length(); ++i) {
char c = str.charAt(i);
Long count = baseCounts.get(c);
if (count == null) {
baseCounts.put(c, 1L);
} else {
baseCounts.put(c, count + 1L);
}
}
}
}
return Collections.singletonList(baseCounts).iterator();
}
});
List<Map<Character, Long>> collect = mapJavaRDD.collect();
}

}

Run above program in debug mode on intellij and inspect the collect variable List<Map<Character, Long>> collect = mapJavaRDD.collect();

Also, I have tested this code against a small dataset but above solution will also scale, say that we have about 6 billion records and for efficient transformation, we are using a single HashMap<> per partition to aggregate DNA letters and its associated frequencies. For example, if our dataset has a total of 5 billion records and N = 50,000, then each partition will have about 100,000 FASTA records (5 billion = 50,000 * 100,000). Therefore, each func() will process (by means of iteration) about 100,000 FASTA records. Each partition will emit at most 6 (key, value) pairs, where keys will be in {"a", "t", "c", "g", "n", "z"} where n as a key for non-DNA letters and z as a key for number of consumed DNA string/sequences. We can also derive that mapPartitions() can be considered as a “Many-to-1” transformation because it maps each partition (comprised of many elements of the source RDD — each partition may have thousands or millions of elements) of the source RDD into a single element of target RDD.

--

--

Chandra Prakash
Chandra Prakash

Written by Chandra Prakash

Cloud Data Engineer : Spark ~ Scala ~ Flink ~ Kafka ~ Concurrency ~ Azure Cloud ~ Design Patterns https://www.linkedin.com/in/chandra-prakash-28932652/

No responses yet