Python method that returns instance of class or subclass while keeping subclass attributes

Question:

I’m writing a Python class A with a method square() that returns a new instance of that class with its first attribute squared. For example:

class A:
    def __init__(self, x):
        self.x = x

    def square(self):
        return self.__class__(self.x**2)

I would like to use this method in a subclass B so that it returns an instance of B with x squared but all additional attributes of B unchanged (i. e. taken from the instance). I can get it to work by overwriting square() like this:

class B(A):
    def __init__(self, x, y):
        super(B, self).__init__(x)
        self.y = y

    def square(self):
        return self.__class__(self.x**2, self.y)

If I don’t overwrite the square() method, this little code example will fail because I need to pass a value for y in the constructor of B:

#test.py

class A:
    def __init__(self, x):
        self.x = x 

    def square(self):
        return self.__class__(self.x**2)

class B(A):
    def __init__(self, x, y):
        super(B, self).__init__(x)
        self.y = y 

    #def square(self):
    #    return self.__class__(self.x**2, self.y)

a = A(3)
a2 = a.square()
print(a2.x)
b = B(4, 5)
b2 = b.square()
print(b2.x, b2.y)
$ python test.py
9
Traceback (most recent call last):
  File "test.py", line 20, in <module>
    b2 = b.square()
  File "test.py", line 6, in square
    return self.__class__(self.x**2)
TypeError: __init__() takes exactly 3 arguments (2 given)

Overwriting the method once isn’t a problem. But A potentially has multiple methods similar to square() and there might be more sub(sub)classes. If possible, I would like to avoid overwriting all those methods in all those subclasses.

So my question is this:
Can I somehow implement the method square() in A so that it returns a new instance of the current subclass with x squared and all other attributes it needs for the constructor taken from self (kept constant)? Or do I have to go ahead and overwrite square() for each subclass?

Thanks in advance!

Asked By: nieswand

||

Answers:

I’d suggest implementing .__copy__() (and possibly .__deepcopy__ as well) methods for both classes.

Then your squared can be simple method:

def squared(self):
    newObj = copy(self)
    newObj.x = self.x **2
    return newObj

It will work with inheritance, assuming all child classes have correctly implemented __copy__ method.

EDIT: fixed typo with call to copy()

Full working example:

#test.py

from copy import copy


class A:
    def __init__(self, x):
        self.x = x 

    def square(self):
        newObj = copy(self)
        newObj.x = self.x **2
        return newObj

    def __copy__(self):
        return A(self.x)

class B(A):
    def __init__(self, x, y):
        super(B, self).__init__(x)
        self.y = y 

    def __copy__(self):
        return B(self.x, self.y)

a = A(3)
a2 = a.square()
print(a2.x)
b = B(4, 5)
b2 = b.square()
print(b2.x, b2.y)
Answered By: matszwecja

check if the object contains y then return the right class instance:

class A:
    x: int
    def __init__(self, x):
        self.x = x

    def square(self):
        if hasattr(self, 'y'):
            return self.__class__(self.x ** 2, self.y)
        
        return self.__class__(self.x**2)


class B(A):
    y: int
    def __init__(self, x, y):
        super(B, self).__init__(x)
        self.y = y

    # def square(self):
    #     return self.__class__(self.x**2, self.y)
Answered By: MBarsi
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.