Friday, 28 July 2017

Java 8: Use Lambda to create a Pivot-like result

Lets assume we have a java POJO Person with some fields like year, teamId and sum and we want to create a pivot which shows for every year and teamId the total sum like:
-----------------------------------------------
             |     Test|  Develop|  Deploy|
-----------------------------------------------
     2015|    1450|       5800|       870|
-----------------------------------------------
     2016|    2500|     10000|           0|
-----------------------------------------------
     2017|          0|       3500|       250|
-----------------------------------------------

Define a simple CSV file which keeps the data:
year,teamId,sum
2015,Test,480
2015,Develop,800
2015,Deploy,300
.....
To create our pivot, first we need to read the data in an appropriate structure Map<YearTeam, List<Person>> where YearTeam is also a simple POJO containg year and teamId. For this, Collectors.groupingBy method will really help.
Pattern pattern = Pattern.compile(",");
Map<YearTeam, List<Person>> grouped = new HashMap<>();
try (BufferedReader in = new BufferedReader(new FileReader(fileName));) {
   grouped = in
    .lines()
    .skip(1)
    .map(line -> {
       String[] arr = pattern.split(line);
       return new Person(
                  Integer.parseInt(arr[0]), 
                  arr[1],
                  Integer.parseInt(arr[2]));
       })
    .collect(Collectors.groupingBy(x -> new YearTeam(x.getYear(), x.getTeamId())));
}
Having this map, we want to print first the teams header and then every year with total sum. First step is to obtain the set of teams:
Set<String> teams = map
        .keySet()
        .stream()
        .map(x -> x.getTeamId())
        .collect(Collectors.toCollection(TreeSet::new));

We can define some useful parameters for pretty printing:
public static int COLUMN_WIDTH = 9;
public static String COLUMN_DELIMITER = "|";
public static String LINE_DELIMITER = "-";
And some print utilities methods:
private static void printLineDelimiter(int columns) {
    System.out.println();
    System.out.println(String.join("", Collections.nCopies(
               columns*(COLUMN_WIDTH +1), LINE_DELIMITER)));
}

private static void printTeamsHeader(Set<String> teams) {
    System.out.printf("%" + (COLUMN_WIDTH + 1) + "s", COLUMN_DELIMITER);
    teams.stream().forEach(t -> System.out.printf("%" + (COLUMN_WIDTH + 1) + "s", 
                                t + COLUMN_DELIMITER));
}

private static void printYear(int year) {
    System.out.printf("%" + (COLUMN_WIDTH +1) + "s" , year + COLUMN_DELIMITER);
}

private static void printTotal(long total) {
    System.out.printf("%" + COLUMN_WIDTH + "s", total);
}
With these we can start to print teams header:
int columns = teams.size()+1;

printLineDelimiter(columns);
printTeamsHeader(teams);
printLineDelimiter(columns);
Next we need to print actual data. To compute the total sum we will use Collectors.summingLong :
Set<Integer> years = map
        .keySet()
        .stream()
        .map(x -> x.getYear())
        .collect(Collectors.toSet());

years
    .stream()
    .forEach(y -> {
        printYear(y);
        teams.stream().forEach(t -> {
            YearTeam yt = new YearTeam(y, t);
            List<Person> persons = map.get(yt);
            long total = persons == null ? 0 :
                    persons.stream()
                           .collect(Collectors.summingLong(Person::getSum));
            printTotal(total);
            System.out.print(COLUMN_DELIMITER);
        });
        printLineDelimiter(columns);
    });