How can I have a row-wise rank in a pyspark dataframe
Question:
I have a dataset for which I am going to find the rank per row. This is a toy example in pandas
.
import pandas as pd
df = pd.DataFrame({"ID":[1,2,3,4], "a":[2,7,9,10],
"b":[6,7,4,2], "c":[3,4,8,5]})
print(df)
# ID a b c
# 0 1 2 6 3
# 1 2 7 7 4
# 2 3 9 4 8
# 3 4 10 2 5
df[["a","b","c"]] = df[["a","b","c"]].rank(method="min",
ascending=False,
axis=1).astype("int")
print(df)
# ID a b c
# 0 1 3 1 2
# 1 2 1 1 3
# 2 3 1 3 2
# 3 4 1 3 2
However, as I didn’t find equivalent of axis=1
in pyspark
, I couldn’t convert it to that. My dataset has 60 million rows and 40 columns, so the suggestion should be efficient (e.g., I cannot loop over them).
Answers:
You can add all columns (except ID
) to a new array column (called arr
below), then sort this array and then print out the index of each value in the sorted array:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = SparkSession.builder.master("local[1]")
.appName("HadiJahanshahi").enableHiveSupport().getOrCreate()
df = spark.createDataFrame([(1, 2, 6, 3),
(2, 7, 7, 4),
(3, 9, 4, 8),
(4, 10, 2, 5)],
["ID", "a", "b", "c"])
cols = df.columns
cols.remove('ID')
expr = [F.col('ID')] + [F.expr(f'array_position(arr, {c})').alias(f'{c}') for c in cols]
df.withColumn('arr', F.sort_array(F.array(cols), False))
.select(expr)
.show()
Result:
+---+---+---+---+
| ID| a| b| c|
+---+---+---+---+
| 1| 3| 1| 2|
| 2| 1| 1| 3|
| 3| 1| 3| 2|
| 4| 1| 3| 2|
+---+---+---+---+
I have a dataset for which I am going to find the rank per row. This is a toy example in pandas
.
import pandas as pd
df = pd.DataFrame({"ID":[1,2,3,4], "a":[2,7,9,10],
"b":[6,7,4,2], "c":[3,4,8,5]})
print(df)
# ID a b c
# 0 1 2 6 3
# 1 2 7 7 4
# 2 3 9 4 8
# 3 4 10 2 5
df[["a","b","c"]] = df[["a","b","c"]].rank(method="min",
ascending=False,
axis=1).astype("int")
print(df)
# ID a b c
# 0 1 3 1 2
# 1 2 1 1 3
# 2 3 1 3 2
# 3 4 1 3 2
However, as I didn’t find equivalent of axis=1
in pyspark
, I couldn’t convert it to that. My dataset has 60 million rows and 40 columns, so the suggestion should be efficient (e.g., I cannot loop over them).
You can add all columns (except ID
) to a new array column (called arr
below), then sort this array and then print out the index of each value in the sorted array:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = SparkSession.builder.master("local[1]")
.appName("HadiJahanshahi").enableHiveSupport().getOrCreate()
df = spark.createDataFrame([(1, 2, 6, 3),
(2, 7, 7, 4),
(3, 9, 4, 8),
(4, 10, 2, 5)],
["ID", "a", "b", "c"])
cols = df.columns
cols.remove('ID')
expr = [F.col('ID')] + [F.expr(f'array_position(arr, {c})').alias(f'{c}') for c in cols]
df.withColumn('arr', F.sort_array(F.array(cols), False))
.select(expr)
.show()
Result:
+---+---+---+---+
| ID| a| b| c|
+---+---+---+---+
| 1| 3| 1| 2|
| 2| 1| 1| 3|
| 3| 1| 3| 2|
| 4| 1| 3| 2|
+---+---+---+---+