From 27a727bf72bc7d07e2629543f5ac02fd36073803 Mon Sep 17 00:00:00 2001 From: Matthew Jubb Date: Mon, 5 Sep 2022 16:27:32 -0400 Subject: [PATCH 1/4] support for multiple grouping and aggregation columns via pivot method --- .../tech/tablesaw/aggregate/PivotTable.java | 128 +++++++++++++----- .../main/java/tech/tablesaw/api/Table.java | 44 ++++-- .../tablesaw/aggregate/PivotTableTest.java | 75 +++++++++- 3 files changed, 202 insertions(+), 45 deletions(-) diff --git a/core/src/main/java/tech/tablesaw/aggregate/PivotTable.java b/core/src/main/java/tech/tablesaw/aggregate/PivotTable.java index f92eceea8..e430187ed 100644 --- a/core/src/main/java/tech/tablesaw/aggregate/PivotTable.java +++ b/core/src/main/java/tech/tablesaw/aggregate/PivotTable.java @@ -3,10 +3,15 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; +import java.util.LinkedList; import java.util.Map; +import java.util.stream.Collector; +import java.util.stream.Collectors; + import tech.tablesaw.api.CategoricalColumn; import tech.tablesaw.api.DoubleColumn; import tech.tablesaw.api.NumericColumn; +import tech.tablesaw.columns.Column; import tech.tablesaw.api.Table; import tech.tablesaw.table.TableSlice; import tech.tablesaw.table.TableSliceGroup; @@ -36,66 +41,123 @@ public class PivotTable { * subgroups * @return A new, pivoted table */ + + + public static Table pivot( + Table table, + CategoricalColumn groupingColumn, + CategoricalColumn pivotColumn, + NumericColumn aggregatedColumns, + AggregateFunction aggregateFunction) { + + return pivot(table, List.of(groupingColumn), pivotColumn, List.of(aggregatedColumns), aggregateFunction); + } + + public static Table pivot( Table table, - CategoricalColumn column1, - CategoricalColumn column2, - NumericColumn values, + List> groupingColumns, + CategoricalColumn pivotColumn, + List> aggregatedColumns, AggregateFunction aggregateFunction) { - TableSliceGroup tsg = table.splitOn(column1); + boolean multiAggregated = aggregatedColumns.size() > 1; - Table pivotTable = Table.create("Pivot: " + column1.name() + " x " + column2.name()); - pivotTable.addColumns(column1.type().create(column1.name())); + TableSliceGroup tsg = table.splitOn(groupingColumns.toArray(CategoricalColumn[]::new)); - List valueColumnNames = getValueColumnNames(table, column2); + List groupingColumnNames = groupingColumns.stream().map(_c -> _c.name()).collect(Collectors.toList()); - for (String colName : valueColumnNames) { - pivotTable.addColumns(DoubleColumn.create(colName)); - } + Table pivotTable = Table.create("Pivot: " + String.join(",", groupingColumnNames) + " x " + pivotColumn.name()); - int valueIndex = table.columnIndex(column2); - int keyIndex = table.columnIndex(column1); + pivotTable.addColumns(groupingColumns.stream().map(_c -> _c.type().create(_c.name())).toArray(Column[]::new)); - String key; + List valueColumnNames = getValueColumnNames(table, pivotColumn); + + if(multiAggregated){ + for (String colName : valueColumnNames) + for(NumericColumn aggColumn : aggregatedColumns) { + pivotTable.addColumns(DoubleColumn.create(colName + "." + aggColumn.name())); + } + } + else{ + for (String colName : valueColumnNames) { + pivotTable.addColumns(DoubleColumn.create(colName)); + } + } for (TableSlice slice : tsg.getSlices()) { - key = String.valueOf(slice.get(0, keyIndex)); - pivotTable.column(0).appendCell(key); + + for (int i = 0; i < groupingColumns.size(); i++) { + String key = String.valueOf(slice.get(0, table.columnIndex(groupingColumns.get(i)))); + pivotTable.column(i).appendCell(key); + } Map valueMap = - getValueMap(column1, column2, values, valueIndex, slice, aggregateFunction); + getValueMap(groupingColumns, pivotColumn, aggregatedColumns, slice, aggregateFunction); for (String columnName : valueColumnNames) { - Double aDouble = valueMap.get(columnName); - NumericColumn pivotValueColumn = pivotTable.numberColumn(columnName); - if (aDouble == null) { - pivotValueColumn.appendMissing(); - } else { - pivotValueColumn.appendObj(aDouble); - } - } - } + for (NumericColumn aggregatedColumn: aggregatedColumns) { + + String appendedColumnName; + + if(multiAggregated){ + appendedColumnName = columnName + "." + aggregatedColumn.name(); + } else { + appendedColumnName = columnName; + } + + NumericColumn pivotValueColumn = pivotTable.numberColumn(appendedColumnName); + + Double aDouble = valueMap.get(appendedColumnName); + + if (aDouble == null) { + pivotValueColumn.appendMissing(); + } else { + pivotValueColumn.appendObj(aDouble); + } + } + } + } + return pivotTable; } private static Map getValueMap( - CategoricalColumn column1, - CategoricalColumn column2, - NumericColumn values, - int valueIndex, + List> groupingColumns, + CategoricalColumn pivotColumn, + List> aggregatedColumns, TableSlice slice, AggregateFunction function) { Table temp = slice.asTable(); - Table summary = temp.summarize(values.name(), function).by(column1.name(), column2.name()); + List> allKeyColumns = new LinkedList<>(groupingColumns); + allKeyColumns.add(pivotColumn); + + List aggregatedColumnNames = aggregatedColumns.stream().map(NumericColumn::name).collect(Collectors.toList()); + + Table summary = temp.summarize(aggregatedColumnNames, function).by(allKeyColumns.stream().map(CategoricalColumn::name).toArray(String[]::new)); Map valueMap = new HashMap<>(); - NumericColumn nc = summary.numberColumn(summary.columnCount() - 1); - for (int i = 0; i < summary.rowCount(); i++) { - valueMap.put(String.valueOf(summary.get(i, 1)), nc.getDouble(i)); + + + if(aggregatedColumns.size() == 1){ + + NumericColumn nc = summary.numberColumn(summary.columnCount() - 1); + for (int i = 0; i < summary.rowCount(); i++) { + valueMap.put(String.valueOf(summary.get(i, groupingColumns.size())), nc.getDouble(i)); + } + } + else{ + for (int i = 0; i < summary.rowCount(); i++) { + for (int k = 0; k < aggregatedColumns.size(); k++) { + NumericColumn nc = summary.numberColumn(groupingColumns.size() + k + 1); + valueMap.put(String.valueOf(summary.get(i, groupingColumns.size())) + "." + aggregatedColumns.get(k).name(), nc.getDouble(i)); + } + } + } + return valueMap; } diff --git a/core/src/main/java/tech/tablesaw/api/Table.java b/core/src/main/java/tech/tablesaw/api/Table.java index f099adc09..91e645c80 100644 --- a/core/src/main/java/tech/tablesaw/api/Table.java +++ b/core/src/main/java/tech/tablesaw/api/Table.java @@ -864,20 +864,42 @@ public Table dropWhere(Selection selection) { return newTable; } - /** + + public Table pivot( + List> groupingColumn, + CategoricalColumn pivotColumn, + List> aggregatedColumn, + AggregateFunction aggregateFunction) { + return PivotTable.pivot(this, groupingColumn, pivotColumn, aggregatedColumn, aggregateFunction); + } + + public Table pivot( + List groupingColumnNames, + String pivotColumnName, + List aggregatedColumnNames, + AggregateFunction aggregateFunction) { + return pivot( + groupingColumnNames.stream().map(this::categoricalColumn).collect(Collectors.toList()), + categoricalColumn(pivotColumnName), + aggregatedColumnNames.stream().map(this::numberColumn).collect(Collectors.toList()), + aggregateFunction); + } + + /** * Returns a pivot on this table, where: The first column contains unique values from the index * column1 There are n additional columns, one for each unique value in column2 The values in each * of the cells in these new columns are the result of applying the given AggregateFunction to the * data in column3, grouped by the values of column1 and column2 */ public Table pivot( - CategoricalColumn column1, - CategoricalColumn column2, - NumericColumn column3, + CategoricalColumn groupingColumn, + CategoricalColumn pivotColumn, + NumericColumn aggregatedColumn, AggregateFunction aggregateFunction) { - return PivotTable.pivot(this, column1, column2, column3, aggregateFunction); + return PivotTable.pivot(this, groupingColumn, pivotColumn, aggregatedColumn, aggregateFunction); } + /** * Returns a pivot on this table, where: The first column contains unique values from the index * column1 There are n additional columns, one for each unique value in column2 The values in each @@ -885,14 +907,14 @@ public Table pivot( * data in column3, grouped by the values of column1 and column2 */ public Table pivot( - String column1Name, - String column2Name, - String column3Name, + String groupingColumnName, + String pivotColumnName, + String aggregatedColumnName, AggregateFunction aggregateFunction) { return pivot( - categoricalColumn(column1Name), - categoricalColumn(column2Name), - numberColumn(column3Name), + categoricalColumn(groupingColumnName), + categoricalColumn(pivotColumnName), + numberColumn(pivotColumnName), aggregateFunction); } diff --git a/core/src/test/java/tech/tablesaw/aggregate/PivotTableTest.java b/core/src/test/java/tech/tablesaw/aggregate/PivotTableTest.java index c6be7e3e4..037d2676e 100644 --- a/core/src/test/java/tech/tablesaw/aggregate/PivotTableTest.java +++ b/core/src/test/java/tech/tablesaw/aggregate/PivotTableTest.java @@ -6,11 +6,16 @@ import org.junit.jupiter.api.Test; import tech.tablesaw.api.Table; import tech.tablesaw.io.csv.CsvReadOptions; +import java.util.List; public class PivotTableTest { + /** + * Illustrate usage of pivot function with a single grouping, pivot and aggregated columns + * @throws Exception + */ @Test - public void pivot() throws Exception { + public void pivotSingle() throws Exception { Table t = Table.read() .csv(CsvReadOptions.builder("../data/bush.csv").missingValueIndicator(":").build()); @@ -30,4 +35,72 @@ public void pivot() throws Exception { assertTrue(pivot.columnNames().contains("2004")); assertEquals(6, pivot.rowCount()); } + + + @Test + public void pivotMultipleGroupAndAggregate() throws Exception { + Table t = + Table.read() + .csv(CsvReadOptions.builder("../data/baseball.csv").build()); + + Table pivot = + t.pivot( + List.of("Team","League"), + "Year", + List.of("RS","RA","W"), + AggregateFunctions.mean); + + assertTrue(pivot.columnNames().contains("Team")); + assertTrue(pivot.columnNames().contains("League")); + assertTrue(pivot.columnNames().contains("2001.RS")); + assertTrue(pivot.columnNames().contains("2001.RA")); + assertTrue(pivot.columnNames().contains("2001.W")); + assertEquals(143, pivot.columnCount()); + assertEquals(40, pivot.rowCount()); + } + + @Test + public void pivotMultipleGroup() throws Exception { + Table t = + Table.read() + .csv(CsvReadOptions.builder("../data/baseball.csv").build()); + + Table pivot = + t.pivot( + List.of("Team","League"), + "Year", + List.of("RS"), + AggregateFunctions.mean); + + assertTrue(pivot.columnNames().contains("Team")); + assertTrue(pivot.columnNames().contains("League")); + assertTrue(pivot.columnNames().contains("2001")); + assertTrue(pivot.columnNames().contains("2002")); + assertTrue(pivot.columnNames().contains("2003")); + assertEquals(49, pivot.columnCount()); + assertEquals(40, pivot.rowCount()); + } + + @Test + public void pivotMultipleAggregate() throws Exception { + Table t = + Table.read() + .csv(CsvReadOptions.builder("../data/baseball.csv").build()); + + Table pivot = + t.pivot( + List.of("League"), + "Year", + List.of("RS","RA","W"), + AggregateFunctions.mean); + + assertTrue(!pivot.columnNames().contains("Team")); + assertTrue(pivot.columnNames().contains("League")); + assertTrue(pivot.columnNames().contains("2001.RS")); + assertTrue(pivot.columnNames().contains("2001.RA")); + assertTrue(pivot.columnNames().contains("2001.W")); + assertEquals(142, pivot.columnCount()); + assertEquals(2, pivot.rowCount()); + } + } From 433d595d0f1c0410210defcd28a3746b4500a6b8 Mon Sep 17 00:00:00 2001 From: Matthew Jubb Date: Mon, 5 Sep 2022 16:40:55 -0400 Subject: [PATCH 2/4] javadocs on new methods --- .../tech/tablesaw/aggregate/PivotTable.java | 47 +++++++++++-------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/core/src/main/java/tech/tablesaw/aggregate/PivotTable.java b/core/src/main/java/tech/tablesaw/aggregate/PivotTable.java index e430187ed..ae86a30eb 100644 --- a/core/src/main/java/tech/tablesaw/aggregate/PivotTable.java +++ b/core/src/main/java/tech/tablesaw/aggregate/PivotTable.java @@ -5,7 +5,6 @@ import java.util.List; import java.util.LinkedList; import java.util.Map; -import java.util.stream.Collector; import java.util.stream.Collectors; import tech.tablesaw.api.CategoricalColumn; @@ -29,31 +28,43 @@ public class PivotTable { /** * Returns a table that is a rotation of the given table pivoted around the key columns, and * filling the output cells using the values calculated by the {@code aggregateFunction} when - * applied to the {@code values column} grouping by the key columns + * applied to the {@code aggregatedColumn} grouping by the key columns + * + * Handles the case whereby there is a single groupingColumn and aggregatedColumn * * @param table The table that provides the data to be pivoted - * @param column1 A "key" categorical column from which the primary grouping is created. There + * @param groupingColumn A "key" categorical column from which the primary grouping is created. There * will be one on each row of the result - * @param column2 A second categorical column for which a subtotal is created; this produces n + * @param pivotColumn A second categorical column for which a subtotal is created; this produces n * columns on each row of the result - * @param values A numeric column that provides the values to be summarized + * @param aggregatedColumn A numeric column that provides the values to be summarized * @param aggregateFunction function that defines what operation is performed on the values in the * subgroups * @return A new, pivoted table */ - public static Table pivot( Table table, CategoricalColumn groupingColumn, CategoricalColumn pivotColumn, NumericColumn aggregatedColumns, AggregateFunction aggregateFunction) { - return pivot(table, List.of(groupingColumn), pivotColumn, List.of(aggregatedColumns), aggregateFunction); } - + /** + * Returns a table that is a rotation of the given table pivoted around the key columns, and + * filling the output cells using the values calculated by the {@code aggregateFunction} when + * applied to the {@code aggregatedColumns} grouping by the key columns + * + * Handles the case whereby there may be multiple groupingColumns and/or multiple aggregatedColumns + * @param table + * @param groupingColumn + * @param pivotColumn + * @param aggregatedColumns + * @param aggregateFunction + * @return + */ public static Table pivot( Table table, List> groupingColumns, @@ -119,7 +130,7 @@ public static Table pivot( } } - + return pivotTable; } @@ -130,6 +141,7 @@ private static Map getValueMap( TableSlice slice, AggregateFunction function) { + boolean multiAggregated = aggregatedColumns.size() > 1; Table temp = slice.asTable(); List> allKeyColumns = new LinkedList<>(groupingColumns); allKeyColumns.add(pivotColumn); @@ -140,16 +152,7 @@ private static Map getValueMap( Map valueMap = new HashMap<>(); - - if(aggregatedColumns.size() == 1){ - - NumericColumn nc = summary.numberColumn(summary.columnCount() - 1); - for (int i = 0; i < summary.rowCount(); i++) { - valueMap.put(String.valueOf(summary.get(i, groupingColumns.size())), nc.getDouble(i)); - } - - } - else{ + if(multiAggregated){ for (int i = 0; i < summary.rowCount(); i++) { for (int k = 0; k < aggregatedColumns.size(); k++) { NumericColumn nc = summary.numberColumn(groupingColumns.size() + k + 1); @@ -157,6 +160,12 @@ private static Map getValueMap( } } } + else{ + NumericColumn nc = summary.numberColumn(summary.columnCount() - 1); + for (int i = 0; i < summary.rowCount(); i++) { + valueMap.put(String.valueOf(summary.get(i, groupingColumns.size())), nc.getDouble(i)); + } + } return valueMap; } From 6564e821357dda974af6674bcbf12ac8a23133f3 Mon Sep 17 00:00:00 2001 From: Matthew Jubb Date: Mon, 5 Sep 2022 16:41:11 -0400 Subject: [PATCH 3/4] javadocs on new methods --- .../main/java/tech/tablesaw/api/Table.java | 48 +++++++++++++++---- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/core/src/main/java/tech/tablesaw/api/Table.java b/core/src/main/java/tech/tablesaw/api/Table.java index 91e645c80..d9ab35ecf 100644 --- a/core/src/main/java/tech/tablesaw/api/Table.java +++ b/core/src/main/java/tech/tablesaw/api/Table.java @@ -865,14 +865,44 @@ public Table dropWhere(Selection selection) { } + /** + * Returns a new column, where the first n columns are the groupingColumns. There are then p additional + * columns, which is the product of each unique value in the pivot column and aggregatedColumn. The + * values in each of the cells in these new columns are the result of applying the given AggregateFunction + * to the data in each of aggregatedColumn, grouped by the values of groupingColumn and pivotColumn. + * + * If more than one aggregatedColumn is provided then each is appended to each unique value of the pivot + * column in the format "{PivotColumnValue}.{AggregatedColumnName} + * + * @param groupingColumn + * @param pivotColumn + * @param aggregatedColumn + * @param aggregateFunction + * @return + */ public Table pivot( - List> groupingColumn, + List> groupingColumns, CategoricalColumn pivotColumn, - List> aggregatedColumn, + List> aggregatedColumns, AggregateFunction aggregateFunction) { - return PivotTable.pivot(this, groupingColumn, pivotColumn, aggregatedColumn, aggregateFunction); + return PivotTable.pivot(this, groupingColumns, pivotColumn, aggregatedColumns, aggregateFunction); } + /** + * Returns a new column, where the first n columns are the groupingColumns. There are then p additional + * columns, which is the product of each unique value in the pivot column and aggregatedColumn. The + * values in each of the cells in these new columns are the result of applying the given AggregateFunction + * to the data in each of aggregatedColumn, grouped by the values of groupingColumn and pivotColumn. + * + * If more than one aggregatedColumn is provided then each is appended to each unique value of the pivot + * column in the format "{PivotColumnValue}.{AggregatedColumnName} + * + * @param groupingColumnNames + * @param pivotColumnName + * @param aggregatedColumnNames + * @param aggregateFunction + * @return + */ public Table pivot( List groupingColumnNames, String pivotColumnName, @@ -887,9 +917,9 @@ public Table pivot( /** * Returns a pivot on this table, where: The first column contains unique values from the index - * column1 There are n additional columns, one for each unique value in column2 The values in each - * of the cells in these new columns are the result of applying the given AggregateFunction to the - * data in column3, grouped by the values of column1 and column2 + * groupingColumn There are n additional columns, one for each unique value in the pivotColumn. The + * values in each of the cells in these new columns are the result of applying the given AggregateFunction + * to the data in the aggregatedColumn, grouped by the values of groupingColumn and pivotColumn */ public Table pivot( CategoricalColumn groupingColumn, @@ -902,9 +932,9 @@ public Table pivot( /** * Returns a pivot on this table, where: The first column contains unique values from the index - * column1 There are n additional columns, one for each unique value in column2 The values in each - * of the cells in these new columns are the result of applying the given AggregateFunction to the - * data in column3, grouped by the values of column1 and column2 + * groupingColumn There are n additional columns, one for each unique value in the pivotColumn The + * values in each of the cells in these new columns are the result of applying the given AggregateFunction + * to the data in the aggregatedColumn, grouped by the values of groupingColumn and pivotColumn */ public Table pivot( String groupingColumnName, From dd4edc4f8ba4dcc4fd30a5657aee8839184c366f Mon Sep 17 00:00:00 2001 From: Matt Jubb Date: Sat, 4 Oct 2025 21:46:57 -0400 Subject: [PATCH 4/4] use existing static import --- core/src/main/java/tech/tablesaw/api/Table.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/tech/tablesaw/api/Table.java b/core/src/main/java/tech/tablesaw/api/Table.java index d9ab35ecf..148db8532 100644 --- a/core/src/main/java/tech/tablesaw/api/Table.java +++ b/core/src/main/java/tech/tablesaw/api/Table.java @@ -909,9 +909,9 @@ public Table pivot( List aggregatedColumnNames, AggregateFunction aggregateFunction) { return pivot( - groupingColumnNames.stream().map(this::categoricalColumn).collect(Collectors.toList()), + groupingColumnNames.stream().map(this::categoricalColumn).collect(toList()), categoricalColumn(pivotColumnName), - aggregatedColumnNames.stream().map(this::numberColumn).collect(Collectors.toList()), + aggregatedColumnNames.stream().map(this::numberColumn).collect(toList()), aggregateFunction); }