Compare (assert equality of) two complex data structures containing numpy arrays in unittest

Question:

I use Python’s unittest module and want to check if two complex data structures are equal. The objects can be lists of dicts with all sorts of values: numbers, strings, Python containers (lists/tuples/dicts) and numpy arrays. The latter are the reason for asking the question, because I cannot just do

self.assertEqual(big_struct1, big_struct2)

because it produces a

ValueError: The truth value of an array with more than one element is ambiguous.
Use a.any() or a.all()

I imagine that I need to write my own equality test for this. It should work for arbitrary structures. My current idea is a recursive function that:

  • tries direct comparison of the current “node” of arg1 to the corresponding node of arg2;
  • if no exception is raised, moves on (“terminal” nodes/leaves are processed here, too);
  • if ValueError is caught, goes deeper until it finds a numpy.array;
  • compares the arrays (e.g. like this).

What seems a little problematic is keeping track of “corresponding” nodes of two structures, but perhaps zip is all I need here.

The question is: are there good (simpler) alternatives to this approach? Maybe numpy presents some tools for this? If no alternatives are suggested, I will implement this idea (unless I have a better one) and post as an answer.

P.S. I have a vague feeling that I might have seen a question addressing this problem, but I can’t find it now.

P.P.S. An alternative approach would be a function that traverses the structure and converts all numpy.arrays to lists, but is this any easier to implement? Seems the same to me.


Edit: Subclassing numpy.ndarray sounds very promising, but obviously I don’t have both sides of the comparison hard-coded into a test. One of them, though, is indeed hardcoded, so I can:

  • populate it with custom subclasses of numpy.array;
  • change isinstance(other, SaneEqualityArray) to isinstance(other, np.ndarray) in jterrace‘s answer;
  • always use it as LHS in comparisons.

My questions in this regard are:

  1. Will it work (I mean, it sounds all right to me, but maybe some tricky edge cases will not be handled correctly)? Will my custom object always end up as LHS in the recursive equality checks, as I expect?
  2. Again, are there better ways (given that I get at least one of the structures with real numpy arrays).

Edit 2: I tried it out, the (seemingly) working implementation is shown in this answer.

Asked By: Lev Levitsky

||

Answers:

The assertEqual function will invoke the __eq__ method of objects, which should recurse for complex data types. The exception is numpy, which doesn’t have a sane __eq__ method. Using a numpy subclass from this question, you can restore sanity to the equality behavior:

import copy
import numpy
import unittest

class SaneEqualityArray(numpy.ndarray):
    def __eq__(self, other):
        return (isinstance(other, SaneEqualityArray) and
                self.shape == other.shape and
                numpy.ndarray.__eq__(self, other).all())

class TestAsserts(unittest.TestCase):

    def testAssert(self):
        tests = [
            [1, 2],
            {'foo': 2},
            [2, 'foo', {'d': 4}],
            SaneEqualityArray([1, 2]),
            {'foo': {'hey': SaneEqualityArray([2, 3])}},
            [{'foo': SaneEqualityArray([3, 4]), 'd': {'doo': 3}},
             SaneEqualityArray([5, 6]), 34]
        ]
        for t in tests:
            self.assertEqual(t, copy.deepcopy(t))

if __name__ == '__main__':
    unittest.main()

This test passes.

Answered By: jterrace

Would have commented, but it gets too long…

Fun fact, you cannot use == to test if arrays are the same I would suggest you use np.testing.assert_array_equal instead.

  1. that checks dtype, shape, etc.,
  2. that doesn’t fail for the neat little math of (float('nan') == float('nan')) == False (normal python sequence == has an even more fun way of ignoring this sometimes, because it uses PyObject_RichCompareBool which does a (for NaNs incorrect) is quick check (for testing of course that is perfect)…
  3. There is also assert_allclose because floating point equality can get very tricky if you do actual calculations and you usually want almost the same values, since the values can become hardware depended or possibly random depending what you do with them.

I would almost suggest trying serializing it with pickle if you want something this insanely nested, but that is overly strict (and point 3 is of course fully broken then), for example the memory layout of your array does not matter, but matters to its serialization.

Answered By: seberg

So the idea illustrated by jterrace seems to work for me with a slight modification:

class SaneEqualityArray(np.ndarray):
    def __eq__(self, other):
        return (isinstance(other, np.ndarray) and self.shape == other.shape and 
            np.allclose(self, other))

Like I said, the container with these objects should be on the left side of the equality check. I create SaneEqualityArray objects from existing numpy.ndarrays like this:

SaneEqualityArray(my_array.shape, my_array.dtype, my_array)

in accordance with ndarray constructor signature:

ndarray(shape, dtype=float, buffer=None, offset=0,
        strides=None, order=None)

This class is defined within the test suite and serves for testing purposes only. The RHS of the equality check is an actual object returned by the tested function and contains real numpy.ndarray objects.

P.S. Thanks to the authors of both answers posted so far, they were both very helpful. If anyone sees any problems with this approach, I’d appreciate your feedback.

Answered By: Lev Levitsky

I would define my own assertNumpyArraysEqual() method that explicitly makes the comparison that you want to use. That way, your production code is unchanged but you can still make reasonable assertions in your unit tests. Make sure to define it in a module that includes __unittest = True so that it will not be included in stack traces:

import numpy
__unittest = True

def assertNumpyArraysEqual(self, other):
    if self.shape != other.shape:
        raise AssertionError("Shapes don't match")
    if not numpy.allclose(self, other)
        raise AssertionError("Elements don't match!")
Answered By: dbn

Building on @dbw (with thanks), the following method inserted within the test-case subclass worked well for me:

 def assertNumpyArraysEqual(self,this,that,msg=''):
    '''
    modified from http://stackoverflow.com/a/15399475/5459638
    '''
    if this.shape != that.shape:
        raise AssertionError("Shapes don't match")
    if not np.allclose(this,that):
        raise AssertionError("Elements don't match!")

I had it called as self.assertNumpyArraysEqual(this,that) inside my test case methods and worked like a charm.

Answered By: XavierStuvw

check numpy.testing.assert_almost_equal which “raises an AssertionError if two items are not equal up to desired precision”, e.g.:

 import numpy.testing as npt
 npt.assert_almost_equal(np.array([1.0,2.3333333333333]),
                         np.array([1.0,2.33333334]), decimal=9)
Answered By: Hanan Shteingart

I’ve run into the same issue, and developed a function to compare equality based on creating a fixed hash for the object. This has the added advantage that you can test that an object is as expected by comparing it’s hash against a fixed has shored in code.

The code (a stand-alone python file, is here). There are two functions: fixed_hash_eq, which solves your problem, and compute_fixed_hash, which makes a hash from the structure. Tests are here

Here’s a test:

obj1 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj2 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj3 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj3[2]['b'][4] = 0
assert fixed_hash_eq(obj1, obj2)
assert not fixed_hash_eq(obj1, obj3)
Answered By: Peter
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.