Compare (assert equality of) two complex data structures containing numpy arrays in unittest
Question:
I use Python’s unittest
module and want to check if two complex data structures are equal. The objects can be lists of dicts with all sorts of values: numbers, strings, Python containers (lists/tuples/dicts) and numpy
arrays. The latter are the reason for asking the question, because I cannot just do
self.assertEqual(big_struct1, big_struct2)
because it produces a
ValueError: The truth value of an array with more than one element is ambiguous.
Use a.any() or a.all()
I imagine that I need to write my own equality test for this. It should work for arbitrary structures. My current idea is a recursive function that:
- tries direct comparison of the current “node” of
arg1
to the corresponding node of arg2
;
- if no exception is raised, moves on (“terminal” nodes/leaves are processed here, too);
- if
ValueError
is caught, goes deeper until it finds a numpy.array
;
- compares the arrays (e.g. like this).
What seems a little problematic is keeping track of “corresponding” nodes of two structures, but perhaps zip
is all I need here.
The question is: are there good (simpler) alternatives to this approach? Maybe numpy
presents some tools for this? If no alternatives are suggested, I will implement this idea (unless I have a better one) and post as an answer.
P.S. I have a vague feeling that I might have seen a question addressing this problem, but I can’t find it now.
P.P.S. An alternative approach would be a function that traverses the structure and converts all numpy.array
s to lists, but is this any easier to implement? Seems the same to me.
Edit: Subclassing numpy.ndarray
sounds very promising, but obviously I don’t have both sides of the comparison hard-coded into a test. One of them, though, is indeed hardcoded, so I can:
- populate it with custom subclasses of
numpy.array
;
- change
isinstance(other, SaneEqualityArray)
to isinstance(other, np.ndarray)
in jterrace‘s answer;
- always use it as LHS in comparisons.
My questions in this regard are:
- Will it work (I mean, it sounds all right to me, but maybe some tricky edge cases will not be handled correctly)? Will my custom object always end up as LHS in the recursive equality checks, as I expect?
- Again, are there better ways (given that I get at least one of the structures with real
numpy
arrays).
Edit 2: I tried it out, the (seemingly) working implementation is shown in this answer.
Answers:
The assertEqual
function will invoke the __eq__
method of objects, which should recurse for complex data types. The exception is numpy, which doesn’t have a sane __eq__
method. Using a numpy subclass from this question, you can restore sanity to the equality behavior:
import copy
import numpy
import unittest
class SaneEqualityArray(numpy.ndarray):
def __eq__(self, other):
return (isinstance(other, SaneEqualityArray) and
self.shape == other.shape and
numpy.ndarray.__eq__(self, other).all())
class TestAsserts(unittest.TestCase):
def testAssert(self):
tests = [
[1, 2],
{'foo': 2},
[2, 'foo', {'d': 4}],
SaneEqualityArray([1, 2]),
{'foo': {'hey': SaneEqualityArray([2, 3])}},
[{'foo': SaneEqualityArray([3, 4]), 'd': {'doo': 3}},
SaneEqualityArray([5, 6]), 34]
]
for t in tests:
self.assertEqual(t, copy.deepcopy(t))
if __name__ == '__main__':
unittest.main()
This test passes.
Would have commented, but it gets too long…
Fun fact, you cannot use ==
to test if arrays are the same I would suggest you use np.testing.assert_array_equal
instead.
- that checks dtype, shape, etc.,
- that doesn’t fail for the neat little math of
(float('nan') == float('nan')) == False
(normal python sequence ==
has an even more fun way of ignoring this sometimes, because it uses PyObject_RichCompareBool
which does a (for NaNs incorrect) is
quick check (for testing of course that is perfect)…
- There is also
assert_allclose
because floating point equality can get very tricky if you do actual calculations and you usually want almost the same values, since the values can become hardware depended or possibly random depending what you do with them.
I would almost suggest trying serializing it with pickle if you want something this insanely nested, but that is overly strict (and point 3 is of course fully broken then), for example the memory layout of your array does not matter, but matters to its serialization.
So the idea illustrated by jterrace seems to work for me with a slight modification:
class SaneEqualityArray(np.ndarray):
def __eq__(self, other):
return (isinstance(other, np.ndarray) and self.shape == other.shape and
np.allclose(self, other))
Like I said, the container with these objects should be on the left side of the equality check. I create SaneEqualityArray
objects from existing numpy.ndarray
s like this:
SaneEqualityArray(my_array.shape, my_array.dtype, my_array)
in accordance with ndarray
constructor signature:
ndarray(shape, dtype=float, buffer=None, offset=0,
strides=None, order=None)
This class is defined within the test suite and serves for testing purposes only. The RHS of the equality check is an actual object returned by the tested function and contains real numpy.ndarray
objects.
P.S. Thanks to the authors of both answers posted so far, they were both very helpful. If anyone sees any problems with this approach, I’d appreciate your feedback.
I would define my own assertNumpyArraysEqual() method that explicitly makes the comparison that you want to use. That way, your production code is unchanged but you can still make reasonable assertions in your unit tests. Make sure to define it in a module that includes __unittest = True
so that it will not be included in stack traces:
import numpy
__unittest = True
def assertNumpyArraysEqual(self, other):
if self.shape != other.shape:
raise AssertionError("Shapes don't match")
if not numpy.allclose(self, other)
raise AssertionError("Elements don't match!")
Building on @dbw (with thanks), the following method inserted within the test-case subclass worked well for me:
def assertNumpyArraysEqual(self,this,that,msg=''):
'''
modified from http://stackoverflow.com/a/15399475/5459638
'''
if this.shape != that.shape:
raise AssertionError("Shapes don't match")
if not np.allclose(this,that):
raise AssertionError("Elements don't match!")
I had it called as self.assertNumpyArraysEqual(this,that)
inside my test case methods and worked like a charm.
check numpy.testing.assert_almost_equal
which “raises an AssertionError if two items are not equal up to desired precision”, e.g.:
import numpy.testing as npt
npt.assert_almost_equal(np.array([1.0,2.3333333333333]),
np.array([1.0,2.33333334]), decimal=9)
I’ve run into the same issue, and developed a function to compare equality based on creating a fixed hash for the object. This has the added advantage that you can test that an object is as expected by comparing it’s hash against a fixed has shored in code.
The code (a stand-alone python file, is here). There are two functions: fixed_hash_eq
, which solves your problem, and compute_fixed_hash
, which makes a hash from the structure. Tests are here
Here’s a test:
obj1 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj2 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj3 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj3[2]['b'][4] = 0
assert fixed_hash_eq(obj1, obj2)
assert not fixed_hash_eq(obj1, obj3)
I use Python’s unittest
module and want to check if two complex data structures are equal. The objects can be lists of dicts with all sorts of values: numbers, strings, Python containers (lists/tuples/dicts) and numpy
arrays. The latter are the reason for asking the question, because I cannot just do
self.assertEqual(big_struct1, big_struct2)
because it produces a
ValueError: The truth value of an array with more than one element is ambiguous.
Use a.any() or a.all()
I imagine that I need to write my own equality test for this. It should work for arbitrary structures. My current idea is a recursive function that:
- tries direct comparison of the current “node” of
arg1
to the corresponding node ofarg2
; - if no exception is raised, moves on (“terminal” nodes/leaves are processed here, too);
- if
ValueError
is caught, goes deeper until it finds anumpy.array
; - compares the arrays (e.g. like this).
What seems a little problematic is keeping track of “corresponding” nodes of two structures, but perhaps zip
is all I need here.
The question is: are there good (simpler) alternatives to this approach? Maybe numpy
presents some tools for this? If no alternatives are suggested, I will implement this idea (unless I have a better one) and post as an answer.
P.S. I have a vague feeling that I might have seen a question addressing this problem, but I can’t find it now.
P.P.S. An alternative approach would be a function that traverses the structure and converts all numpy.array
s to lists, but is this any easier to implement? Seems the same to me.
Edit: Subclassing numpy.ndarray
sounds very promising, but obviously I don’t have both sides of the comparison hard-coded into a test. One of them, though, is indeed hardcoded, so I can:
- populate it with custom subclasses of
numpy.array
; - change
isinstance(other, SaneEqualityArray)
toisinstance(other, np.ndarray)
in jterrace‘s answer; - always use it as LHS in comparisons.
My questions in this regard are:
- Will it work (I mean, it sounds all right to me, but maybe some tricky edge cases will not be handled correctly)? Will my custom object always end up as LHS in the recursive equality checks, as I expect?
- Again, are there better ways (given that I get at least one of the structures with real
numpy
arrays).
Edit 2: I tried it out, the (seemingly) working implementation is shown in this answer.
The assertEqual
function will invoke the __eq__
method of objects, which should recurse for complex data types. The exception is numpy, which doesn’t have a sane __eq__
method. Using a numpy subclass from this question, you can restore sanity to the equality behavior:
import copy
import numpy
import unittest
class SaneEqualityArray(numpy.ndarray):
def __eq__(self, other):
return (isinstance(other, SaneEqualityArray) and
self.shape == other.shape and
numpy.ndarray.__eq__(self, other).all())
class TestAsserts(unittest.TestCase):
def testAssert(self):
tests = [
[1, 2],
{'foo': 2},
[2, 'foo', {'d': 4}],
SaneEqualityArray([1, 2]),
{'foo': {'hey': SaneEqualityArray([2, 3])}},
[{'foo': SaneEqualityArray([3, 4]), 'd': {'doo': 3}},
SaneEqualityArray([5, 6]), 34]
]
for t in tests:
self.assertEqual(t, copy.deepcopy(t))
if __name__ == '__main__':
unittest.main()
This test passes.
Would have commented, but it gets too long…
Fun fact, you cannot use ==
to test if arrays are the same I would suggest you use np.testing.assert_array_equal
instead.
- that checks dtype, shape, etc.,
- that doesn’t fail for the neat little math of
(float('nan') == float('nan')) == False
(normal python sequence==
has an even more fun way of ignoring this sometimes, because it usesPyObject_RichCompareBool
which does a (for NaNs incorrect)is
quick check (for testing of course that is perfect)… - There is also
assert_allclose
because floating point equality can get very tricky if you do actual calculations and you usually want almost the same values, since the values can become hardware depended or possibly random depending what you do with them.
I would almost suggest trying serializing it with pickle if you want something this insanely nested, but that is overly strict (and point 3 is of course fully broken then), for example the memory layout of your array does not matter, but matters to its serialization.
So the idea illustrated by jterrace seems to work for me with a slight modification:
class SaneEqualityArray(np.ndarray):
def __eq__(self, other):
return (isinstance(other, np.ndarray) and self.shape == other.shape and
np.allclose(self, other))
Like I said, the container with these objects should be on the left side of the equality check. I create SaneEqualityArray
objects from existing numpy.ndarray
s like this:
SaneEqualityArray(my_array.shape, my_array.dtype, my_array)
in accordance with ndarray
constructor signature:
ndarray(shape, dtype=float, buffer=None, offset=0,
strides=None, order=None)
This class is defined within the test suite and serves for testing purposes only. The RHS of the equality check is an actual object returned by the tested function and contains real numpy.ndarray
objects.
P.S. Thanks to the authors of both answers posted so far, they were both very helpful. If anyone sees any problems with this approach, I’d appreciate your feedback.
I would define my own assertNumpyArraysEqual() method that explicitly makes the comparison that you want to use. That way, your production code is unchanged but you can still make reasonable assertions in your unit tests. Make sure to define it in a module that includes __unittest = True
so that it will not be included in stack traces:
import numpy
__unittest = True
def assertNumpyArraysEqual(self, other):
if self.shape != other.shape:
raise AssertionError("Shapes don't match")
if not numpy.allclose(self, other)
raise AssertionError("Elements don't match!")
Building on @dbw (with thanks), the following method inserted within the test-case subclass worked well for me:
def assertNumpyArraysEqual(self,this,that,msg=''):
'''
modified from http://stackoverflow.com/a/15399475/5459638
'''
if this.shape != that.shape:
raise AssertionError("Shapes don't match")
if not np.allclose(this,that):
raise AssertionError("Elements don't match!")
I had it called as self.assertNumpyArraysEqual(this,that)
inside my test case methods and worked like a charm.
check numpy.testing.assert_almost_equal
which “raises an AssertionError if two items are not equal up to desired precision”, e.g.:
import numpy.testing as npt
npt.assert_almost_equal(np.array([1.0,2.3333333333333]),
np.array([1.0,2.33333334]), decimal=9)
I’ve run into the same issue, and developed a function to compare equality based on creating a fixed hash for the object. This has the added advantage that you can test that an object is as expected by comparing it’s hash against a fixed has shored in code.
The code (a stand-alone python file, is here). There are two functions: fixed_hash_eq
, which solves your problem, and compute_fixed_hash
, which makes a hash from the structure. Tests are here
Here’s a test:
obj1 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj2 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj3 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj3[2]['b'][4] = 0
assert fixed_hash_eq(obj1, obj2)
assert not fixed_hash_eq(obj1, obj3)