Pyspark how to include the row that fails the where condition


I have a sorted table that looks like the following.

col1 col2
1000 1000
2600 2600
3600 3600
3600 4050
3600 4500

I want to create a flag such that it is true when col1 and col2 are both less than 4000.
This is easy with

pyspark_df = pyspark_df.withColumn('flag', when((pyspark_df['col1'] <= 4000) & (pyspark_df['col2'] <= 4000), 1).otherwise(0)

However, I also want the first row that fails (in this case row 4) to also have this flag be true. How should I do this?

Asked By: Peter Yang



You could create a lag column and then use bitwiseOR between the two columns.

from pyspark.sql.functions import when, lag, col, monotonically_increasing_id
from pyspark.sql.window import Window

df = spark.createDataFrame(

df = df.withColumn('flag', when((df['col1'] <= 4000) & (df['col2'] <= 4000), 1).otherwise(0))

df = df.withColumn('idx', monotonically_increasing_id())

w = Window().partitionBy().orderBy(col('idx'))
df = df.withColumn('lag', lag('flag', 1).over(w))
df = df.fillna(0, subset='lag')

df = df.withColumn('flag', df.flag.bitwiseOR(df.lag))'col1','col2','flag').show()


|1000|1000|   1|
|2600|2600|   1|
|3600|3600|   1|
|3600|4500|   1|
|3600|4500|   0|
Answered By: Chris
