Simple way to do multiple dispatch in python? (No external libraries or class building?)

Question:

I’m writing a throwaway script to compute some analytical solutions to a few simulations I’m running.

I would like to implement a function in a way that, based on its inputs, will compute the right answer. So for instance, say I have the following math equation:

tmax = (s1 - s2) / 2 = q * (a^2 / (a^2 - b^2))

It seems simple to me that I should be able to do something like:

def tmax(s1, s2):
    return (s1 - s2) / 2

def tmax(a, b, q):
    return q * (a**2 / (a**2 - b**2))

I may have gotten to used to writing in julia, but I really don’t want to complicate this script more than I need to.

Asked By: dylanjm

||

Answers:

You can do this using an optional argument:

def tmax_2(s1, s2):
    return (s1 - s2) / 2

def tmax_3(a, b, q):
    return q * (a**2 / (a**2 - b**2))

def tmax(a, b, c=None):
    if c is None:
        return tmax_2(a, b)
    else:
        return tmax_3(a, b, c)
Answered By: Fengyang Wang

In statically typed languages like C++, you can overload functions based on the input parameter types (and quantity) but that’s not really possible in Python. There can only be one function of any given name.

What you can do is to use the default argument feature to select one of two pathways within that function, something like:

def tmax(p1, p2, p3 = None):
    # Two-argument variant has p3 as None.

    if p3 is None:
        return (p1 - p2) / 2

    # Otherwise, we have three arguments.

    return (p1 * p1 / (p1 * p1 - p2 * p2)) * p3

If you’re wondering why I’ve change the squaring operations from n ** 2 to n * n, it’s because the latter is faster (or it was, at some point in the past, at least for small integral powers like 2 – this is probably still the case but you may want to confirm).

A possible case where it may be faster to do g1 ** 2 rather than g1 * g1 is where g1 is a global rather than a local (it takes longer for the Python VM to LOAD_GLOBAL rather than LOAD_FAST). This is not the case with the code posted since the argument is inherently non-global.

Answered By: paxdiablo

Just thought I’d offer two more options.

Multiple Dispatch Type Checking

Python has native support for @overload annotations.

It won’t impact runtime, but it will notify your IDE & static analysis tools if you elect to use them.

Fundamentally your implementation will be the same nasty series of hacks Python wants you to use, but you’ll have better debugging support.

This SO post explains it better, I’ve modified the code example to show multiple parameters:

# << Beginning of additional stuff >>
from typing import overload


@overload
def hello(s: int) -> str:
    ...


@overload
def hello(s: str) -> str:
    ...


@overload
def hello(s: int, b: int | float | str) -> str:
    ...
# << End of additional stuff >>

# Janky python overload
def hello(s, b=None):
    if b is None:
        if isinstance(s, int):
            return "s is an integer!"
        if isinstance(s, str):
            return "s is a string!"
    if b is not None:
        if isinstance(s, int) and isinstance(b, int | float | str):
            return "s is an integer & b is an int / float / string!"

    raise ValueError('You must pass either int or str')
print(hello(1))           # s is an integer!
print(hello("Blah"))      # s is a string!
print(hello(11, 1))       # s is an integer & b is an int / float / string!
print(hello(11, "Blah"))  # s is an integer & b is an int / float / string!

My IDE puts a normal error line under the offending argument.

print(hello("Blah", "Blah"))  
# >> ValueError: You must pass either int or str
# PyCharm warns "Blah" w/ "Expected type 'int', got 'str' instead"

print(hello(1, [0, 1]))  
# >> ValueError: You must pass either int or str
# PyCharm warns [0, 1] w/ "Expected type 'int | float | str', got 'list[int]' instead"

print(hello(1, 1) + 1)  
# >> TypeError: can only concatenate str (not "int") to str
# PyCharm warns "+ 1" w/ "Expected type 'str', got 'int' instead"

This is the most direct answer to the post.

SINGLE DISPATCH:

If you only need single dispatch for functions or class methods, have a look at the slightly recent @singledispatch and @singledispatchmethod annotations.

Functions:

from functools import singledispatch


@singledispatch
def coolAdd(a, b):
    raise NotImplementedError('Unsupported type')

@coolAdd.register(int)
@coolAdd.register(float)
def _(a, b):
    print(a + b)

@coolAdd.register(str)
def _(a, b):
    print((a + " " + b).upper())
coolAdd(1, 2)                     # 3
coolAdd(0.1, 0.2)                 # 0.30000000000000004
coolAdd('Python', 'Programming')  # PYTHON PROGRAMMING
coolAdd(b"hi", b"hello")          # NotImplementedError: Unsupported type

Python 3.11 should include union operators for even easier reading (as of now you just put each type as an individual decorator).

Methods:

class CoolClassAdd:

    @singledispatchmethod
    def addMethod(self, arg1, arg2):
        raise NotImplementedError('Unsupported type')

    @addMethod.register(int)
    @addMethod.register(float)
    def _(self, arg1, arg2):
        print(f"Numbers = %s" % (arg1 + arg2))

    @addMethod.register(str)
    def _(self, arg1, arg2):
        print(f"Strings = %s %s" % (arg1, (arg2).upper()))
c = CoolClassAdd()
c.addMethod(1, 2)           # Numbers = 3
c.addMethod(0.1, 0.2)       # Numbers = 0.30000000000000004
c.addMethod(0.1, 2)         # Numbers = 2.1
c.addMethod("hi", "hello")  # hi HELLO

Static & class methods are also supported (and most bugs are resolved as of 3.9.7).

However, beware! Dispatch appears to check only the first (non-self) argument type when evaluating which function/method to use.

c.addMethod(1, "hello")     
# >> TypeError: unsupported operand type(s) for +: 'int' and 'str'

Of course, this would normally call for advanced error handling OR implementing multiple dispatch, and now we’re back to where we started!

Answered By: Ed Shelton