unittest.mock: asserting partial match for method argument

Question:

Rubyist writing Python here. I’ve got some code that looks kinda like this:

result = database.Query('complicated sql with an id: %s' % id)

database.Query is mocked out, and I want to test that the ID gets injected in correctly without hardcoding the entire SQL statement into my test. In Ruby/RR, I would have done this:

mock(database).query(/#{id}/)

But I can’t see a way to set up a ‘selective mock’ like that in unittest.mock, at least without some hairy side_effect logic. So I tried using the regexp in the assertion instead:

with patch(database) as MockDatabase:
  instance = MockDatabase.return_value
  ...
  instance.Query.assert_called_once_with(re.compile("%s" % id))

But that doesn’t work either. This approach does work, but it’s ugly:

with patch(database) as MockDatabase:
  instance = MockDatabase.return_value
  ...
  self.assertIn(id, instance.Query.call_args[0][0])

Better ideas?

Asked By: lambshaanxy

||

Answers:

import mock

class AnyStringWith(str):
    def __eq__(self, other):
        return self in other

...
result = database.Query('complicated sql with an id: %s' % id)
database.Query.assert_called_once_with(AnyStringWith(id))
...

Preemptively requires a matching string

def arg_should_contain(x):
    def wrapper(arg):
        assert str(x) in arg, "'%s' does not contain '%s'" % (arg, x)
    return wrapper

...
database.Query = arg_should_contain(id)
result = database.Query('complicated sql with an id: %s' % id)

UPDATE

Using libraries like callee, you don’t need to implement AnyStringWith.

from callee import Contains

database.Query.assert_called_once_with(Contains(id))

https://callee.readthedocs.io/en/latest/reference/operators.html#callee.operators.Contains

Answered By: falsetru

I always write my unit tests so they reflect the ‘real world’. I don’t really know what you want to test except for the ID gets injected in correctly.

I don’t know what the database.Query is supposed to do, but I guess it’s supposed to create a query object you can call or pass to a connection later?

The best way you can test this to take a real world example. Doing something simple like checking if the id occurs in the query is too error prone. I often see people wanting to do magic stuff in their unit tests, this always leads to problems. Keep your unit tests simple and static. In your case you could do:

class QueryTest(unittest.TestCase):
    def test_insert_id_simple(self):
        expected = 'a simple query with an id: 2'
        query = database.Query('a simple query with an id: %s' % 2)
        self.assertEqual(query, expected)

    def test_insert_id_complex(self):
        expected = 'some complex query with an id: 6'
        query = database.Query('some complex query with an id: %s' 6)
        self.assertEqual(query, expected)

If database.Query directly executes a query in the database, you might want to consider using something like database.query or database.execute instead. The capital in the Query implies you create an object if it’s all lowercase it implies you call a function. It’s more a naming convention and my opinion, but I’m just throwing it out there. 😉

If the database.Query directly queries you can best patch the method it is calling. For example, if it looks like this:

def Query(self, query):
    self.executeSQL(query)
    return query

You can use mock.patch to prevent the unit test from going to the database:

@mock.patch('database.executeSQL')
def test_insert_id_simple(self, mck):
    expected = 'a simple query with an id: 2'
    query = database.Query('a simple query with an id: %s' % 2)
    self.assertEqual(query, expected)

As an extra tip, try to use the str.format method. The % formatting may go away in the future. See this question for more info.

I also cannot help but feel testing string formatting is redundant. If 'test %s' % 'test' doesn’t work it would mean something is wrong with Python. It would only make sense if you wanted to test custom query building. e.g. inserting strings should be quoted, numbers shouldn’t, escape special characters, etc.

Answered By: siebz0r

You can just use unittest.mock.ANY 🙂

from unittest.mock import Mock, ANY

def foo(some_string):
    print(some_string)

foo = Mock()
foo("bla")
foo.assert_called_with(ANY)

As described here –
https://docs.python.org/3/library/unittest.mock.html#any

Answered By: Kfir Eisner

You can use match_equality from PyHamcrest library to wrap the matches_regexp matcher from the same library:

from hamcrest.library.integration import match_equality

with patch(database) as MockDatabase:
  instance = MockDatabase.return_value
  ...
  expected_arg = matches_regexp(id)
  instance.Query.assert_called_once_with(match_equality(expected_arg))

This method is mentioned also in Python’s unittest.mock documentation:

As of version 1.5, the Python testing library PyHamcrest provides similar functionality, that may be useful here, in the form of its equality matcher (hamcrest.library.integration.match_equality).

If you don’t want to use PyHamcrest, the documentation linked above also shows how to write a custom matcher by defining a class with an __eq__ method (as suggested in falsetrus answer):

class Matcher:
    def __init__(self, compare, expected):
        self.compare = compare
        self.expected = expected

    def __eq__(self, actual):
        return self.compare(self.expected, actual)

match_foo = Matcher(compare, Foo(1, 2))
mock.assert_called_with(match_foo)

You could replace the call to self.compare here with your own regex matching and return False if none found or raise an AssertionError with a descriptive error message of your choice.

Answered By: saaskis

The chosen answer is absolutely wonderful.

However, the original question seemed to want to match on the basis of a regex. I offer the following, which I would never have been able to devise without falsetru’s chosen answer:

class AnyStringWithRegex(str):
    def __init__(self, case_insensitive=True):
        self.case_insensitive = case_insensitive
    def __eq__(self, other):
        if self.case_insensitive:
            return len(re.findall(self.lower(), other.lower(), re.DOTALL)) != 0
        return len(re.findall(self, other, re.DOTALL)) != 0

No doubt many variations on this theme are possible. This compares two objects on the basis of specified attributes:

class AnyEquivalent():
    # compares two objects on basis of specified attributes
    def __init__(self, compared_object, *attrs):
        self.compared_object = compared_object
        self.attrs = attrs
        
    def __eq__(self, other):
        equal_objects = True
        for attr in self.attrs:
            if hasattr(other, attr):
                if getattr(self.compared_object, attr) != getattr(other, attr):
                    equal_objects = False
                    break
            else:
                equal_objects = False
                break
        return equal_objects

For example, this fails even when the file is correct (slightly confusingly, as the error message says the f values are the same in terms of their str(f) output). The explanation being that the two file objects are different ones:

f = open(FILENAME, 'w')
mock_run.assert_called_once_with(['pip', 'freeze'], stdout=f)

But this passes (explicitly comparing only on the basis of the values of the specified 3 attributes):

f = open(FILENAME, 'w')
mock_run.assert_called_once_with(['pip', 'freeze'], stdout=AnyEquivalent(f, 'name', 'mode', 'encoding'))
Answered By: mike rodent
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.