Pyspark how to include the row that fails the where condition

Question:

I have a sorted table that looks like the following.

col1 col2
1000 1000
2600 2600
3600 3600
3600 4050
3600 4500

I want to create a flag such that it is true when col1 and col2 are both less than 4000.
This is easy with

pyspark_df = pyspark_df.withColumn('flag', when((pyspark_df['col1'] <= 4000) & (pyspark_df['col2'] <= 4000), 1).otherwise(0)

However, I also want the first row that fails (in this case row 4) to also have this flag be true. How should I do this?

Asked By: Peter Yang

||

Answers:

You could create a lag column and then use bitwiseOR between the two columns.

from pyspark.sql.functions import when, lag, col, monotonically_increasing_id
from pyspark.sql.window import Window


df = spark.createDataFrame(
  [[1000,1000],
  [2600,2600],
  [3600,3600],
  [3600,4500],
  [3600,4500]],['col1','col2']
)

df = df.withColumn('flag', when((df['col1'] <= 4000) & (df['col2'] <= 4000), 1).otherwise(0))

df = df.withColumn('idx', monotonically_increasing_id())

w = Window().partitionBy().orderBy(col('idx'))
df = df.withColumn('lag', lag('flag', 1).over(w))
df = df.fillna(0, subset='lag')

df = df.withColumn('flag', df.flag.bitwiseOR(df.lag))
df.select('col1','col2','flag').show()

Output

+----+----+----+
|col1|col2|flag|
+----+----+----+
|1000|1000|   1|
|2600|2600|   1|
|3600|3600|   1|
|3600|4500|   1|
|3600|4500|   0|
+----+----+----+
Answered By: Chris
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.