Rename pivoted and aggregated column in PySpark Dataframe
Question:
With a dataframe as follows:
from pyspark.sql.functions import avg, first
rdd = sc.parallelize(
[
(0, "A", 223,"201603", "PORT"),
(0, "A", 22,"201602", "PORT"),
(0, "A", 422,"201601", "DOCK"),
(1,"B", 3213,"201602", "DOCK"),
(1,"B", 3213,"201601", "PORT"),
(2,"C", 2321,"201601", "DOCK")
]
)
df_data = sqlContext.createDataFrame(rdd, ["id","type", "cost", "date", "ship"])
df_data.show()
I do a pivot on it,
df_data.groupby(df_data.id, df_data.type).pivot("date").agg(avg("cost"), first("ship")).show()
+---+----+----------------+--------------------+----------------+--------------------+----------------+--------------------+
| id|type|201601_avg(cost)|201601_first(ship)()|201602_avg(cost)|201602_first(ship)()|201603_avg(cost)|201603_first(ship)()|
+---+----+----------------+--------------------+----------------+--------------------+----------------+--------------------+
| 2| C| 2321.0| DOCK| null| null| null| null|
| 0| A| 422.0| DOCK| 22.0| PORT| 223.0| PORT|
| 1| B| 3213.0| PORT| 3213.0| DOCK| null| null|
+---+----+----------------+--------------------+----------------+--------------------+----------------+--------------------+
But I get these really complicated names for the columns. Applying alias
on the aggregation usually works, but because of the pivot
in this case the names are even worse:
+---+----+--------------------------------------------------------------+------------------------------------------------------------------+--------------------------------------------------------------+------------------------------------------------------------------+--------------------------------------------------------------+------------------------------------------------------------------+
| id|type|201601_(avg(cost),mode=Complete,isDistinct=false) AS cost#1619|201601_(first(ship)(),mode=Complete,isDistinct=false) AS ship#1620|201602_(avg(cost),mode=Complete,isDistinct=false) AS cost#1619|201602_(first(ship)(),mode=Complete,isDistinct=false) AS ship#1620|201603_(avg(cost),mode=Complete,isDistinct=false) AS cost#1619|201603_(first(ship)(),mode=Complete,isDistinct=false) AS ship#1620|
+---+----+--------------------------------------------------------------+------------------------------------------------------------------+--------------------------------------------------------------+------------------------------------------------------------------+--------------------------------------------------------------+------------------------------------------------------------------+
| 2| C| 2321.0| DOCK| null| null| null| null|
| 0| A| 422.0| DOCK| 22.0| PORT| 223.0| PORT|
| 1| B| 3213.0| PORT| 3213.0| DOCK| null| null|
+---+----+--------------------------------------------------------------+------------------------------------------------------------------+--------------------------------------------------------------+------------------------------------------------------------------+--------------------------------------------------------------+------------------------------------------------------------------+
Is there a way to rename the column names on the fly on the pivot and aggregation?
Answers:
A simple regular expression should do the trick:
import re
def clean_names(df):
p = re.compile("^(w+?)_([a-z]+)((w+))(?:())?")
return df.toDF(*[p.sub(r"1_3", c) for c in df.columns])
pivoted = df_data.groupby(...).pivot(...).agg(...)
clean_names(pivoted).printSchema()
## root
## |-- id: long (nullable = true)
## |-- type: string (nullable = true)
## |-- 201601_cost: double (nullable = true)
## |-- 201601_ship: string (nullable = true)
## |-- 201602_cost: double (nullable = true)
## |-- 201602_ship: string (nullable = true)
## |-- 201603_cost: double (nullable = true)
## |-- 201603_ship: string (nullable = true)
If you want to preserve function name you change substitution pattern to for example 1_2_3
.
You can alias the aggregations directly:
pivoted = df_data
.groupby(df_data.id, df_data.type)
.pivot("date")
.agg(
avg('cost').alias('cost'),
first("ship").alias('ship')
)
pivoted.printSchema()
##root
##|-- id: long (nullable = true)
##|-- type: string (nullable = true)
##|-- 201601_cost: double (nullable = true)
##|-- 201601_ship: string (nullable = true)
##|-- 201602_cost: double (nullable = true)
##|-- 201602_ship: string (nullable = true)
##|-- 201603_cost: double (nullable = true)
##|-- 201603_ship: string (nullable = true)
A simple approach will be using alias after the aggregate function.
I start with the df_data spark dataFrame you created.
df_data.groupby(df_data.id, df_data.type).pivot("date").agg(avg("cost").alias("avg_cost"), first("ship").alias("first_ship")).show()
+---+----+---------------+-----------------+---------------+-----------------+---------------+-----------------+
| id|type|201601_avg_cost|201601_first_ship|201602_avg_cost|201602_first_ship|201603_avg_cost|201603_first_ship|
+---+----+---------------+-----------------+---------------+-----------------+---------------+-----------------+
| 1| B| 3213.0| PORT| 3213.0| DOCK| null| null|
| 2| C| 2321.0| DOCK| null| null| null| null|
| 0| A| 422.0| DOCK| 22.0| PORT| 223.0| PORT|
+---+----+---------------+-----------------+---------------+-----------------+---------------+-----------------+
column names will be the form of “original_column_name_aliased_column_name”. For your case, original_column_name will be 201601, aliased_column_name will be avg_cost, and the column name is 201601_avg_cost(linked by underscore “_”).
Wrote an easy and fast function to do this. Enjoy! 🙂
# This function efficiently rename pivot tables' urgly names
def rename_pivot_cols(rename_df, remove_agg):
"""change spark pivot table's default ugly column names at ease.
Option 1: remove_agg = True: `2_sum(sum_amt)` --> `sum_amt_2`.
Option 2: remove_agg = False: `2_sum(sum_amt)` --> `sum_sum_amt_2`
"""
for column in rename_df.columns:
if remove_agg == True:
start_index = column.find('(')
end_index = column.find(')')
if (start_index > 0 and end_index > 0):
rename_df = rename_df.withColumnRenamed(column, column[start_index+1:end_index]+'_'+column[:1])
else:
new_column = column.replace('(','_').replace(')','')
rename_df = rename_df.withColumnRenamed(column, new_column[2:]+'_'+new_column[:1])
return rename_df
You can change the name of column before pivot:
.withColumn("ship", F.concat(F.lit("ship_"), "ship"))
With a dataframe as follows:
from pyspark.sql.functions import avg, first
rdd = sc.parallelize(
[
(0, "A", 223,"201603", "PORT"),
(0, "A", 22,"201602", "PORT"),
(0, "A", 422,"201601", "DOCK"),
(1,"B", 3213,"201602", "DOCK"),
(1,"B", 3213,"201601", "PORT"),
(2,"C", 2321,"201601", "DOCK")
]
)
df_data = sqlContext.createDataFrame(rdd, ["id","type", "cost", "date", "ship"])
df_data.show()
I do a pivot on it,
df_data.groupby(df_data.id, df_data.type).pivot("date").agg(avg("cost"), first("ship")).show()
+---+----+----------------+--------------------+----------------+--------------------+----------------+--------------------+
| id|type|201601_avg(cost)|201601_first(ship)()|201602_avg(cost)|201602_first(ship)()|201603_avg(cost)|201603_first(ship)()|
+---+----+----------------+--------------------+----------------+--------------------+----------------+--------------------+
| 2| C| 2321.0| DOCK| null| null| null| null|
| 0| A| 422.0| DOCK| 22.0| PORT| 223.0| PORT|
| 1| B| 3213.0| PORT| 3213.0| DOCK| null| null|
+---+----+----------------+--------------------+----------------+--------------------+----------------+--------------------+
But I get these really complicated names for the columns. Applying alias
on the aggregation usually works, but because of the pivot
in this case the names are even worse:
+---+----+--------------------------------------------------------------+------------------------------------------------------------------+--------------------------------------------------------------+------------------------------------------------------------------+--------------------------------------------------------------+------------------------------------------------------------------+
| id|type|201601_(avg(cost),mode=Complete,isDistinct=false) AS cost#1619|201601_(first(ship)(),mode=Complete,isDistinct=false) AS ship#1620|201602_(avg(cost),mode=Complete,isDistinct=false) AS cost#1619|201602_(first(ship)(),mode=Complete,isDistinct=false) AS ship#1620|201603_(avg(cost),mode=Complete,isDistinct=false) AS cost#1619|201603_(first(ship)(),mode=Complete,isDistinct=false) AS ship#1620|
+---+----+--------------------------------------------------------------+------------------------------------------------------------------+--------------------------------------------------------------+------------------------------------------------------------------+--------------------------------------------------------------+------------------------------------------------------------------+
| 2| C| 2321.0| DOCK| null| null| null| null|
| 0| A| 422.0| DOCK| 22.0| PORT| 223.0| PORT|
| 1| B| 3213.0| PORT| 3213.0| DOCK| null| null|
+---+----+--------------------------------------------------------------+------------------------------------------------------------------+--------------------------------------------------------------+------------------------------------------------------------------+--------------------------------------------------------------+------------------------------------------------------------------+
Is there a way to rename the column names on the fly on the pivot and aggregation?
A simple regular expression should do the trick:
import re
def clean_names(df):
p = re.compile("^(w+?)_([a-z]+)((w+))(?:())?")
return df.toDF(*[p.sub(r"1_3", c) for c in df.columns])
pivoted = df_data.groupby(...).pivot(...).agg(...)
clean_names(pivoted).printSchema()
## root
## |-- id: long (nullable = true)
## |-- type: string (nullable = true)
## |-- 201601_cost: double (nullable = true)
## |-- 201601_ship: string (nullable = true)
## |-- 201602_cost: double (nullable = true)
## |-- 201602_ship: string (nullable = true)
## |-- 201603_cost: double (nullable = true)
## |-- 201603_ship: string (nullable = true)
If you want to preserve function name you change substitution pattern to for example 1_2_3
.
You can alias the aggregations directly:
pivoted = df_data
.groupby(df_data.id, df_data.type)
.pivot("date")
.agg(
avg('cost').alias('cost'),
first("ship").alias('ship')
)
pivoted.printSchema()
##root
##|-- id: long (nullable = true)
##|-- type: string (nullable = true)
##|-- 201601_cost: double (nullable = true)
##|-- 201601_ship: string (nullable = true)
##|-- 201602_cost: double (nullable = true)
##|-- 201602_ship: string (nullable = true)
##|-- 201603_cost: double (nullable = true)
##|-- 201603_ship: string (nullable = true)
A simple approach will be using alias after the aggregate function.
I start with the df_data spark dataFrame you created.
df_data.groupby(df_data.id, df_data.type).pivot("date").agg(avg("cost").alias("avg_cost"), first("ship").alias("first_ship")).show()
+---+----+---------------+-----------------+---------------+-----------------+---------------+-----------------+
| id|type|201601_avg_cost|201601_first_ship|201602_avg_cost|201602_first_ship|201603_avg_cost|201603_first_ship|
+---+----+---------------+-----------------+---------------+-----------------+---------------+-----------------+
| 1| B| 3213.0| PORT| 3213.0| DOCK| null| null|
| 2| C| 2321.0| DOCK| null| null| null| null|
| 0| A| 422.0| DOCK| 22.0| PORT| 223.0| PORT|
+---+----+---------------+-----------------+---------------+-----------------+---------------+-----------------+
column names will be the form of “original_column_name_aliased_column_name”. For your case, original_column_name will be 201601, aliased_column_name will be avg_cost, and the column name is 201601_avg_cost(linked by underscore “_”).
Wrote an easy and fast function to do this. Enjoy! 🙂
# This function efficiently rename pivot tables' urgly names
def rename_pivot_cols(rename_df, remove_agg):
"""change spark pivot table's default ugly column names at ease.
Option 1: remove_agg = True: `2_sum(sum_amt)` --> `sum_amt_2`.
Option 2: remove_agg = False: `2_sum(sum_amt)` --> `sum_sum_amt_2`
"""
for column in rename_df.columns:
if remove_agg == True:
start_index = column.find('(')
end_index = column.find(')')
if (start_index > 0 and end_index > 0):
rename_df = rename_df.withColumnRenamed(column, column[start_index+1:end_index]+'_'+column[:1])
else:
new_column = column.replace('(','_').replace(')','')
rename_df = rename_df.withColumnRenamed(column, new_column[2:]+'_'+new_column[:1])
return rename_df
You can change the name of column before pivot:
.withColumn("ship", F.concat(F.lit("ship_"), "ship"))