How to perform approximate structural pattern matching for floats and complex
Question:
I’ve read about and understand floating point round-off issues such as:
>>> sum([0.1] * 10) == 1.0
False
>>> 1.1 + 2.2 == 3.3
False
>>> sin(radians(45)) == sqrt(2) / 2
False
I also know how to work around these issues with math.isclose() and cmath.isclose().
The question is how to apply those work arounds to Python’s match/case statement. I would like this to work:
match 1.1 + 2.2:
case 3.3:
print('hit!') # currently, this doesn't match
Answers:
The key to the solution is to build a wrapper that overrides the __eq__
method and replaces it with an approximate match:
import cmath
class Approximately(complex):
def __new__(cls, x, /, **kwargs):
result = complex.__new__(cls, x)
result.kwargs = kwargs
return result
def __eq__(self, other):
try:
return isclose(self, other, **self.kwargs)
except TypeError:
return NotImplemented
It creates approximate equality tests for both float values and complex values:
>>> Approximately(1.1 + 2.2) == 3.3
True
>>> Approximately(1.1 + 2.2, abs_tol=0.2) == 3.4
True
>>> Approximately(1.1j + 2.2j) == 0.0 + 3.3j
True
Here is how to use it in a match/case statement:
for x in [sum([0.1] * 10), 1.1 + 2.2, sin(radians(45))]:
match Approximately(x):
case 1.0:
print(x, 'sums to about 1.0')
case 3.3:
print(x, 'sums to about 3.3')
case 0.7071067811865475:
print(x, 'is close to sqrt(2) / 2')
case _:
print('Mismatch')
This outputs:
0.9999999999999999 sums to about 1.0
3.3000000000000003 sums to about 3.3
0.7071067811865475 is close to sqrt(2) / 2
Raymond’s answer is very fancy and ergonomic, but seems like a lot of magic for something that could be much simpler. A more minimal version would just be to capture the calculated value and just explicitly check whether the things are "close", e.g.:
import math
match 1.1 + 2.2:
case x if math.isclose(x, 3.3):
print(f"{x} is close to 3.3")
case x:
print(f"{x} wasn't close)
I’d also suggest only using cmath.isclose()
where/when you actually need it, using appropriate types lets you ensure your code is doing what you expect.
The above example is just the minimum code used to demonstrate the matching and, as pointed out in the comments, could be more easily implemented using a traditional if
statement. At the risk of derailing the original question, this is a somewhat more complete example:
from dataclasses import dataclass
@dataclass
class Square:
size: float
@dataclass
class Rectangle:
width: float
height: float
def classify(obj: Square | Rectangle) -> str:
match obj:
case Square(size=x) if math.isclose(x, 1):
return "~unit square"
case Square(size=x):
return f"square, size={x}"
case Rectangle(width=w, height=h) if math.isclose(w, h):
return "~square rectangle"
case Rectangle(width=w, height=h):
return f"rectangle, width={w}, height={h}"
almost_one = 1 + 1e-10
print(classify(Square(almost_one)))
print(classify(Rectangle(1, almost_one)))
print(classify(Rectangle(1, 2)))
Not sure if I’d actually use a match
statement here, but is hopefully more representative!
I’ve read about and understand floating point round-off issues such as:
>>> sum([0.1] * 10) == 1.0
False
>>> 1.1 + 2.2 == 3.3
False
>>> sin(radians(45)) == sqrt(2) / 2
False
I also know how to work around these issues with math.isclose() and cmath.isclose().
The question is how to apply those work arounds to Python’s match/case statement. I would like this to work:
match 1.1 + 2.2:
case 3.3:
print('hit!') # currently, this doesn't match
The key to the solution is to build a wrapper that overrides the __eq__
method and replaces it with an approximate match:
import cmath
class Approximately(complex):
def __new__(cls, x, /, **kwargs):
result = complex.__new__(cls, x)
result.kwargs = kwargs
return result
def __eq__(self, other):
try:
return isclose(self, other, **self.kwargs)
except TypeError:
return NotImplemented
It creates approximate equality tests for both float values and complex values:
>>> Approximately(1.1 + 2.2) == 3.3
True
>>> Approximately(1.1 + 2.2, abs_tol=0.2) == 3.4
True
>>> Approximately(1.1j + 2.2j) == 0.0 + 3.3j
True
Here is how to use it in a match/case statement:
for x in [sum([0.1] * 10), 1.1 + 2.2, sin(radians(45))]:
match Approximately(x):
case 1.0:
print(x, 'sums to about 1.0')
case 3.3:
print(x, 'sums to about 3.3')
case 0.7071067811865475:
print(x, 'is close to sqrt(2) / 2')
case _:
print('Mismatch')
This outputs:
0.9999999999999999 sums to about 1.0
3.3000000000000003 sums to about 3.3
0.7071067811865475 is close to sqrt(2) / 2
Raymond’s answer is very fancy and ergonomic, but seems like a lot of magic for something that could be much simpler. A more minimal version would just be to capture the calculated value and just explicitly check whether the things are "close", e.g.:
import math
match 1.1 + 2.2:
case x if math.isclose(x, 3.3):
print(f"{x} is close to 3.3")
case x:
print(f"{x} wasn't close)
I’d also suggest only using cmath.isclose()
where/when you actually need it, using appropriate types lets you ensure your code is doing what you expect.
The above example is just the minimum code used to demonstrate the matching and, as pointed out in the comments, could be more easily implemented using a traditional if
statement. At the risk of derailing the original question, this is a somewhat more complete example:
from dataclasses import dataclass
@dataclass
class Square:
size: float
@dataclass
class Rectangle:
width: float
height: float
def classify(obj: Square | Rectangle) -> str:
match obj:
case Square(size=x) if math.isclose(x, 1):
return "~unit square"
case Square(size=x):
return f"square, size={x}"
case Rectangle(width=w, height=h) if math.isclose(w, h):
return "~square rectangle"
case Rectangle(width=w, height=h):
return f"rectangle, width={w}, height={h}"
almost_one = 1 + 1e-10
print(classify(Square(almost_one)))
print(classify(Rectangle(1, almost_one)))
print(classify(Rectangle(1, 2)))
Not sure if I’d actually use a match
statement here, but is hopefully more representative!