How to add custom method to Pyspark Dataframe class by inheritance

Question:

I am trying to inherit DataFrame class and add additional custom methods as below so that i can chain fluently and also ensure all methods refers the same dataframe. I get an exception as column is not iterable

from pyspark.sql.dataframe import DataFrame

class Myclass(DataFrame):
def __init__(self,df):
    super().__init__(df._jdf, df.sql_ctx)

def add_column3(self):
 // Add column1 to dataframe received
  self._jdf.withColumn("col3",lit(3))
  return self

def add_column4(self):
 // Add column to dataframe received
  self._jdf.withColumn("col4",lit(4))
  return self

if __name__ == "__main__":
'''
Spark Context initialization code
col1 col2
a 1
b 2
'''
  df = spark.createDataFrame([("a",1), ("b",2)], ["col1","col2"])
  myobj = MyClass(df)
  ## Trying to accomplish below where i can chain MyClass methods & Dataframe methods
  myobj.add_column3().add_column4().drop_columns(["col1"])

'''
Expected Output
col2, col3,col4
1,3,4
2,3,4
'''
Asked By: rajdallas

||

Answers:

Below is my solution (which is based on your code).
I don’t know if it’s the best practice, but at least does what you want correctly. Dataframes are immutable objects, so after we add a new column we create a new object but not a Dataframe object but a Myclass object, because we want to have Dataframe and custom methods.

from pyspark.sql.dataframe import DataFrame
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

spark = SparkSession.builder.getOrCreate()


class MyClass(DataFrame):
   def __init__(self,df):
      super().__init__(df._jdf, df.sql_ctx)
      self._df = df

  def add_column3(self):
      #Add column1 to dataframe received
      newDf=self._df.withColumn("col3",F.lit(3))
      return MyClass(newDf)

  def add_column4(self):
      #Add column2 to dataframe received
      newDf=self._df.withColumn("col4",F.lit(4))
      return MyClass(newDf)


df = spark.createDataFrame([("a",1), ("b",2)], ["col1","col2"])
myobj = MyClass(df)
myobj.add_column3().add_column4().na.drop().show()

# Result:
+----+----+----+----+
|col1|col2|col3|col4|
+----+----+----+----+
|   a|   1|   3|   4|
|   b|   2|   3|   4|
+----+----+----+----+
Answered By: ggeop

Actually you don’t need to inherit DataFrame class in order to add some custom methods to DataFrame objects.

In Python, you can add a custom property that wraps your methods like this:

# decorator to attach a function to an attribute
def add_attr(cls):
    def decorator(func):
        @wraps(func)
        def _wrapper(*args, **kwargs):
            f = func(*args, **kwargs)
            return f

        setattr(cls, func.__name__, _wrapper)
        return func

    return decorator

# custom functions
def custom(self):
    @add_attr(custom)
    def add_column3():
        return self.withColumn("col3", lit(3))

    @add_attr(custom)
    def add_column4():
        return self.withColumn("col4", lit(4))

    return custom

# add new property to the Class pyspark.sql.DataFrame
DataFrame.custom = property(custom)

# use it
df.custom.add_column3().show()
Answered By: blackbishop

I think you are looking for something like this:

class dfc:
  def __init__(self, df):
    self.df = df
    
  def func(self, num):
    self.df = self.df.selectExpr(f"id * {num} AS id")
  
  def func1(self, num1):
    self.df = self.df.selectExpr(f"id * {num1} AS id")
    
  def dfdis(self):
    self.df.show()

In this example, there is a dataframe passed to the constructor method which is used by subsequent methods defined inside the class. The state of the dataframe is stored in the instantiated object whenever corresponding methods are called.

df = spark.range(10)

ob = dfc(df)

ob.func(2)

ob.func(2)

ob.dfdis()
Answered By: Hari Gopinath

The answer by blackbishop is worth a look, even if it has no upvotes as of this writing. This seems a good general approach for extending the Spark DataFrame class, and I presume other complex objects. I rewrote it slightly as this:

from pyspark.sql.dataframe import DataFrame
from functools import wraps

# Create a decorator to add a function to a python object
def add_attr(cls):
    def decorator(func):
        @wraps(func)
        def _wrapper(*args, **kwargs):
            f = func(*args, **kwargs)
            return f

        setattr(cls, func.__name__, _wrapper)
        return func

    return decorator

  
# Extensions to the Spark DataFrame class go here
def dataframe_extension(self):
  @add_attr(dataframe_extension)
  def drop_records():
    return(
      self
      .where(~((col('test1') == 'ABC') & (col('test2') =='XYZ')))
      .where(~col('test1').isin(['AAA', 'BBB']))
    )
  return dataframe_extension

DataFrame.dataframe_extension = property(dataframe_extension)
Answered By: mostly definitive

Note: Pyspark is deprecating df.sql_ctx in an upcoming version, so this answer is not future-proof.

I like many of the other answers, but there are a few lingering questions in comments. I think they can be addressed as such:

  • we need to think of everything as immutable, so we return the subclass
  • we do not need to call self._jdf anywhere — instead, just use self as if it were a DataFrame (since it is one — this is why we used inheritance!)
  • we need to explicitly construct a new one of our class since returns from self.foo will be of the base DataFrame type
  • I have added a DataFrameExtender subclass that mediates creation of new classes. Subclasses will inherit parent constructors if not overridden, so we can neaten up the DataFrame constructor to take a DataFrame, and add in the capability to store metadata.

We can make a new class for conceptual stages that the data arrives in, and we can sidecar flags that help us identify the state of the data in the dataframe. Here I add a flag when either add column method is called, and I push forward all existing flags. You can do whatever you like.

This setup means that you can create a sequence of DataFrameExtender objects, such as:

  • RawData, which implements .clean() method, returning CleanedData
  • CleanedData, which implements .normalize() method, returning ModelReadyData
  • ModelReadyData, which implements .train(model) and .predict(model), or .summarize() and which is used in a model as a base DataFrame object would be used.

By splitting these methods into different classes, it means that we cannot call .train() on RawData, but we can take a RawData object and chain together .clean().normalize().train(). This is a functional-like approach, but using immutable objects to assist in interpretation.

Note that DataFrames in Spark are lazily evaluated, which is great for this approach. All of this code just produces a final unevaluated DataFrame object that contains all of the operations that will be performed. We don’t have to worry about memory or copies or anything.

from pyspark.sql.dataframe import DataFrame

class DataFrameExtender(DataFrame):
    def __init__(self,df,**kwargs):
        self.flags = kwargs
        super().__init__(df._jdf, df.sql_ctx)

class ColumnAddedData(DataFrameExtender):
    def add_column3(self):
        df_added_column = self.withColumn("col3", lit(3))
        return ColumnAddedData(df_added_column, with_col3=True, **self.flags)
    def add_column4(self):
        ## Add a bit of complexity: do not call again if we have already called this method
        if not self.flags['with_col4']:
            df_added_column = self.withColumn("col4", lit(4))
            return ColumnAddedData(df_added_column, with_col4=True, **self.flags)
        return self
Answered By: John Haberstroh
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.