Split rows in train test based on user id PySpark

Question:

I have a PySpark dataframe containing multiple rows for each user:

userId action time
1 buy 8 AM
1 buy 9 AM
1 sell 2 PM
1 sell 3 PM
2 sell 10 AM
2 buy 11 AM
2 sell 2 PM
2 sell 3 PM

My goal is to split this dataset into a training and a test set in such a way that, for each userId, N % of the rows are in the training set and 100-N % rows are in the test set. For example, given N=75%, the training set will be

userId action time
1 buy 8 AM
1 buy 9 AM
1 sell 2 PM
2 sell 10 AM
2 buy 11 AM
2 sell 2 PM

and the test set will be

userId action time
1 sell 3 PM
2 sell 3 PM

Any suggestion? Rows are ordered according to column time and I don’t think that Spark’s RandomSplit may help as I cannot stratify the split on specific columns

Asked By: mht

||

Answers:

You can use ntile:

ds = ds.withColumn("tile", expr("ntile(4) over (partition by id order by id)"))

The dataset where tile=4 is your test set, and tile<4 is your train set:

val train = ds.filter(col("tile").equalTo(4))
val test = ds.filter(col("tile").lt(4))

test.show()
+---+------+----+----+
| id|action|time|tile|
+---+------+----+----+
|  1|  sell|3 PM|   4|
|  2|  sell|3 PM|   4|
+---+------+----+----+

train.show()
+---+------+-----+----+
| id|action| time|tile|
+---+------+-----+----+
|  1|   buy| 8 AM|   1|
|  1|   buy| 9 AM|   2|
|  1|  sell| 2 PM|   3|
|  2|  sell|10 AM|   1|
|  2|   buy|11 AM|   2|
|  2|  sell| 2 PM|   3|
+---+------+-----+----+

Good luck!

Answered By: vilalabinot

We had similar requirement and solved it in following way:

data = [
  (1, "buy"),
  (1, "buy"),
  (1, "sell"),
  (1, "sell"),
  (2, "sell"),
  (2, "buy"),
  (2, "sell"),
  (2, "sell"),
]

df = spark.createDataFrame(data, ["userId", "action"])

Use Window functionality to create serial row numbers. Also compute count of records by each userId. This will be helpful to compute percentage of records to filter.

from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
window = Window.partitionBy(df["userId"]).orderBy(df["userId"])
df_count = df.groupBy("userId").count().withColumnRenamed("userId", "userId_grp")
df = df.join(df_count, col("userId") == col("userId_grp"), "left").drop("userId_grp")
df = df.select("userId", "action", "count", row_number().over(window).alias("row_number"))

df.show()
+------+------+-----+----------+
|userId|action|count|row_number|
+------+------+-----+----------+
|     1|   buy|    4|         1|
|     1|   buy|    4|         2|
|     1|  sell|    4|         3|
|     1|  sell|    4|         4|
|     2|  sell|    4|         1|
|     2|   buy|    4|         2|
|     2|  sell|    4|         3|
|     2|  sell|    4|         4|
+------+------+-----+----------+

Filter training records by required percentage:

n = 75
df_train = df.filter(col("row_number") <= col("count") * n / 100)
df_train.show()
+------+------+-----+----------+
|userId|action|count|row_number|
+------+------+-----+----------+
|     1|   buy|    4|         1|
|     1|   buy|    4|         2|
|     1|  sell|    4|         3|
|     2|  sell|    4|         1|
|     2|   buy|    4|         2|
|     2|  sell|    4|         3|
+------+------+-----+----------+

And remaining records go to the test set:

df_test = df.alias("df").join(df_train.alias("tr"), (col("df.userId") == col("tr.userId")) & (col("df.row_number") == col("tr.row_number")), "leftanti")
df_test.show()
+------+------+-----+----------+
|userId|action|count|row_number|
+------+------+-----+----------+
|     1|  sell|    4|         4|
|     2|  sell|    4|         4|
+------+------+-----+----------+
Answered By: Azhar Khan
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.