Serializing an object in __main__ with pickle or dill - python

I have a pickling problem. I want to serialize a function in my main script, then load it and run it in another script. To demonstrate this, I've made 2 scripts:
Attempt 1: The naive way:
dill_pickle_script_1.py
import pickle
import time
def my_func(a, b):
time.sleep(0.1) # The purpose of this will become evident at the end
return a+b
if __name__ == '__main__':
with open('testfile.pkl', 'wb') as f:
pickle.dump(my_func, f)
dill_pickle_script_2.py
import pickle
if __name__ == '__main__':
with open('testfile.pkl') as f:
func = pickle.load(f)
assert func(1, 2)==3
Problem: when I run script 2, I get AttributeError: 'module' object has no attribute 'my_func'. I understand why: because when my_func is serialized in script1, it belongs to the __main__ module. dill_pickle_script_2 can't know that __main__ there referred to the namespace of dill_pickle_script_1, and therefore cannot find the reference.
Attempt 2: Inserting an absolute import
I fix the problem by adding a little hack - I add an absolute import to my_func in dill_pickle_script_1 before pickling it.
dill_pickle_script_1.py
import pickle
import time
def my_func(a, b):
time.sleep(0.1)
return a+b
if __name__ == '__main__':
from dill_pickle_script_1 import my_func # Added absolute import
with open('testfile.pkl', 'wb') as f:
pickle.dump(my_func, f)
Now it works! However, I'd like to avoid having to do this hack every time I want to do this. (Also, I want to have my pickling be done inside some other module which wouldn't have know which module that my_func came from).
Attempt 3: Dill
I head that the package dill lets you serialize things in main and load them elsewhere. So I tried that:
dill_pickle_script_1.py
import dill
import time
def my_func(a, b):
time.sleep(0.1)
return a+b
if __name__ == '__main__':
with open('testfile.pkl', 'wb') as f:
dill.dump(my_func, f)
dill_pickle_script_2.py
import dill
if __name__ == '__main__':
with open('testfile.pkl') as f:
func = dill.load(f)
assert func(1, 2)==3
Now, however, I have another problem: When running dill_pickle_script_2.py, I get a NameError: global name 'time' is not defined. It seems that dill did not realize that my_func referenced the time module and has to import it on load.
My Question?
How can I serialize an object in main, and load it again in another script so that all the imports used by that object are also loaded, without doing the nasty little hack in Attempt 2?

Well, I found a solution. It is a horrible but tidy kludge and not guaranteed to work in all cases. Any suggestions for improvement are welcome. The solution involves replacing the main reference with an absolute module reference in the pickle string, using the following helper functions:
import sys
import os
def pickle_dumps_without_main_refs(obj):
"""
Yeah this is horrible, but it allows you to pickle an object in the main module so that it can be reloaded in another
module.
:param obj:
:return:
"""
currently_run_file = sys.argv[0]
module_path = file_path_to_absolute_module(currently_run_file)
pickle_str = pickle.dumps(obj, protocol=0)
pickle_str = pickle_str.replace('__main__', module_path) # Hack!
return pickle_str
def pickle_dump_without_main_refs(obj, file_obj):
string = pickle_dumps_without_main_refs(obj)
file_obj.write(string)
def file_path_to_absolute_module(file_path):
"""
Given a file path, return an import path.
:param file_path: A file path.
:return:
"""
assert os.path.exists(file_path)
file_loc, ext = os.path.splitext(file_path)
assert ext in ('.py', '.pyc')
directory, module = os.path.split(file_loc)
module_path = [module]
while True:
if os.path.exists(os.path.join(directory, '__init__.py')):
directory, package = os.path.split(directory)
module_path.append(package)
else:
break
path = '.'.join(module_path[::-1])
return path
Now, I can simply change dill_pickle_script_1.py to say
import time
from artemis.remote.child_processes import pickle_dump_without_main_refs
def my_func(a, b):
time.sleep(0.1)
return a+b
if __name__ == '__main__':
with open('testfile.pkl', 'wb') as f:
pickle_dump_without_main_refs(my_func, f)
And then dill_pickle_script_2.py works!

You can use dill.dump with recurse=True or dill.settings["recurse"] = True. It will capture closures:
In file A:
import time
import dill
def my_func(a, b):
time.sleep(0.1)
return a + b
with open("tmp.pkl", "wb") as f:
dill.dump(my_func, f, recurse=True)
In file B:
import dill
with open("tmp.pkl", "rb") as f:
my_func = dill.load(f)

Here's another solution that modifies the serialization so that it will deserialize without any special measures. You could argue it is less hacky than Peter's solution.
Instead of hacking the output from pickle.dumps(), this subclasses Pickler to modify the way it pickles objects that refer back to __main__. This does mean that the fast (C implementation) pickler can't be used, so there is a performance penalty with this method. It also overrides the save_pers() method of Pickler, which isn't intended to be overridden. So this could break in a future version of Python (unlikely though).
def get_function_module_str(func):
"""Returns a dotted module string suitable for importlib.import_module() from a
function reference.
"""
source_file = Path(inspect.getsourcefile(func))
# (Doesn't work with built-in functions)
if not source_file.is_absolute():
rel_path = source_file
else:
# It's an absolute path so find the longest entry in sys.path that shares a
# common prefix and remove the prefix.
for path_str in sorted(sys.path, key=len, reverse=True):
try:
rel_path = source_file.relative_to(Path(path_str))
break
except ValueError:
pass
else:
raise ValueError(f"{source_file!r} is not on the Python path")
# Replace path separators with dots.
modules_str = ".".join(p for p in rel_path.with_suffix("").parts if p != "__init__")
return modules_str, func.__name__
class ResolveMainPickler(pickle._Pickler):
"""Subclass of Pickler that replaces __main__ references with the actual module
name."""
def persistent_id(self, obj):
"""Override to see if this object is defined in "__main__" and if so to replace
__main__ with the actual module name."""
if getattr(obj, "__module__", None) == "__main__":
module_str, obj_name = get_function_module_str(obj)
obj_ref = getattr(importlib.import_module(module_str), obj_name)
return obj_ref
return None
def save_pers(self, pid):
"""Override the function to save a persistent ID so that it saves it as a
normal reference. So it can be unpickled with no special arrangements.
"""
self.save(pid, save_persistent_id=False)
with io.BytesIO() as pickled:
pickler = ResolveMainPickler(pickled)
pickler.dump(obj)
print(pickled.getvalue())
If you already know the name of the __main__ module then you could dispense with get_function_module_str() and just supply the name directly.

Related

Recursionerror when attempting to mock built-in open in python [duplicate]

How do I test the following code with unittest.mock:
def testme(filepath):
with open(filepath) as f:
return f.read()
Python 3
Patch builtins.open and use mock_open, which is part of the mock framework. patch used as a context manager returns the object used to replace the patched one:
from unittest.mock import patch, mock_open
with patch("builtins.open", mock_open(read_data="data")) as mock_file:
assert open("path/to/open").read() == "data"
mock_file.assert_called_with("path/to/open")
If you want to use patch as a decorator, using mock_open()'s result as the new= argument to patch can be a little bit weird. Instead, use patch's new_callable= argument and remember that every extra argument that patch doesn't use will be passed to the new_callable function, as described in the patch documentation:
patch() takes arbitrary keyword arguments. These will be passed to the Mock (or new_callable) on construction.
#patch("builtins.open", new_callable=mock_open, read_data="data")
def test_patch(mock_file):
assert open("path/to/open").read() == "data"
mock_file.assert_called_with("path/to/open")
Remember that in this case patch will pass the mocked object as an argument to your test function.
Python 2
You need to patch __builtin__.open instead of builtins.open and mock is not part of unittest, you need to pip install and import it separately:
from mock import patch, mock_open
with patch("__builtin__.open", mock_open(read_data="data")) as mock_file:
assert open("path/to/open").read() == "data"
mock_file.assert_called_with("path/to/open")
The way to do this has changed in mock 0.7.0 which finally supports mocking the python protocol methods (magic methods), particularly using the MagicMock:
http://www.voidspace.org.uk/python/mock/magicmock.html
An example of mocking open as a context manager (from the examples page in the mock documentation):
>>> open_name = '%s.open' % __name__
>>> with patch(open_name, create=True) as mock_open:
... mock_open.return_value = MagicMock(spec=file)
...
... with open('/some/path', 'w') as f:
... f.write('something')
...
<mock.Mock object at 0x...>
>>> file_handle = mock_open.return_value.__enter__.return_value
>>> file_handle.write.assert_called_with('something')
With the latest versions of mock, you can use the really useful mock_open helper:
mock_open(mock=None, read_data=None)
A helper function to create a
mock to replace the use of open. It works for open called directly or
used as a context manager.
The mock argument is the mock object to configure. If None (the
default) then a MagicMock will be created for you, with the API
limited to methods or attributes available on standard file handles.
read_data is a string for the read method of the file handle to
return. This is an empty string by default.
>>> from mock import mock_open, patch
>>> m = mock_open()
>>> with patch('{}.open'.format(__name__), m, create=True):
... with open('foo', 'w') as h:
... h.write('some stuff')
>>> m.assert_called_once_with('foo', 'w')
>>> handle = m()
>>> handle.write.assert_called_once_with('some stuff')
To use mock_open for a simple file read() (the original mock_open snippet already given on this page is geared more for write):
my_text = "some text to return when read() is called on the file object"
mocked_open_function = mock.mock_open(read_data=my_text)
with mock.patch("__builtin__.open", mocked_open_function):
with open("any_string") as f:
print f.read()
Note as per docs for mock_open, this is specifically for read(), so won't work with common patterns like for line in f, for example.
Uses python 2.6.6 / mock 1.0.1
The top answer is useful but I expanded on it a bit.
If you want to set the value of your file object (the f in as f) based on the arguments passed to open() here's one way to do it:
def save_arg_return_data(*args, **kwargs):
mm = MagicMock(spec=file)
mm.__enter__.return_value = do_something_with_data(*args, **kwargs)
return mm
m = MagicMock()
m.side_effect = save_arg_return_array_of_data
# if your open() call is in the file mymodule.animals
# use mymodule.animals as name_of_called_file
open_name = '%s.open' % name_of_called_file
with patch(open_name, m, create=True):
#do testing here
Basically, open() will return an object and with will call __enter__() on that object.
To mock properly, we must mock open() to return a mock object. That mock object should then mock the __enter__() call on it (MagicMock will do this for us) to return the mock data/file object we want (hence mm.__enter__.return_value). Doing this with 2 mocks the way above allows us to capture the arguments passed to open() and pass them to our do_something_with_data method.
I passed an entire mock file as a string to open() and my do_something_with_data looked like this:
def do_something_with_data(*args, **kwargs):
return args[0].split("\n")
This transforms the string into a list so you can do the following as you would with a normal file:
for line in file:
#do action
I might be a bit late to the game, but this worked for me when calling open in another module without having to create a new file.
test.py
import unittest
from mock import Mock, patch, mock_open
from MyObj import MyObj
class TestObj(unittest.TestCase):
open_ = mock_open()
with patch.object(__builtin__, "open", open_):
ref = MyObj()
ref.save("myfile.txt")
assert open_.call_args_list == [call("myfile.txt", "wb")]
MyObj.py
class MyObj(object):
def save(self, filename):
with open(filename, "wb") as f:
f.write("sample text")
By patching the open function inside the __builtin__ module to my mock_open(), I can mock writing to a file without creating one.
Note: If you are using a module that uses cython, or your program depends on cython in any way, you will need to import cython's __builtin__ module by including import __builtin__ at the top of your file. You will not be able to mock the universal __builtin__ if you are using cython.
If you don't need any file further, you can decorate the test method:
#patch('builtins.open', mock_open(read_data="data"))
def test_testme():
result = testeme()
assert result == "data"
To patch the built-in open() function with unittest:
This worked for a patch to read a json config.
class ObjectUnderTest:
def __init__(self, filename: str):
with open(filename, 'r') as f:
dict_content = json.load(f)
The mocked object is the io.TextIOWrapper object returned by the open() function
#patch("<src.where.object.is.used>.open",
return_value=io.TextIOWrapper(io.BufferedReader(io.BytesIO(b'{"test_key": "test_value"}'))))
def test_object_function_under_test(self, mocker):
I'm using pytest in my case, and the good news is that in Python 3 the unittest library can also be imported and used without issue.
Here is my approach. First, I create a conftest.py file with reusable pytest fixture(s):
from functools import cache
from unittest.mock import MagicMock, mock_open
import pytest
from pytest_mock import MockerFixture
class FileMock(MagicMock):
def __init__(self, mocker: MagicMock = None, **kwargs):
super().__init__(**kwargs)
if mocker:
self.__dict__ = mocker.__dict__
# configure mock object to replace the use of open(...)
# note: this is useful in scenarios where data is written out
_ = mock_open(mock=self)
#property
def read_data(self):
return self.side_effect
#read_data.setter
def read_data(self, mock_data: str):
"""set mock data to be returned when `open(...).read()` is called."""
self.side_effect = mock_open(read_data=mock_data)
#property
#cache
def write_calls(self):
"""a list of calls made to `open().write(...)`"""
handle = self.return_value
write: MagicMock = handle.write
return write.call_args_list
#property
def write_lines(self) -> str:
"""a list of written lines (as a string)"""
return ''.join([c[0][0] for c in self.write_calls])
#pytest.fixture
def mock_file_open(mocker: MockerFixture) -> FileMock:
return FileMock(mocker.patch('builtins.open'))
Where I decided to make the read_data as a property, in order to be more pythonic. It can be assigned in a test function with whatever data that open() needs to return.
In my test file, named something like test_it_works.py, I have a following test case to confirm intended functionality:
from unittest.mock import call
def test_mock_file_open_and_read(mock_file_open):
mock_file_open.read_data = 'hello\nworld!'
with open('/my/file/here', 'r') as in_file:
assert in_file.readlines() == ['hello\n', 'world!']
mock_file_open.assert_called_with('/my/file/here', 'r')
def test_mock_file_open_and_write(mock_file_open):
with open('/out/file/here', 'w') as f:
f.write('hello\n')
f.write('world!\n')
f.write('--> testing 123 :-)')
mock_file_open.assert_called_with('/out/file/here', 'w')
assert call('world!\n') in mock_file_open.write_calls
assert mock_file_open.write_lines == """\
hello
world!
--> testing 123 :-)
""".rstrip()
Check out the gist here.
Sourced from a github snippet to patch read and write functionality in python.
The source link is over here
import configparser
import pytest
simpleconfig = """[section]\nkey = value\n\n"""
def test_monkeypatch_open_read(mockopen):
filename = 'somefile.txt'
mockopen.write(filename, simpleconfig)
parser = configparser.ConfigParser()
parser.read(filename)
assert parser.sections() == ['section']
def test_monkeypatch_open_write(mockopen):
parser = configparser.ConfigParser()
parser.add_section('section')
parser.set('section', 'key', 'value')
filename = 'somefile.txt'
parser.write(open(filename, 'wb'))
assert mockopen.read(filename) == simpleconfig
SIMPLE #patch with assert
If you're wanting to use #patch. The open() is called inside the handler and is read.
#patch("builtins.open", new_callable=mock_open, read_data="data")
def test_lambda_handler(self, mock_open_file):
lambda_handler(event, {})

Keeping track of when Python modules are imported

Does the interpreter somehow keep a timestamp of when a module is imported? Or is there an easy way of hooking into the import machinery to do this?
The scenario is a long-running Python process that at various points imports user-provided modules. I would like the process to be able to check "should I restart to load the latest code changes?" by checking the module file's timestamps against the time the module was imported.
Here's a way to automatically have an attribute (named _loadtime in the example code below) added to modules when they're imported. The code is based on Recipe 10.12 titled "Patching Modules on Import" in the book Python Cookbook, by David Beazley and Brian Jones, O'Reilly, 2013, which shows a technique that I adapted to do what you want.
For testing purposes I created this trivial target_module.py file:
print('in target_module')
Here's the example code:
import importlib
import sys
import time
class PostImportFinder:
def __init__(self):
self._skip = set() # To prevent recursion.
def find_module(self, fullname, path=None):
if fullname in self._skip: # Prevent recursion
return None
self._skip.add(fullname)
return PostImportLoader(self)
class PostImportLoader:
def __init__(self, finder):
self._finder = finder
def load_module(self, fullname):
importlib.import_module(fullname)
module = sys.modules[fullname]
# Add a custom attribute to the module object.
module._loadtime = time.time()
self._finder._skip.remove(fullname)
return module
sys.meta_path.insert(0, PostImportFinder())
if __name__ == '__main__':
import time
try:
print('importing target_module')
import target_module
except Exception as e:
print('Import failed:', e)
raise
loadtime = time.localtime(target_module._loadtime)
print('module loadtime: {} ({})'.format(
target_module._loadtime,
time.strftime('%Y-%b-%d %H:%M:%S', loadtime)))
Sample output:
importing target_module
in target_module
module loadtime: 1604683023.2491636 (2020-Nov-06 09:17:03)
I don't think there's any way to get around how hacky this is, but how about something like this every time you import? (I don't know exactly how you're importing):
import time
from types import ModuleType
# create a dictionary to keep track
# filter globals to exclude things that aren't modules and aren't builtins
MODULE_TIMES = {k:None for k,v in globals().items() if not k.startswith("__") and not k.endswith("__") and type(v) == ModuleType}
for module_name in user_module_list:
MODULE_TIMES[module_name] = time.time()
eval("import {0}".format(module_name))
And then you can reference this dictionary in a similar way later.

In Python, how do I get the list of classes defined within a particular file?

If a file myfile.py contains:
class A(object):
# Some implementation
class B (object):
# Some implementation
How can I define a method so that, given myfile.py, it returns
[A, B]?
Here, the returned values for A and B can be either the name of the classes or the type of the classes.
(i.e. type(A) = type(str) or type(A) = type(type))
You can get both:
import importlib, inspect
for name, cls in inspect.getmembers(importlib.import_module("myfile"), inspect.isclass):
you may additionally want to check:
if cls.__module__ == 'myfile'
In case it helps someone else. Here is the final solution that I used. This method returns all classes defined in a particular package.
I keep all of the subclasses of X in a particular folder (package) and then, using this method, I can load all the subclasses of X, even if they haven't been imported yet. (If they haven't been imported yet, they cannot be accessible via __all__; otherwise things would have been much easier).
import importlib, os, inspect
def get_modules_in_package(package_name: str):
files = os.listdir(package_name)
for file in files:
if file not in ['__init__.py', '__pycache__']:
if file[-3:] != '.py':
continue
file_name = file[:-3]
module_name = package_name + '.' + file_name
for name, cls in inspect.getmembers(importlib.import_module(module_name), inspect.isclass):
if cls.__module__ == module_name:
yield cls
It's a bit long-winded, but you first need to load the file as a module, then inspect its methods to see which are classes:
import inspect
import importlib.util
# Load the module from file
spec = importlib.util.spec_from_file_location("foo", "foo.py")
foo = importlib.util.module_from_spec(spec)
spec.loader.exec_module(foo)
# Return a list of all attributes of foo which are classes
[x for x in dir(foo) if inspect.isclass(getattr(foo, x))]
Just building on the answers above.
If you need a list of the classes defined within the module (file), i.e. not just those present in the module namespace, and you want the list within that module, i.e. using reflection, then the below will work under both __name__ == __main__ and __name__ == <module> cases.
import sys, inspect
# You can pass a lambda function as the predicate for getmembers()
[name, cls in inspect.getmembers(sys.modules[__name__], lambda x: inspect.isclass(x) and (x.__module__ == __name__))]
In my very specific use case of registering classes to a calling framework, I used as follows:
def register():
myLogger.info(f'Registering classes defined in module {__name__}')
for name, cls in inspect.getmembers(sys.modules[__name__], lambda x: inspect.isclass(x) and (x.__module__ == __name__)):
myLogger.debug(f'Registering class {cls} with name {name}')
<framework>.register_class(cls)

Import only functions from a python file

I have many Python files (submission1.py, submission2.py, ... , submissionN.py) in the following format,
#submission1.py
def fun():
print('some fancy function')
fun()
I want to write a tester to test these submissions. (They are actually homeworks that I am grading.). I have a tester for the fun() which is able to test the function itself. However, my problem is, when I import submission.py, it runs the fun() since it calls it at the end of file.
I know that, using if __name__ == "__main__": is the correct way of handling this issue, however, our submissions does not have it since we did not teach it.
So, my question is, is there any way that I can import only fun() from the submission.py files without running the rest of the python file?
For simple scripts with just functions the following will work:
submission1.py:
def fun(x):
print(x)
fun("foo")
def fun2(x):
print(x)
fun2("bar")
print("debug print")
You can remove all bar the FunctionDef nodes then recompile:
import ast
import types
with open("submission1.py") as f:
p = ast.parse(f.read())
for node in p.body[:]:
if not isinstance(node, ast.FunctionDef):
p.body.remove(node)
module = types.ModuleType("mod")
code = compile(p, "mod.py", 'exec')
sys.modules["mod"] = module
exec(code, module.__dict__)
import mod
mod.fun("calling fun")
mod.fun2("calling fun2")
Output:
calling fun
calling fun2
The module body contains two Expr and one Print node which we remove in the loop keeping just the FunctionDef's.
[<_ast.FunctionDef object at 0x7fa33357f610>, <_ast.Expr object at 0x7fa330298a90>,
<_ast.FunctionDef object at 0x7fa330298b90>, <_ast.Expr object at 0x7fa330298cd0>,
<_ast.Print object at 0x7fa330298dd0>]
So after the loop out body only contains the functions:
[<_ast.FunctionDef object at 0x7f49a786a610>, <_ast.FunctionDef object at 0x7f49a4583b90>]
This will also catch where the functions are called with print which if the student was calling the function from an IDE where the functions have return statements is pretty likely, also to keep any imports of there are any you can keep ast.Import's and ast.ImportFrom's:
submission.py:
from math import *
import datetime
def fun(x):
print(x)
fun("foo")
def fun2(x):
return x
def get_date():
print(pi)
return datetime.datetime.now()
fun2("bar")
print("debug print")
print(fun2("hello world"))
print(get_date())
Compile then import:
for node in p.body[:]:
if not isinstance(node, (ast.FunctionDef,ast.Import, ast.ImportFrom)):
p.body.remove(node)
.....
import mod
mod.fun("calling fun")
print(mod.fun2("calling fun2"))
print(mod.get_date())
Output:
calling fun
calling fun2
3.14159265359
2015-05-09 12:29:02.472329
Lastly if you have some variables declared that you need to use you can keep them using ast.Assign:
submission.py:
from math import *
import datetime
AREA = 25
WIDTH = 35
def fun(x):
print(x)
fun("foo")
def fun2(x):
return x
def get_date():
print(pi)
return datetime.datetime.now()
fun2("bar")
print("debug print")
print(fun2("hello world"))
print(get_date()
Add ast.Assign:
for node in p.body[:]:
if not isinstance(node, (ast.FunctionDef,
ast.Import, ast.ImportFrom,ast.Assign)):
p.body.remove(node)
....
Output:
calling fun
calling fun2
3.14159265359
2015-05-09 12:34:18.015799
25
35
So it really all depends on how your modules are structured and what they should contain as to what you remove. If there are literally only functions then the first example will do what you want. If there are other parts that need to be kept it is just a matter of adding them to the isinstance check.
The listing of all the abstract grammar definitions is in the cpython source under Parser/Python.asdl.
You could use sys.settrace() to catch function definitions.
Whenever your fun() is defined, you save it somewhere, and you place a stub into the module you are importing, so that it won't get executed.
Assuming that fun() gets defined only once, this code should do the trick:
import sys
fun = None
def stub(*args, **kwargs):
pass
def wait_for_fun(frame, event, arg):
global fun
if frame.f_code.co_filename == '/path/to/module.py':
if 'fun' in frame.f_globals:
# The function has just been defined. Save it.
fun = frame.f_globals['fun']
# And replace it with our stub.
frame.f_globals['fun'] = stub
# Stop tracing the module execution.
return None
return wait_for_fun
sys.settrace(wait_for_fun)
import my_module
# Now fun() is available and we can test it.
fun(1, 2, 3)
# We can also put it again inside the module.
# This is important if other functions in the module need it.
my_module.fun = fun
This code can be improved in many ways, but it does its job.
maybe if you just want to import the fun () function from submission.py try
from submission import fun
To perform the function of fun, you must include the fun module
submission.fun()
or if you want to make it easier when calling the fun () function, give it a try
from submission import fun as FUN
FUN ()

Asserting execution order in python unittest

I have a function that creates a temporary directory, switches to that temporary directory, performs some work, and then switches back to the original directory. I am trying to write a unit test that tests this. I don't have a problem verifying that the current directory was changed to the temp dir and changed back again, but I'm having a problem verifying that the important stuff took place in between those calls.
My original idea was to abstract the function into three sub functions so that I could test the call order. I can replace each of the three sub functions with mocks to verify that they are called -- however, I am still presented with the issue of verifying the order. On a mock I can use assert_has_calls, but upon what object do I call that function?
Here is the class I'm trying to test:
import shutil
import os
import subprocess
import tempfile
import pkg_resources
class Converter:
def __init__(self, encoded_file_path):
self.encoded_file_path = encoded_file_path
self.unencoded_file_path = None
self.original_path = None
self.temp_dir = None
def change_to_temp_dir(self):
self.original_path = os.getcwd()
self.temp_dir = tempfile.mkdtemp()
os.chdir(self.temp_dir)
def change_to_original_dir(self):
os.chdir(self.original_path)
shutil.rmtree(self.temp_dir)
def do_stuff(self):
pass
def run(self):
self.change_to_temp_dir()
self.do_stuff()
self.change_to_original_dir()
This is as far as I got writing the test case:
def test_converter(self, pkg_resources, tempfile, subprocess, os, shutil):
encoded_file_path = Mock()
converter = Converter(encoded_file_path)
converter.change_to_temp_dir = Mock()
converter.do_stuff= Mock()
converter.change_to_original_dir = Mock()
assert converter.encoded_file_path == encoded_file_path
assert converter.unencoded_file_path is None
converter.run()
Now that I have each function mocked, I can verify THAT they were called, but not in what ORDER. How do I go about doing this?
One workaround would to be to create a separate mock object, attach methods to it and use assert_has_calls() to check the call order:
converter = Converter(encoded_file_path)
converter.change_to_temp_dir = Mock()
converter.do_stuff = Mock()
converter.change_to_original_dir = Mock()
m = Mock()
m.configure_mock(first=converter.change_to_temp_dir,
second=converter.do_stuff,
third=converter.change_to_original_dir)
converter.run()
m.assert_has_calls([call.first(), call.second(), call.third()])

Categories