Get the column index where the value is maximum per each row in pyspark
Question:
I have a pyspark dataframe such as:
ID
Col1
Col2
Col3
…
ColN
1
10
5
21
…
-9
2
87
1
1
…
1
3
1
95
1
…
1
How one could create a pyspark dataframe column MAX
that represents the index column where the value is maximum per row such as:
ID
Col1
Col2
Col3
…
ColN
MAX
1
10
5
21
…
-9
3
2
87
1
1
…
1
1
3
1
95
1
…
1
2
Answers:
Create a column with the max in each row
List columns in which the max value can be found
Eliminate the NaNs in the list
Code below
import pyspark.sql.functions as F
from pyspark.sql import Window
from pyspark.sql.functions import*
w=Window.partitionBy('ID').orderBy().rowsBetween(Window.unboundedPreceding,0)
df=(df.withColumn(
"max",
F.greatest(*[F.col(x) for x in df.columns[1:]])#Find the max in each row
)
.withColumn(
'maxcol', array(*[when(col(c) ==col('max'), lit(c)) for c in df.columns])#Find intersection of max with all other columns
).withColumn(
'maxcol', expr("filter(maxcol, x -> x is not null)")#Filter ou the nans in the intersection
).show())
+---+----+----+----+----+---+------+
| ID|Col1|Col2|Col3|ColN|max|maxcol|
+---+----+----+----+----+---+------+
| 1| 10| 5| 21| -9| 21|[Col3]|
| 2| 87| 1| 1| 1| 87|[Col1]|
| 3| 1| 95| 1| 1| 95|[Col2]|
+---+----+----+----+----+---+------+
You could also use pandas_udf though I am not sure of the efficacy
from pyspark.sql.functions import pandas_udf
import pandas as pd
from pyspark.sql.types import *
def max_col(a:pd.DataFrame) -> pd.DataFrame:
s=a.isin(a.iloc[:,1:].max(1))
return a.assign(maxcol=s.agg(lambda x: x.index[x].values, axis=1))
schema=StructType([
StructField('ID',LongType(),True),
StructField('Col1',LongType(),True),
StructField('Col2',LongType(),True),
StructField('Col3',LongType(),True),
StructField('ColN',LongType(),True),
StructField('maxcol',ArrayType(StringType(),True),False)
])
df.groupby('ID').applyInPandas(max_col, schema).show()
df.show(5)
Col1
Col2
Col3
Col4
Col5
3
5
7
3
6
7
7
5
8
8
2
2
3
7
7
4
4
8
6
2
7
4
9
2
2
only showing top 5 rows
df = df.withColumn("data", F.array(df.columns))
max_index = F.udf(lambda x: x.index(max(x)))
df = df.withColumn("max_index", max_index("data"))
df = df.drop("data")
df.show(5)
Col1
Col2
Col3
Col4
Col5
max_index
3
5
7
3
6
2
7
7
5
8
8
3
2
2
3
7
7
3
4
4
8
6
2
2
7
4
9
2
2
2
only showing top 5 rows
I have a pyspark dataframe such as:
ID | Col1 | Col2 | Col3 | … | ColN |
---|---|---|---|---|---|
1 | 10 | 5 | 21 | … | -9 |
2 | 87 | 1 | 1 | … | 1 |
3 | 1 | 95 | 1 | … | 1 |
How one could create a pyspark dataframe column MAX
that represents the index column where the value is maximum per row such as:
ID | Col1 | Col2 | Col3 | … | ColN | MAX |
---|---|---|---|---|---|---|
1 | 10 | 5 | 21 | … | -9 | 3 |
2 | 87 | 1 | 1 | … | 1 | 1 |
3 | 1 | 95 | 1 | … | 1 | 2 |
Create a column with the max in each row
List columns in which the max value can be found
Eliminate the NaNs in the list
Code below
import pyspark.sql.functions as F
from pyspark.sql import Window
from pyspark.sql.functions import*
w=Window.partitionBy('ID').orderBy().rowsBetween(Window.unboundedPreceding,0)
df=(df.withColumn(
"max",
F.greatest(*[F.col(x) for x in df.columns[1:]])#Find the max in each row
)
.withColumn(
'maxcol', array(*[when(col(c) ==col('max'), lit(c)) for c in df.columns])#Find intersection of max with all other columns
).withColumn(
'maxcol', expr("filter(maxcol, x -> x is not null)")#Filter ou the nans in the intersection
).show())
+---+----+----+----+----+---+------+
| ID|Col1|Col2|Col3|ColN|max|maxcol|
+---+----+----+----+----+---+------+
| 1| 10| 5| 21| -9| 21|[Col3]|
| 2| 87| 1| 1| 1| 87|[Col1]|
| 3| 1| 95| 1| 1| 95|[Col2]|
+---+----+----+----+----+---+------+
You could also use pandas_udf though I am not sure of the efficacy
from pyspark.sql.functions import pandas_udf
import pandas as pd
from pyspark.sql.types import *
def max_col(a:pd.DataFrame) -> pd.DataFrame:
s=a.isin(a.iloc[:,1:].max(1))
return a.assign(maxcol=s.agg(lambda x: x.index[x].values, axis=1))
schema=StructType([
StructField('ID',LongType(),True),
StructField('Col1',LongType(),True),
StructField('Col2',LongType(),True),
StructField('Col3',LongType(),True),
StructField('ColN',LongType(),True),
StructField('maxcol',ArrayType(StringType(),True),False)
])
df.groupby('ID').applyInPandas(max_col, schema).show()
df.show(5)
Col1 | Col2 | Col3 | Col4 | Col5 |
---|---|---|---|---|
3 | 5 | 7 | 3 | 6 |
7 | 7 | 5 | 8 | 8 |
2 | 2 | 3 | 7 | 7 |
4 | 4 | 8 | 6 | 2 |
7 | 4 | 9 | 2 | 2 |
only showing top 5 rows
df = df.withColumn("data", F.array(df.columns))
max_index = F.udf(lambda x: x.index(max(x)))
df = df.withColumn("max_index", max_index("data"))
df = df.drop("data")
df.show(5)
Col1 | Col2 | Col3 | Col4 | Col5 | max_index |
---|---|---|---|---|---|
3 | 5 | 7 | 3 | 6 | 2 |
7 | 7 | 5 | 8 | 8 | 3 |
2 | 2 | 3 | 7 | 7 | 3 |
4 | 4 | 8 | 6 | 2 | 2 |
7 | 4 | 9 | 2 | 2 | 2 |
only showing top 5 rows