Say I have a module with the following:
def main():
pass
if __name__ == "__main__":
main()
I want to write a unit test for the bottom half (I'd like to achieve 100% coverage). I discovered the runpy builtin module that performs the import/__name__-setting mechanism, but I can't figure out how to mock or otherwise check that the main() function is called.
This is what I've tried so far:
import runpy
import mock
#mock.patch('foobar.main')
def test_main(self, main):
runpy.run_module('foobar', run_name='__main__')
main.assert_called_once_with()
I will choose another alternative which is to exclude the if __name__ == '__main__' from the coverage report , of course you can do that only if you already have a test case for your main() function in your tests.
As for why I choose to exclude rather than writing a new test case for the whole script is because if as I stated you already have a test case for your main() function the fact that you add an other test case for the script (just for having a 100 % coverage) will be just a duplicated one.
For how to exclude the if __name__ == '__main__' you can write a coverage configuration file and add in the section report:
[report]
exclude_lines =
if __name__ == .__main__.:
More info about the coverage configuration file can be found here.
Hope this can help.
You can do this using the imp module rather than the import statement. The problem with the import statement is that the test for '__main__' runs as part of the import statement before you get a chance to assign to runpy.__name__.
For example, you could use imp.load_source() like so:
import imp
runpy = imp.load_source('__main__', '/path/to/runpy.py')
The first parameter is assigned to __name__ of the imported module.
Whoa, I'm a little late to the party, but I recently ran into this issue and I think I came up with a better solution, so here it is...
I was working on a module that contained a dozen or so scripts all ending with this exact copypasta:
if __name__ == '__main__':
if '--help' in sys.argv or '-h' in sys.argv:
print(__doc__)
else:
sys.exit(main())
Not horrible, sure, but not testable either. My solution was to write a new function in one of my modules:
def run_script(name, doc, main):
"""Act like a script if we were invoked like a script."""
if name == '__main__':
if '--help' in sys.argv or '-h' in sys.argv:
sys.stdout.write(doc)
else:
sys.exit(main())
and then place this gem at the end of each script file:
run_script(__name__, __doc__, main)
Technically, this function will be run unconditionally whether your script was imported as a module or ran as a script. This is ok however because the function doesn't actually do anything unless the script is being ran as a script. So code coverage sees the function runs and says "yes, 100% code coverage!" Meanwhile, I wrote three tests to cover the function itself:
#patch('mymodule.utils.sys')
def test_run_script_as_import(self, sysMock):
"""The run_script() func is a NOP when name != __main__."""
mainMock = Mock()
sysMock.argv = []
run_script('some_module', 'docdocdoc', mainMock)
self.assertEqual(mainMock.mock_calls, [])
self.assertEqual(sysMock.exit.mock_calls, [])
self.assertEqual(sysMock.stdout.write.mock_calls, [])
#patch('mymodule.utils.sys')
def test_run_script_as_script(self, sysMock):
"""Invoke main() when run as a script."""
mainMock = Mock()
sysMock.argv = []
run_script('__main__', 'docdocdoc', mainMock)
mainMock.assert_called_once_with()
sysMock.exit.assert_called_once_with(mainMock())
self.assertEqual(sysMock.stdout.write.mock_calls, [])
#patch('mymodule.utils.sys')
def test_run_script_with_help(self, sysMock):
"""Print help when the user asks for help."""
mainMock = Mock()
for h in ('-h', '--help'):
sysMock.argv = [h]
run_script('__main__', h*5, mainMock)
self.assertEqual(mainMock.mock_calls, [])
self.assertEqual(sysMock.exit.mock_calls, [])
sysMock.stdout.write.assert_called_with(h*5)
Blam! Now you can write a testable main(), invoke it as a script, have 100% test coverage, and not need to ignore any code in your coverage report.
Python 3 solution:
import os
from importlib.machinery import SourceFileLoader
from importlib.util import spec_from_loader, module_from_spec
from importlib import reload
from unittest import TestCase
from unittest.mock import MagicMock, patch
class TestIfNameEqMain(TestCase):
def test_name_eq_main(self):
loader = SourceFileLoader('__main__',
os.path.join(os.path.dirname(os.path.dirname(__file__)),
'__main__.py'))
with self.assertRaises(SystemExit) as e:
loader.exec_module(module_from_spec(spec_from_loader(loader.name, loader)))
Using the alternative solution of defining your own little function:
# module.py
def main():
if __name__ == '__main__':
return 'sweet'
return 'child of mine'
You can test with:
# Override the `__name__` value in your module to '__main__'
with patch('module_name.__name__', '__main__'):
import module_name
self.assertEqual(module_name.main(), 'sweet')
with patch('module_name.__name__', 'anything else'):
reload(module_name)
del module_name
import module_name
self.assertEqual(module_name.main(), 'child of mine')
I did not want to exclude the lines in question, so based on this explanation of a solution, I implemented a simplified version of the alternate answer given here...
I wrapped if __name__ == "__main__": in a function to make it easily testable, and then called that function to retain logic:
# myapp.module.py
def main():
pass
def init():
if __name__ == "__main__":
main()
init()
I mocked the __name__ using unittest.mock to get at the lines in question:
from unittest.mock import patch, MagicMock
from myapp import module
def test_name_equals_main():
# Arrange
with patch.object(module, "main", MagicMock()) as mock_main:
with patch.object(module, "__name__", "__main__"):
# Act
module.init()
# Assert
mock_main.assert_called_once()
If you are sending arguments into the mocked function, like so,
if __name__ == "__main__":
main(main_args)
then you can use assert_called_once_with() for an even better test:
expected_args = ["expected_arg_1", "expected_arg_2"]
mock_main.assert_called_once_with(expected_args)
If desired, you can also add a return_value to the MagicMock() like so:
with patch.object(module, "main", MagicMock(return_value='foo')) as mock_main:
One approach is to run the modules as scripts (e.g. os.system(...)) and compare their stdout and stderr output to expected values.
I found this solution helpful. Works well if you use a function to keep all your script code.
The code will be handled as one code line. It doesn't matter if the entire line was executed for coverage counter (though this is not what you would actually actually expect by 100% coverage)
The trick is also accepted pylint. ;-)
if __name__ == '__main__': \
main()
If it's just to get the 100% and there is nothing "real" to test there, it is easier to ignore that line.
If you are using the regular coverage lib, you can just add a simple comment, and the line will be ignored in the coverage report.
if __name__ == '__main__':
main() # pragma: no cover
https://coverage.readthedocs.io/en/coverage-4.3.3/excluding.html
Another comment by # Taylor Edmiston also mentions it
My solution is to use imp.load_source() and force an exception to be raised early in main() by not providing a required CLI argument, providing a malformed argument, setting paths in such a way that a required file is not found, etc.
import imp
import os
import sys
def mainCond(testObj, srcFilePath, expectedExcType=SystemExit, cliArgsStr=''):
sys.argv = [os.path.basename(srcFilePath)] + (
[] if len(cliArgsStr) == 0 else cliArgsStr.split(' '))
testObj.assertRaises(expectedExcType, imp.load_source, '__main__', srcFilePath)
Then in your test class you can use this function like this:
def testMain(self):
mainCond(self, 'path/to/main.py', cliArgsStr='-d FailingArg')
To import your "main" code in pytest in order to test it you can import main module like other functions thanks to native importlib package :
def test_main():
import importlib
loader = importlib.machinery.SourceFileLoader("__main__", "src/glue_jobs/move_data_with_resource_partitionning.py")
runpy_main = loader.load_module()
assert runpy_main()
Related
I wrote a package that is using multiprocessing.Pool inside one of its functions.
Due to this reason, it is mandatory (as specified in here under "Safe importing of main module") that the outermost calling function can be imported safely e.g. without starting a new process. This is usually achieved using the if __name__ == "__main__": statement as explicitly explained at the link above.
My understanding (but please correct me if I'm wrong) is that multiprocessing imports the outermost calling module. So, if this is not "import-safe", this will start a new process that will import again the outermost module and so on recursively, until everything crashes.
If the outermost module is not "import-safe" when the main function is launched it usually hangs without printing any warning, error, message, anything.
Since using if __name__ == "__main__": is not usually mandatory and the user is usually not always aware of all the modules used inside a package, I would like to check at the beginning of my function if the user complied with this requirement and, if not, raise a warning/error.
Is this possible? How can I do this?
To show this with an example, consider the following example.
Let's say I developed my_module.py and I share it online/in my company.
# my_module.py
from multiprocessing import Pool
def f(x):
return x*x
def my_function(x_max):
with Pool(5) as p:
print(p.map(f, range(x_max)))
If a user (not me) writes his own script as:
# Script_of_a_good_user.py
from my_module import my_function
if __name__ == '__main__':
my_function(10)
all is good and the output is printed as expected.
However, if a careless user writes his script as:
# Script_of_a_careless_user.py
from my_module import my_function
my_function(10)
then the process hangs, no output is produces, but no error message or warning is issued to the user.
Is there a way inside my_function, BEFORE opening Pool, to check if the user used the if __name__ == '__main__': condition in its script and, if not, raise an error saying it should do it?
NOTE: I think this behavior is only a problem on Windows machines where fork() is not available, as explained here.
You can use the traceback module to inspect the stack and find the information you're looking for. Parse the top frame, and look for the main shield in the code.
I assume this will fail when you're working with a .pyc file and don't have access to the source code, but I assume developers will test their code in the regular fashion first before doing any kind of packaging, so I think it's safe to assume your error message will get printed when needed.
Version with verbose messages:
import traceback
import re
def called_from_main_shield():
print("Calling introspect")
tb = traceback.extract_stack()
print(traceback.format_stack())
print(f"line={tb[0].line} lineno={tb[0].lineno} file={tb[0].filename}")
try:
with open(tb[0].filename, mode="rt") as f:
found_main_shield = False
for i, line in enumerate(f):
if re.search(r"__name__.*['\"]__main__['\"]", line):
found_main_shield = True
if i == tb[0].lineno:
print(f"found_main_shield={found_main_shield}")
return found_main_shield
except:
print("Coulnd't inspect stack, let's pretend the code is OK...")
return True
print(called_from_main_shield())
if __name__ == "__main__":
print(called_from_main_shield())
In the output, we see that the first called to called_from_main_shield returns False, while the second returns True:
$ python3 introspect.py
Calling introspect
[' File "introspect.py", line 24, in <module>\n print(called_from_main_shield())\n', ' File "introspect.py", lin
e 7, in called_from_main_shield\n print(traceback.format_stack())\n']
line=print(called_from_main_shield()) lineno=24 file=introspect.py
found_main_shield=False
False
Calling introspect
[' File "introspect.py", line 27, in <module>\n print(called_from_main_shield())\n', ' File "introspect.py", lin
e 7, in called_from_main_shield\n print(traceback.format_stack())\n']
line=print(called_from_main_shield()) lineno=27 file=introspect.py
found_main_shield=True
True
More concise version:
def called_from_main_shield():
tb = traceback.extract_stack()
try:
with open(tb[0].filename, mode="rt") as f:
found_main_shield = False
for i, line in enumerate(f):
if re.search(r"__name__.*['\"]__main__['\"]", line):
found_main_shield = True
if i == tb[0].lineno:
return found_main_shield
except:
return True
Now, it's not super elegant to use re.search() like I did, but it should be reliable enough. Warning: since I defined this function in my main script, I had to make sure that line didn't match itself, which is why I used ['\"] to match the quotes instead of using a simpler RE like __name__.*__main__. Whatever you chose, just make sure it's flexible enough to match all legal variants of that code, which is what I aimed for.
I think the best you can do is to try execute the code and provide a hint if it fails. Something like this:
# my_module.py
import sys # Use sys.stderr to print to the error stream.
from multiprocessing import Pool
def f(x):
return x*x
def my_function(x_max):
try:
with Pool(5) as p:
print(p.map(f, range(x_max)))
except RuntimeError as e:
print("Whoops! Did you perhaps forget to put the code in `if __name__ == '__main__'`?", file=sys.stderr)
raise e
This is of course not a 100% solution, as there might be several other reasons the code throws a RuntimeError.
If it doesn't raise a RuntimeError, an ugly solution would be to explicitly force the user to pass in the name of the module.
# my_module.py
from multiprocessing import Pool
def f(x):
return x*x
def my_function(x_max, module):
"""`module` must be set to `__name__`, for example `my_function(10, __name__)`"""
if module == '__main__':
with Pool(5) as p:
print(p.map(f, range(x_max)))
else:
raise Exception("This can only be called from the main module.")
And call it as:
# Script_of_a_careless_user.py
from my_module import my_function
my_function(10, __name__)
This makes it very explicit to the user.
Here is my code, where I would like to test validate_yaml function (I removed the function bodies, because they don't needed in question):
yaml_file_name = "env.yaml"
def load_yaml(file: str) -> list:
pass
def validate_yaml(env_list: list):
pass
def yaml_to_env(env_list: list):
pass
env_list = load_yaml(f"{yaml_file_name}")
validate_yaml(env_list)
yaml_to_env(env_list)
This is my test file:
import pytest
import jsonschema
from yaml_to_env import load_yaml, validate_yaml
#pytest.mark.parametrize(
"invalid_yaml",
[
(load_yaml("tests/yaml_files/invalid_workload_type.yaml")),
],
)
def test_yaml_env(invalid_yaml):
with pytest.raises(jsonschema.ValidationError):
validate_yaml(invalid_yaml)
My problem is that when I run pytest then the last three rows are executed too:
env_list = load_yaml(f"{yaml_file_name}")
validate_yaml(env_list)
yaml_to_env(env_list)
Why it is doing this? I would like to test only validate_yaml function and not call that three lines during pytest.
Thanks in advance
[EDIT1]
This is the best solution what I found so far:
if __name__ == "__main__":
env_list = load_yaml(f"{yaml_file_name}")
validate_yaml(env_list)
yaml_to_env(env_list)
Your edit is exactly what you need to do here.
In your previous attempt the code has been executed during import of your production code in the test file.
This is not an issue with pytest but rather basic module structure of python.
See this link for details:
What does if __name__ == "__main__": do?
First point in the short answer seems to be what you did. :)
I want to run doc-tests of a Python script as part of a pre-commit hook in Python.
In the file set_prefix.py, I have doc-tests in front of functions, which I test before running with:
import doctest
import sys
EXTENSIONS = tuple([".%s" % ending for ending in ["jpg", "heic", "nrw"]])
def is_target_for_renaming(filepath):
"""Returns true if this filepath should be renamed.
>>> is_target_for_renaming("/Users/username/Pictures/document.other_jpg")
True
"""
return filepath.lower().endswith(EXTENSIONS)
def get_failed_tests():
r = doctest.testmod()
return r.failed
def main():
pass
if "__main__" == __name__:
args = sys.argv
test_only = 2 <= len(sys.argv) and "test" == sys.argv[1]
test_failures = get_failed_tests()
print(test_failures)
assert 0 == test_failures
if not test_only:
main()
When I run python3 set_prefix.py test, I get the error I expected.
Yet, when I import the module and call the function:
import set_prefix
if "__main__" == __name__:
test_failures = set_prefix.get_failed_tests()
print(test_failures)
I get 0 failures:
$ python3 temp.py
0
The reason I want to import the module is to run the tests in a pre-commit hook similar to that added by flake8:
#!/usr/local/opt/python/bin/python3.7
import sys
from flake8.main import git
if __name__ == '__main__':
sys.exit(
git.hook(
strict=git.config_for('strict'),
lazy=git.config_for('lazy'),
)
)
Why do the doc-tests run when called from the command-line and the script and not when the script is imported? Would unittest be a better framework, as described in this thread?
doctest.testmod()
runs doctests in __main__ module and it depends on which script you're actually running.
You can fix this with m parameter, but you'll still be forced to add boilerplate code in each module that has doctests. Try this:
doctest.testfile("some_module.py")
I have written a test case which shows the error
from unittest import *
class MyTest(unittest.TestCase):
def test_add(self):
self.assertEquals(1,(2-1),"Sample Subraction Test")
if __name__ == '__main__':
unittest.main()
Output:
Str object is not callable
Instead of
"from unittest import *" I have given
"import unittest"
it worked
but still i couldn't get point it accurately
what might be the reason for this?
from ... import * is dangerous practice, and should only be used when the module/package has been designed and advertised that way, and you have a good reason to do so.
It turns out that unittest has not been designed that way, and when that method is used two other 'test cases' are found, but since they aren't really test cases, they create problems.
The correct way to do what you want is:
import unittest
class MyTest(unittest.TestCase):
def test_subtraction(self):
self.assertEqual(1, (2-1), "Sample Subraction Test")
if __name__ == '__main__':
unittest.main()
In researching this issue I discovered that the __all__ variable can and should be used to define the public API -- its presence does not indicate that from ... import * is supported.
I got it working like this.
Override runTest() method, create instance, run your test_add()
from unittest import TestCase
class MyTest(TestCase):
def runTest(self):
pass
def test_add(self):
self.assertEquals(1,(2-2),"Sample Subraction Test")
if __name__ == '__main__':
test = MyTest()
test.test_add()
I use testoob in the following way:
def suite():
import unittest
return unittest.TestLoader().loadTestsFromNames([
'my_module.my_unittest_class',
'my_module.my_other_unittest_class',
])
if __name__ == '__main__':
import testoob
testoob.main(defaultTest="suite")
And then run the unittest suite with the following:
python my_unittest.py --coverage=normal
This however will also print code coverage figures of all modules that my module and unittest depends on, which I'm not interested in at all. How can I configure testoob to only report coverage for my own module?
I ended up overriding the private _should_cover_frame function in the testoob Coverage class and comparing the frame's file path to my module. Not the nicest solution, but at least it works.
from testoob.coverage import Coverage
orig_should_cover = Coverage._should_cover_frame
def my_should_cover_frame(self, frame):
from os.path import abspath
filename = abspath(frame.f_code.co_filename)
if filename.find('my_module') == -1:
return False
else:
return orig_should_cover(self, frame)
Coverage._should_cover_frame = my_should_cover_frame