Get the column index where the value is maximum per each row in pyspark


I have a pyspark dataframe such as:

ID Col1 Col2 Col3 ColN
1 10 5 21 -9
2 87 1 1 1
3 1 95 1 1

How one could create a pyspark dataframe column MAX that represents the index column where the value is maximum per row such as:

ID Col1 Col2 Col3 ColN MAX
1 10 5 21 -9 3
2 87 1 1 1 1
3 1 95 1 1 2
Asked By: Gustavomoty



Create a column with the max in each row

List columns in which the max value can be found

Eliminate the NaNs in the list

Code below

import pyspark.sql.functions as F
from pyspark.sql import Window
from pyspark.sql.functions import*
    F.greatest(*[F.col(x) for x in df.columns[1:]])#Find the max in each row
  'maxcol', array(*[when(col(c) ==col('max'), lit(c)) for c in df.columns])#Find intersection of max with all other columns
  'maxcol', expr("filter(maxcol, x -> x is not null)")#Filter ou the nans in the intersection

| ID|Col1|Col2|Col3|ColN|max|maxcol|
|  1|  10|   5|  21|  -9| 21|[Col3]|
|  2|  87|   1|   1|   1| 87|[Col1]|
|  3|   1|  95|   1|   1| 95|[Col2]|

You could also use pandas_udf though I am not sure of the efficacy
from pyspark.sql.functions import pandas_udf

import pandas as pd
from pyspark.sql.types import *
def max_col(a:pd.DataFrame) -> pd.DataFrame:
  return a.assign(maxcol=s.agg(lambda x: x.index[x].values, axis=1))
df.groupby('ID').applyInPandas(max_col, schema).show()
Answered By: wwnde
Col1 Col2 Col3 Col4 Col5
3 5 7 3 6
7 7 5 8 8
2 2 3 7 7
4 4 8 6 2
7 4 9 2 2

only showing top 5 rows

df = df.withColumn("data", F.array(df.columns)) 
max_index = F.udf(lambda x: x.index(max(x)))  
df = df.withColumn("max_index", max_index("data")) 
df = df.drop("data")
Col1 Col2 Col3 Col4 Col5 max_index
3 5 7 3 6 2
7 7 5 8 8 3
2 2 3 7 7 3
4 4 8 6 2 2
7 4 9 2 2 2

only showing top 5 rows

Answered By: Akhileshwar Sharma