I'm trying to assert that a post_save signal receiver is called when an instance of my Client model is saved.
The signal receiver looks as follow:
# reports/signals.py
#receiver(post_save, sender=Client)
def create_client_draft(sender, instance=None, created=False, **kwargs):
"""Guarantees a DraftSchedule exists for each Client post save"""
print('called') # Log to stdout when called
if created and not kwargs.get('raw', False):
DraftSchedule.objects.get_or_create(client=instance)
I've set up a test that looks like this
#pytest.mark.django_db
#patch('reports.signals.create_client_draft')
def test_auto_create_draftschedule_on_client_creation(mock_signal):
client = mixer.blend(Client) # Creates a Client with random data
assert mock_signal.call_count == 1
I would expect this test to pass since the called print statement appears in captured stdout when the test is ran.
However, the test runner seems to think my mock function was never called at all.
mock_signal = <MagicMock name='create_client_draft' id='139903470431088'>
#pytest.mark.django_db
#patch('reports.signals.create_client_draft')
def test_auto_create_draftschedule_on_client_creation(mock_signal):
client = mixer.blend(Client)
> assert mock_signal.call_count == 1
E AssertionError: assert 0 == 1
E + where 0 = <MagicMock name='create_client_draft' id='139903470431088'>.call_count
reports/tests/test_signals.py:36: AssertionError
---------------------------------------------------------------------------------------------------------------------------------------------------- Captured stdout call -----------------------------------------------------------------------------------------------------------------------------------------------------
called
The print statement seems to suggest that the function was called during the test, whereas the test assertion suggests otherwise. Am I missing something obvious here with the mocking library?
Patching mock objects only works for callers that look up the method at run time. Signal handlers are held in a table, so they don't look up your mocked version.
It's a bit hacky, but you could have your signal handler call a helper function. Then the helper function could be mocked.
# reports/signals.py
#receiver(post_save, sender=Client)
def create_client_draft_handler(sender, instance=None, created=False, **kwargs):
create_client_draft(sender, instance, created, **kwargs)
def create_client_draft(sender, instance=None, created=False, **kwargs):
"""Guarantees a DraftSchedule exists for each Client post save
This function can be mocked, because it's called by name.
"""
print('called') # Log to stdout when called
if created and not kwargs.get('raw', False):
DraftSchedule.objects.get_or_create(client=instance)
Related
I have a module loaders with class "Loader" with class method "load". During test, I want to append some additional steps to "Loader.load" to account for test specific data post-processing, so essentially overriding it. How do I properly do that?
I tried creating a mock class which inherits Loader and use monkeypatch.setattr("loaders.Loader", mock_loader), but this only works when I run one single test, but not when I run all tests.
loaders.py
class Loader:
def load():
# do something
return data
test.py
from loaders import Loader
class MockLoader(Loader):
def load():
data = super().load()
# do something to data
return data
def test_loader_special1(monkeypatch):
monkeypatch.setattr("loaders.Loader", MockLoader)
#run test logic 1
def test_loader_special2(monkeypatch):
monkeypatch.setattr("loaders.Loader", MockLoader)
#run test logic 2
Use patch or patch.object either via the builtin library unittest or the external library pytest-mock :
patch() acts as a function decorator, class decorator or a context
manager. Inside the body of the function or with statement, the target
is patched with a new object.
Where it is explicitly documented that the patch is only applicable per test:
When the function/with statement exits the patch is undone.
The patch will be used to wrap around your real implementation so that you can perform any necessary steps before and/or after calling it.
loaders.py
class Loader:
def load(self):
print("Real load() called")
return "real"
test_loaders.py
import pytest
from unittest.mock import patch
from loaders import Loader
#pytest.fixture
def mock_load(mocker):
real_func = Loader.load
def mock_func(self, *args, **kwargs):
print("Mock load() called")
data = real_func(self, *args, **kwargs)
data += " and mock"
return data
# Option 1: Using pytest-mock + new
mocker.patch.object(Loader, 'load', new=mock_func)
"""
Alternative ways of doing Option 1. All would just work the same.
# Option 2: Using pytest-mock + side_effect
mocker.patch.object(Loader, 'load', side_effect=mock_func, autospec=True)
# Option 3: Using unittest + new
with patch.object(Loader, 'load', new=mock_func):
yield
# Option 4: Using unittest + new
with patch.object(Loader, 'load', side_effect=mock_func, autospec=True):
yield
"""
def test_loader_special1(mock_load):
data = Loader().load()
print(f"{data=}")
assert data == "real and mock"
def test_loader_special2(mock_load):
data = Loader().load()
print(f"{data=}")
assert data == "real and mock"
def test_loader_special3():
data = Loader().load()
print(f"{data=}")
assert data == "real"
def test_loader_special4(mock_load):
data = Loader().load()
print(f"{data=}")
assert data == "real and mock"
Output:
$ pytest test_loaders.py -q -rP
.... [100%]
================================================================================================= PASSES ==================================================================================================
__________________________________________________________________________________________ test_loader_special1 ___________________________________________________________________________________________
------------------------------------------------------------------------------------------ Captured stdout call -------------------------------------------------------------------------------------------
Mock load() called
Real load() called
data='real and mock'
__________________________________________________________________________________________ test_loader_special2 ___________________________________________________________________________________________
------------------------------------------------------------------------------------------ Captured stdout call -------------------------------------------------------------------------------------------
Mock load() called
Real load() called
data='real and mock'
__________________________________________________________________________________________ test_loader_special3 ___________________________________________________________________________________________
------------------------------------------------------------------------------------------ Captured stdout call -------------------------------------------------------------------------------------------
Real load() called
data='real'
__________________________________________________________________________________________ test_loader_special4 ___________________________________________________________________________________________
------------------------------------------------------------------------------------------ Captured stdout call -------------------------------------------------------------------------------------------
Mock load() called
Real load() called
data='real and mock'
4 passed in 0.01s
Modifying the answer https://stackoverflow.com/a/68675071/14536215 to use monkeypatch:
#pytest.fixture
def mock_load(monkeypatch):
real_func = Loader.load
def mock_func(self, *args, **kwargs):
print("Mock load() called")
data = real_func(self, *args, **kwargs)
data += " and mock"
return data
monkeypatch.setattr(Loader, "load", mock_func)
You can then mark the fixture to be loaded for all tests in the testing module with:
#pytest.fixture(autouse=True)
or mark the tests to use specified fixtures with:
#pytest.mark.usefixtures("mock_load")
def test_loader_special1():
...
Edit:
If you want to mock the whole class, you need to mock it before you import the class for it to take effect, or you could just import the module so you don't have the import statements sprinkled around:
import loaders
class MockLoader(loaders.Loader):
def load(self):
data = super().load()
data += " and mock"
return data
def test_loader_special1(monkeypatch):
monkeypatch.setattr("loaders.Loader", MockLoader)
data = loaders.Loader().load()
print(f"{data=}")
assert data == "real and mock"
I have a class that handles the API calls to a server. Certain methods within the class require the user to be logged in. Since it is possible for the session to run out, I need some functionality that re-logins the user once the session timed out. My idea was to use a decorator. If I try it like this
class Outer_Class():
class login_required():
def __init__(self, decorated_func):
self.decorated_func = decorated_func
def __call__(self, *args, **kwargs):
try:
response = self.decorated_func(*args, **kwargs)
except:
print('Session probably timed out. Logging in again ...')
args[0]._login()
response = self.decorated_func(*args, **kwargs)
return response
def __init__(self):
self.logged_in = False
self.url = 'something'
self._login()
def _login(self):
print(f'Logging in on {self.url}!')
self.logged_in = True
#this method requires the user to be logged in
#login_required
def do_something(self, param_1):
print('Doing something important with param_1')
if (): #..this fails
raise Exception()
I get an error. AttributeError: 'str' object has no attribute '_login'
Why do I not get a reference to the Outer_Class-instance handed over via *args? Is there another way to get a reference to the instance?
Found this answer How to get instance given a method of the instance? , but the decorated_function doesn't seem to have a reference to it's own instance.
It works fine, when Im using a decorator function outside of the class. This solves the problem, but I like to know, if it is possible to solve the this way.
The problem is that the magic of passing the object as the first hidden parameter only works for a non static method. As your decorator returns a custom callable object which is not a function, it never receives the calling object which is just lost in the call. So when you try to call the decorated function, you only pass it param_1 in the position of self. You get a first exception do_something() missing 1 required positional argument: 'param_1', fall into the except block and get your error.
You can still tie the decorator to the class, but it must be a function to have self magic work:
class Outer_Class():
def login_required(decorated_func):
def inner(self, *args, **kwargs):
print("decorated called")
try:
response = decorated_func(self, *args, **kwargs)
except:
print('Session probably timed out. Logging in again ...')
self._login()
response = decorated_func(self, *args, **kwargs)
return response
return inner
...
#this method requires the user to be logged in
#login_required
def do_something(self, param_1):
print('Doing something important with param_1', param_1)
if (False): #..this fails
raise Exception()
You can then successfully do:
>>> a = Outer_Class()
Logging in on something!
>>> a.do_something("foo")
decorated called
Doing something important with param_1
You have the command of
args[0]._login()
in the except. Since args[0] is a string and it doesn't have a _login method, you get the error message mentioned in the question.
I have a signal_handler connected through a decorator, something like this very simple one:
#receiver(post_save, sender=User,
dispatch_uid='myfile.signal_handler_post_save_user')
def signal_handler_post_save_user(sender, *args, **kwargs):
# do stuff
What I want to do is to mock it with the mock library http://www.voidspace.org.uk/python/mock/ in a test, to check how many times django calls it. My code at the moment is something like:
def test_cache():
with mock.patch('myapp.myfile.signal_handler_post_save_user') as mocked_handler:
# do stuff that will call the post_save of User
self.assert_equal(mocked_handler.call_count, 1)
The problem here is that the original signal handler is called even if mocked, most likely because the #receiver decorator is storing a copy of the signal handler somewhere, so I'm mocking the wrong code.
So the question: how do I mock my signal handler to make my test work?
Note that if I change my signal handler to:
def _support_function(*args, **kwargs):
# do stuff
#receiver(post_save, sender=User,
dispatch_uid='myfile.signal_handler_post_save_user')
def signal_handler_post_save_user(sender, *args, **kwargs):
_support_function(*args, **kwargs)
and I mock _support_function instead, everything works as expected.
Possibly a better idea is to mock out the functionality inside the signal handler rather than the handler itself. Using the OP's code:
#receiver(post_save, sender=User, dispatch_uid='myfile.signal_handler_post_save_user')
def signal_handler_post_save_user(sender, *args, **kwargs):
do_stuff() # <-- mock this
def do_stuff():
... do stuff in here
Then mock do_stuff:
with mock.patch('myapp.myfile.do_stuff') as mocked_handler:
self.assert_equal(mocked_handler.call_count, 1)
So, I ended up with a kind-of solution: mocking a signal handler simply means to connect the mock itself to the signal, so this exactly is what I did:
def test_cache():
with mock.patch('myapp.myfile.signal_handler_post_save_user', autospec=True) as mocked_handler:
post_save.connect(mocked_handler, sender=User, dispatch_uid='test_cache_mocked_handler')
# do stuff that will call the post_save of User
self.assertEquals(mocked_handler.call_count, 1) # standard django
# self.assert_equal(mocked_handler.call_count, 1) # when using django-nose
Notice that autospec=True in mock.patch is required in order to make post_save.connect to correctly work on a MagicMock, otherwise django will raise some exceptions and the connection will fail.
You can mock a django signal by mocking the ModelSignal class at django.db.models.signals.py like this:
#patch("django.db.models.signals.ModelSignal.send")
def test_overwhelming(self, mocker_signal):
obj = Object()
That should do the trick. Note that this will mock ALL signals no matter which object you are using.
If by any chance you use the mocker library instead, it can be done like this:
from mocker import Mocker, ARGS, KWARGS
def test_overwhelming(self):
mocker = Mocker()
# mock the post save signal
msave = mocker.replace("django.db.models.signals")
msave.post_save.send(KWARGS)
mocker.count(0, None)
with mocker:
obj = Object()
It's more lines but it works pretty well too :)
take a look at mock_django . It has support for signals
https://github.com/dcramer/mock-django/blob/master/tests/mock_django/signals/tests.py
In django 1.9 you can mock all receivers with something like this
# replace actual receivers with mocks
mocked_receivers = []
for i, receiver in enumerate(your_signal.receivers):
mock_receiver = Mock()
your_signal.receivers[i] = (receiver[0], mock_receiver)
mocked_receivers.append(mock_receiver)
... # whatever your test does
# ensure that mocked receivers have been called as expected
for mocked_receiver in mocked_receivers:
assert mocked_receiver.call_count == 1
mocked_receiver.assert_called_with(*your_args, sender="your_sender", signal=your_signal, **your_kwargs)
This replaces all receivers with mocks, eg ones you've registered, ones pluggable apps have registered and ones that django itself has registered. Don't be suprised if you use this on post_save and things start breaking.
You may want to inspect the receiver to determine if you actually want to mock it.
There is a way to mock django signals with a small class.
You should keep in mind that this would only mock the function as a django signal handler and not the original function; for example, if a m2mchange trigers a call to a function that calls your handler directly, mock.call_count would not be incremented. You would need a separate mock to keep track of those calls.
Here is the class in question:
class LocalDjangoSignalsMock():
def __init__(self, to_mock):
"""
Replaces registered django signals with MagicMocks
:param to_mock: list of signal handlers to mock
"""
self.mocks = {handler:MagicMock() for handler in to_mock}
self.reverse_mocks = {magicmock:mocked
for mocked,magicmock in self.mocks.items()}
django_signals = [signals.post_save, signals.m2m_changed]
self.registered_receivers = [signal.receivers
for signal in django_signals]
def _apply_mocks(self):
for receivers in self.registered_receivers:
for receiver_index in xrange(len(receivers)):
handler = receivers[receiver_index]
handler_function = handler[1]()
if handler_function in self.mocks:
receivers[receiver_index] = (
handler[0], self.mocks[handler_function])
def _reverse_mocks(self):
for receivers in self.registered_receivers:
for receiver_index in xrange(len(receivers)):
handler = receivers[receiver_index]
handler_function = handler[1]
if not isinstance(handler_function, MagicMock):
continue
receivers[receiver_index] = (
handler[0], weakref.ref(self.reverse_mocks[handler_function]))
def __enter__(self):
self._apply_mocks()
return self.mocks
def __exit__(self, *args):
self._reverse_mocks()
Example usage
to_mock = [my_handler]
with LocalDjangoSignalsMock(to_mock) as mocks:
my_trigger()
for mocked in to_mock:
assert(mocks[mocked].call_count)
# 'function {0} was called {1}'.format(
# mocked, mocked.call_count)
As you mentioned,
mock.patch('myapp.myfile._support_function') is correct but mock.patch('myapp.myfile.signal_handler_post_save_user') is wrong.
I think the reason is:
When init you test, some file import the signal's realization python file, then #receive decorator create a new signal connection.
In the test, mock.patch('myapp.myfile._support_function') will create another signal connection, so the original signal handler is called even if mocked.
Try to disconnect the signal connection before mock.patch('myapp.myfile._support_function'), like
post_save.disconnect(signal_handler_post_save_user)
with mock.patch("review.signals. signal_handler_post_save_user", autospec=True) as handler:
#do stuff
I'm creating a task (by subclassing celery.task.Task) that creates a connection to Twitter's streaming API. For the Twitter API calls, I am using tweepy. As I've read from the celery-documentation, 'a task is not instantiated for every request, but is registered in the task registry as a global instance.' I was expecting that whenever I call apply_async (or delay) for the task, I will be accessing the task that was originally instantiated but that doesn't happen. Instead, a new instance of the custom task class is created. I need to be able to access the original custom task since this is the only way I can terminate the original connection created by the tweepy API call.
Here's some piece of code if this would help:
from celery import registry
from celery.task import Task
class FollowAllTwitterIDs(Task):
def __init__(self):
# requirements for creation of the customstream
# goes here. The CustomStream class is a subclass
# of tweepy.streaming.Stream class
self._customstream = CustomStream(*args, **kwargs)
#property
def customstream(self):
if self._customstream:
# terminate existing connection to Twitter
self._customstream.running = False
self._customstream = CustomStream(*args, **kwargs)
def run(self):
self._to_follow_ids = function_that_gets_list_of_ids_to_be_followed()
self.customstream.filter(follow=self._to_follow_ids, async=False)
follow_all_twitterids = registry.tasks[FollowAllTwitterIDs.name]
And for the Django view
def connect_to_twitter(request):
if request.method == 'POST':
do_stuff_here()
.
.
.
follow_all_twitterids.apply_async(args=[], kwargs={})
return
Any help would be appreciated. :D
EDIT:
For additional context for the question, the CustomStream object creates an httplib.HTTPSConnection instance whenever the filter() method is called. This connection needs to be closed whenever there is another attempt to create one. The connection is closed by setting customstream.running to False.
The task should only be instantiated once, if you think it is not for some reason,
I suggest you add a
print("INSTANTIATE")
import traceback
traceback.print_stack()
to the Task.__init__ method, so you could tell where this would be happening.
I think your task could be better expressed like this:
from celery.task import Task, task
class TwitterTask(Task):
_stream = None
abstract = True
def __call__(self, *args, **kwargs):
try:
return super(TwitterTask, self).__call__(stream, *args, **kwargs)
finally:
if self._stream:
self._stream.running = False
#property
def stream(self):
if self._stream is None:
self._stream = CustomStream()
return self._stream
#task(base=TwitterTask)
def follow_all_ids():
ids = get_list_of_ids_to_follow()
follow_all_ids.stream.filter(follow=ids, async=false)
I've been using testbed, webtest, and nose to test my Python GAE app, and it is a great setup. I'm now implementing something similar to Nick's great example of using the deferred library, but I can't figure out a good way to test the parts of the code triggered by DeadlineExceededError.
Since this is in the context of a taskqueue, it would be painful to construct a test that took more than 10 minutes to run. Is there a way to temporarily set the taskqueue time limit to a few seconds for the purpose of testing? Or perhaps some other way to elegantly test the execution of code in the except DeadlineExceededError block?
Abstract the "GAE context" for your code. in production provide real "GAE implementation" for testing provide a mock own that will raise the DeadlineExceededError. The test should not depend on any timeout, should be fast.
Sample abstraction (just glue):
class AbstractGAETaskContext(object):
def task_spired(): pass # this will throw exception in mock impl
# here you define any method that you call into GAE, to be mocked
def defered(...): pass
If you don't like abstraction, you can do monkey patching for testing only, also you need to define the task_expired function to be your hook for testing.
task_expired should be called during your task implementation function.
*UPDATED*This the 3rd solution:
First I want to mention that the Nick's sample implementation is not so great, the Mapper class has to many responsabilities(deferring, query data, update in batch); and this make the test hard to made, a lot of mocks need to be defined. So I extract the deferring responsabilities in a separate class. You only want to test that deferring mechanism, what actually is happen(the update, query, etc) should be handled in other test.
Here is deffering class, also this no more depends on GAE:
class DeferredCall(object):
def __init__(self, deferred):
self.deferred = deferred
def run(self, long_execution_call, context, *args, **kwargs):
''' long_execution_call should return a tuple that tell us how was terminate operation, with timeout and the context where was abandoned '''
next_context, timeouted = long_execution_call(context, *args, **kwargs)
if timeouted:
self.deferred(self.run, next_context, *args, **kwargs)
Here is the test module:
class Test(unittest.TestCase):
def test_defer(self):
calls = []
def mock_deferrer(callback, *args, **kwargs):
calls.append((callback, args, kwargs))
def interrupted(self, context):
return "new_context", True
d = DeferredCall()
d.run(interrupted, "init_context")
self.assertEquals(1, len(calls), 'a deferred call should be')
def test_no_defer(self):
calls = []
def mock_deferrer(callback, *args, **kwargs):
calls.append((callback, args, kwargs))
def completed(self, context):
return None, False
d = DeferredCall()
d.run(completed, "init_context")
self.assertEquals(0, len(calls), 'no deferred call should be')
How will look the Nick's Mapper implementation:
class Mapper:
...
def _continue(self, start_key, batch_size):
... # here is same code, nothing was changed
except DeadlineExceededError:
# Write any unfinished updates to the datastore.
self._batch_write()
# Queue a new task to pick up where we left off.
##deferred.defer(self._continue, start_key, batch_size)
return start_key, True ## make compatible with DeferredCall
self.finish()
return None, False ## make it comaptible with DeferredCall
runner = _continue
Code where you register the long running task; this only depend on the GAE deferred lib.
import DeferredCall
import PersonMapper # this inherits the Mapper
from google.appengine.ext import deferred
mapper = PersonMapper()
DeferredCall(deferred).run(mapper.run)