Why use a superclass's __init__ to change it into a subclass?

Question:

I’m working on replicating the SHAP package algorithm – an explainability algorithm for machine learning. I’ve been reading through the author’s code, and I’ve come across a pattern I’ve never seen before.

The author has created a superclass called Explainer, which is a common interface for all the different model specific implementations of the algorithm. The Explainer‘s __init__ method accepts a string for the algorithm type and switches itself to the corresponding subclass if called directly. It does this using multiple versions of the following pattern:

if algorithm == "exact":
    self.__class__ = explainers.Exact
    explainers.Exact.__init__(self, self.model, self.masker, link=self.link, feature_names=self.feature_names, linearize_link=linearize_link, **kwargs)

I understand that this code sets the superclass to one of its subclasses and initialises the subclass by passing itself to __init__. But why would you do this?

Asked By: Connor

||

Answers:

This is a non-standard and awkward way of implementing the Abstract Factory design pattern. The idea is that, although the base class contains state and functionality that are useful for implementing derived classes, it should not be instantiated directly. The full code contains logic that checks whether the base class __init__ is being called "directly" or via super; in the former case, it checks a parameter and chooses an appropriate derived class. (That derived class, of course, will end up calling back to this __init__, but this time super is used, so there is no unbounded recursion.)

To clarify, although this is not standard, it does work:

class Base:
    def __init__(self, *, value=None, kind=None):
        if self.__class__ is Base:
            if kind == 'derived':
                self.__class__ = Derived
                Derived.__init__(self, value)
            else:
                raise ValueError("invalid 'kind'; cannot create Base instances explicitly")

class Derived(Base):
    def __init__(self, value):
        super().__init__()
        self.value = value
    def method(self):
        return 'derived method not defined in base'

Testing it:

>>> Base()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 8, in __init__
ValueError: invalid 'kind'; cannot create Base instances explicitly
>>> Base(value=1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 8, in __init__
ValueError: invalid 'kind'; cannot create Base instances explicitly
>>> Base(value=1, kind='derived')
<__main__.Derived object at 0x7f94fe025790>
>>> Base(value=1, kind='derived').method()
'derived method not defined in base'
>>> Base(value=1, kind='derived').value
1
>>> Derived(2)
<__main__.Derived object at 0x7f94fcc2aa00>
>>> Derived(2).method()
'derived method not defined in base'
>>> Derived(2).value
2

Setting the __class__ attribute allows the factory-created Derived instance to access the derived method, and calling __init__ causes it to have a per-instance value attribute. In fact, we could do those steps in either order, because the Derived __init__ is invoked explicitly rather than via method lookup. Alternatively, it would work (although it would look strange) to call self.__init__(value), but only after changing the __class__.


A more Pythonic way to implement this is to use the standard library abc functionality to mark the base class as "abstract", and use a named method as a factory. For example, decorating the base class __init__ with @abstractmethod will prevent it from being instantiated directly, while forcing derived classes to implement __init__. When they do, they will call super().__init__, which will work without error. For the factory, we can use a method decorated with @staticmethod in the base class (or just an ordinary function; but using @staticmethod effectively "namespaces" the factory). It can, for example, use a string name to choose a derived class, and instantiate it.

A minimal example:

from abc import ABC, abstractmethod

class Base(ABC):
    @abstractmethod
    def __init__(self):
        pass
    @staticmethod
    def create(kind):
        # TODO: add more derived classes to the mapping
        return {'derived': Derived}[kind]()

class Derived(Base):
    def __init__(self):
        super().__init__()

# TODO: implement additional derived classes

Testing it:

>>> Base()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: Can't instantiate abstract class Base with abstract methods __init__
>>> Derived()
<__main__.Derived object at 0x7f94fe025310>
>>> Base.create('derived')
<__main__.Derived object at 0x7f94fe025910>
Answered By: Karl Knechtel