Python unittest: mock open specific paths (don't mock others) - python

The use-case is that I want to mock the opening of two files ~/.myconf and ./.myconf but not the other ones.
I'm testing the setup of a complex object which reads multiple files in its __init__ and so I'd like to mock some data for some of them, not mock at all for some others.
As an example here is how I mock the conditional opening of those two files, but it feels complex and I find it odd that there's no easy way already built-in that I'm missing.
import builtins
import configparser
import unittest
from textwrap import dedent
from pathlib import Path
from unittest.mock import mock_open
OPEN = builtins.open
def get_hierarchical_config():
cwd = Path.cwd()
global_config = configparser.ConfigParser()
local_config = configparser.ConfigParser()
global_config.read(Path("~/.myconf").expanduser().resolve())
local_config.read((cwd / ".myconf").expanduser().resolve())
full_config.read_dict(global_config)
full_config.read_dict(local_config)
return full_config["mysection"]
def get_custom_mock_open(global_conf_str, local_conf_str) -> callable:
def mocked_open():
def conditional_open_func(path, *args, **kwargs):
p = Path(path).expanduser().resolve()
if p.name == ".myconfig":
if p.parent == Path.home():
return mock_open(read_data=global_conf_str)()
return mock_open(read_data=local_conf_str)()
return OPEN(path, *args, **kwargs)
return conditional_open_func
return mocked_open
[...]
class TestConfig(unittest.TestCase):
def test_read_confs(self):
global_conf = dedent(
"""\
[mysection]
no_overwrite=path/to/somewhere
local_overwrite=ERROR:not overwritten
syntax_test_key= no/space= problem2
"""
)
local_conf = dedent(
"""\
[mysection]
local_overwrite=SUCCESS:overwritten
local_new_key=cool value
"""
)
with patch(
"builtins.open",
new_callable=get_custom_mock_open(global_conf, local_conf),
):
conf = dict(get_hierarchical_config()) # reads the config files
target = {
"no_overwrite": "path/to/somewhere",
"local_overwrite": "SUCCESS:overwritten",
"syntax_test_key": "no/space= problem2",
"local_new_key": "cool value",
}
self.assertDictEqual(conf, target)

Related

Unable to successfully patch functions of Azure ContainerClient

I have been trying to patch the list_blobs() function of ContainerClient, have not been able to do this successfully, this code outputs a MagicMock() function - but the function isn't patched as I would expect it to be (Trying to patch with a list ['Blob1', 'Blob2'].
#################Script File
import sys
from datetime import datetime, timedelta
import pyspark
import pytz
import yaml
# from azure.storage.blob import BlobServiceClient, ContainerClient
from pyspark.dbutils import DBUtils as dbutils
import azure.storage.blob
# Open Config
def main():
spark_context = pyspark.SparkContext.getOrCreate()
spark_context.addFile(sys.argv[1])
stream = None
stream = open(sys.argv[1], "r")
config = yaml.load(stream, Loader=yaml.FullLoader)
stream.close()
account_key = dbutils.secrets.get(scope=config["Secrets"]["Scope"], key=config["Secrets"]["Key Name"])
target_container = config["Storage Configuration"]["Container"]
target_account = config["Storage Configuration"]["Account"]
days_history_to_keep = config["Storage Configuration"]["Days History To Keep"]
connection_string = (
"DefaultEndpointsProtocol=https;AccountName="
+ target_account
+ ";AccountKey="
+ account_key
+ ";EndpointSuffix=core.windows.net"
)
blob_service_client: azure.storage.blob.BlobServiceClient = (
azure.storage.blob.BlobServiceClient.from_connection_string(connection_string)
)
container_client: azure.storage.blob.ContainerClient = (
blob_service_client.get_container_client(target_container)
)
blobs = container_client.list_blobs()
print(blobs)
print(blobs)
utc = pytz.UTC
delete_before_date = utc.localize(
datetime.today() - timedelta(days=days_history_to_keep)
)
for blob in blobs:
if blob.creation_time < delete_before_date:
print("Deleting Blob: " + blob.name)
container_client.delete_blob(blob, delete_snapshots="include")
if __name__ == "__main__":
main()
#################Test File
import unittest
from unittest import mock
import DeleteOldBlobs
class DeleteBlobsTest(unittest.TestCase):
def setUp(self):
pass
#mock.patch("DeleteOldBlobs.azure.storage.blob.ContainerClient")
#mock.patch("DeleteOldBlobs.azure.storage.blob.BlobServiceClient")
#mock.patch("DeleteOldBlobs.dbutils")
#mock.patch("DeleteOldBlobs.sys")
#mock.patch('DeleteOldBlobs.pyspark')
def test_main(self, mock_pyspark, mock_sys, mock_dbutils, mock_blobserviceclient, mock_containerclient):
# mock setup
config_file = "Delete_Old_Blobs_UnitTest.yml"
mock_sys.argv = ["unused_arg", config_file]
mock_dbutils.secrets.get.return_value = "A Secret"
mock_containerclient.list_blobs.return_value = ["ablob1", "ablob2"]
# execute test
DeleteOldBlobs.main()
# TODO assert actions taken
# mock_sys.argv.__get__.assert_called_with()
# dbutils.secrets.get(scope=config['Secrets']['Scope'], key=config['Secrets']['Key Name'])
if __name__ == "__main__":
unittest.main()
Output:
<MagicMock name='BlobServiceClient.from_connection_string().get_container_client().list_blobs()' id='1143355577232'>
What am I doing incorrectly here?
I'm not able to execute your code in this moment, but I have tried to simulate it. To do this I have created the following 3 files in the path: /<path-to>/pkg/sub_pkg1 (where pkg and sub_pkg1 are packages).
File ContainerClient.py
def list_blobs(self):
return "blob1"
File DeleteOldBlobs.py
from pkg.sub_pkg1 import ContainerClient
# Open Config
def main():
blobs = ContainerClient.list_blobs()
print(blobs)
print(blobs)
File DeleteBlobsTest.py
import unittest
from unittest import mock
from pkg.sub_pkg1 import DeleteOldBlobs
class DeleteBlobsTest(unittest.TestCase):
def setUp(self):
pass
def test_main(self):
mock_containerclient = mock.MagicMock()
with mock.patch("DeleteOldBlobs.ContainerClient.list_blobs", mock_containerclient.list_blobs):
mock_containerclient.list_blobs.return_value = ["ablob1", "ablob2"]
DeleteOldBlobs.main()
if __name__ == '__main__':
unittest.main()
If you execute the test code you obtain the output:
['ablob1', 'ablob2']
['ablob1', 'ablob2']
This output means that the function list_blobs() is mocked by mock_containerclient.list_blobs.
I don't know if the content of this post can be useful for you, but I'm not able to simulate better your code in this moment.
I hope you can inspire to my code to find your real solution.
The structure of the answer didn't match my solution, perhaps both will work but it was important for me to patch pyspark even though i never call it, or exceptions would get thrown when my code tried to interact with spark.
Perhaps this will be useful to someone:
#mock.patch("DeleteOldBlobs.azure.storage.blob.BlobServiceClient")
#mock.patch("DeleteOldBlobs.dbutils")
#mock.patch("DeleteOldBlobs.sys")
#mock.patch('DeleteOldBlobs.pyspark')
def test_list_blobs_called_once(self, mock_pyspark, mock_sys, mock_dbutils, mock_blobserviceclient):
# mock setup
config_file = "Delete_Old_Blobs_UnitTest.yml"
mock_sys.argv = ["unused_arg", config_file]
account_key = 'Secret Key'
mock_dbutils.secrets.get.return_value = account_key
bsc_mock: mock.Mock = mock.Mock()
container_client_mock = mock.Mock()
blob1 = Blob('newblob', datetime.today())
blob2 = Blob('oldfile', datetime.today() - timedelta(days=20))
container_client_mock.list_blobs.return_value = [blob1, blob2]
bsc_mock.get_container_client.return_value = container_client_mock
mock_blobserviceclient.from_connection_string.return_value = bsc_mock
# execute test
DeleteOldBlobs.main()
#Assert Results
container_client_mock.list_blobs.assert_called_once()

How to modify the argument taken from a contextmanager within the with block?

I am trying to modify on the fly some of the parameters used by the exit function of a context manager. I am trying to bind the parameter to a variable known in the with block
from contextlib import contextmanager
import tempfile, shutil
#contextmanager
def tempdir(suffix = '', prefix = '', dir = None, ignore_errors = False,
remove = True):
"""
Context manager to generate a temporary directory with write permissions.
"""
d = tempfile.mkdtemp(suffix, prefix, dir)
try:
yield d
finally:
print "finalizing tempdir %s, remove= %s" %(d,remove)
if remove:
shutil.rmtree(d, ignore_errors)
willremove = True
with tempdir(remove = willremove) as t:
#attempt to modify parameter
willremove = False
print "willremove:%s" %willremove
pass
I would expect that changing the value of willremove would change the remove variable in the finally: part of the contextmanager function, but it doesn't help
This cannot be done because the parameters in python are passed 'by assignment', as pointed by Ned Batchelder in the following talk:
https://www.youtube.com/watch?v=_AEJHKGk9ns

How to mock a function, in a function map/dictionary?

I am trying to patch the fun_1 function from the worker_functions dictionary and I seem to be struggling:
cli.py:
import sys
from worker_functions import (
fun_1,
fun_2,
fun_3,
)
FUNCTION_MAP = {
'run_1': fun_1,
'run_2': fun_2,
'run_3': fun_3,
}
def main():
command = sys.argv[1]
tag = sys.argv[2]
action = FUNCTION_MAP[command]
action(tag)
I've tried mocking cli.fun_1 and cli.main.action and cli.action but this is leading to failure.
test_cli.py:
from mock import patch
from cli import main
def make_test_args(tup):
sample_args = ['cli.py']
sample_args.extend(tup)
return sample_args
def test_fun_1_command():
test_args = make_test_args(['run_1', 'fake_tag'])
with patch('sys.argv', test_args),\
patch('cli.fun_1') as mock_action:
main()
mock_action.assert_called_once()
Do I seem to be missing something?
You'll need to patch the references in the FUNCTION_MAP dictionary itself. Use the patch.dict() callable to do so:
from unittest.mock import patch, MagicMock
mock_action = MagicMock()
with patch('sys.argv', test_args),\
patch.dict('cli.FUNCTION_MAP', {'run_1': mock_action}):
# ...
That's because the FUNCTION_MAP dictionary is the location that the function reference is looked up.

In Python, how can you load YAML mappings as OrderedDicts?

I'd like to get PyYAML's loader to load mappings (and ordered mappings) into the Python 2.7+ OrderedDict type, instead of the vanilla dict and the list of pairs it currently uses.
What's the best way to do that?
Python >= 3.6
In python 3.6+, it seems that dict loading order is preserved by default without special dictionary types. The default Dumper, on the other hand, sorts dictionaries by key. Starting with pyyaml 5.1, you can turn this off by passing sort_keys=False:
a = dict(zip("unsorted", "unsorted"))
s = yaml.safe_dump(a, sort_keys=False)
b = yaml.safe_load(s)
assert list(a.keys()) == list(b.keys()) # True
This can work due to the new dict implementation that has been in use in pypy for some time. While still considered an implementation detail in CPython 3.6, "the insertion-order preserving nature of dicts has been declared an official part of the Python language spec" as of 3.7+, see What's New In Python 3.7.
Note that this is still undocumented from PyYAML side, so you shouldn't rely on this for safety critical applications.
Original answer (compatible with all known versions)
I like #James' solution for its simplicity. However, it changes the default global yaml.Loader class, which can lead to troublesome side effects. Especially, when writing library code this is a bad idea. Also, it doesn't directly work with yaml.safe_load().
Fortunately, the solution can be improved without much effort:
import yaml
from collections import OrderedDict
def ordered_load(stream, Loader=yaml.SafeLoader, object_pairs_hook=OrderedDict):
class OrderedLoader(Loader):
pass
def construct_mapping(loader, node):
loader.flatten_mapping(node)
return object_pairs_hook(loader.construct_pairs(node))
OrderedLoader.add_constructor(
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG,
construct_mapping)
return yaml.load(stream, OrderedLoader)
# usage example:
ordered_load(stream, yaml.SafeLoader)
For serialization, you could use the following funcion:
def ordered_dump(data, stream=None, Dumper=yaml.SafeDumper, **kwds):
class OrderedDumper(Dumper):
pass
def _dict_representer(dumper, data):
return dumper.represent_mapping(
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG,
data.items())
OrderedDumper.add_representer(OrderedDict, _dict_representer)
return yaml.dump(data, stream, OrderedDumper, **kwds)
# usage:
ordered_dump(data, Dumper=yaml.SafeDumper)
In each case, you could also make the custom subclasses global, so that they don't have to be recreated on each call.
2018 option:
oyaml is a drop-in replacement for PyYAML which preserves dict ordering. Both Python 2 and Python 3 are supported. Just pip install oyaml, and import as shown below:
import oyaml as yaml
You'll no longer be annoyed by screwed-up mappings when dumping/loading.
Note: I'm the author of oyaml.
The yaml module allow you to specify custom 'representers' to convert Python objects to text and 'constructors' to reverse the process.
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
def dict_representer(dumper, data):
return dumper.represent_dict(data.iteritems())
def dict_constructor(loader, node):
return collections.OrderedDict(loader.construct_pairs(node))
yaml.add_representer(collections.OrderedDict, dict_representer)
yaml.add_constructor(_mapping_tag, dict_constructor)
2015 (and later) option:
ruamel.yaml is a drop in replacement for PyYAML (disclaimer: I am the author of that package). Preserving the order of the mappings was one of the things added in the first version (0.1) back in 2015. Not only does it preserve the order of your dictionaries, it will also preserve comments, anchor names, tags and does support the YAML 1.2 specification (released 2009)
The specification says that the ordering is not guaranteed, but of course there is ordering in the YAML file and the appropriate parser can just hold on to that and transparently generate an object that keeps the ordering. You just need to choose the right parser, loader and dumperĀ¹:
import sys
from ruamel.yaml import YAML
yaml_str = """\
3: abc
conf:
10: def
3: gij # h is missing
more:
- what
- else
"""
yaml = YAML()
data = yaml.load(yaml_str)
data['conf'][10] = 'klm'
data['conf'][3] = 'jig'
yaml.dump(data, sys.stdout)
will give you:
3: abc
conf:
10: klm
3: jig # h is missing
more:
- what
- else
data is of type CommentedMap which functions like a dict, but has extra information that is kept around until being dumped (including the preserved comment!)
Note: there is a library, based on the following answer, which implements also the CLoader and CDumpers: Phynix/yamlloader
I doubt very much that this is the best way to do it, but this is the way I came up with, and it does work. Also available as a gist.
import yaml
import yaml.constructor
try:
# included in standard lib from Python 2.7
from collections import OrderedDict
except ImportError:
# try importing the backported drop-in replacement
# it's available on PyPI
from ordereddict import OrderedDict
class OrderedDictYAMLLoader(yaml.Loader):
"""
A YAML loader that loads mappings into ordered dictionaries.
"""
def __init__(self, *args, **kwargs):
yaml.Loader.__init__(self, *args, **kwargs)
self.add_constructor(u'tag:yaml.org,2002:map', type(self).construct_yaml_map)
self.add_constructor(u'tag:yaml.org,2002:omap', type(self).construct_yaml_map)
def construct_yaml_map(self, node):
data = OrderedDict()
yield data
value = self.construct_mapping(node)
data.update(value)
def construct_mapping(self, node, deep=False):
if isinstance(node, yaml.MappingNode):
self.flatten_mapping(node)
else:
raise yaml.constructor.ConstructorError(None, None,
'expected a mapping node, but found %s' % node.id, node.start_mark)
mapping = OrderedDict()
for key_node, value_node in node.value:
key = self.construct_object(key_node, deep=deep)
try:
hash(key)
except TypeError, exc:
raise yaml.constructor.ConstructorError('while constructing a mapping',
node.start_mark, 'found unacceptable key (%s)' % exc, key_node.start_mark)
value = self.construct_object(value_node, deep=deep)
mapping[key] = value
return mapping
Update: the library was deprecated in favor of the yamlloader (which is based on the yamlordereddictloader)
I've just found a Python library (https://pypi.python.org/pypi/yamlordereddictloader/0.1.1) which was created based on answers to this question and is quite simple to use:
import yaml
import yamlordereddictloader
datas = yaml.load(open('myfile.yml'), Loader=yamlordereddictloader.Loader)
On my For PyYaml installation for Python 2.7 I updated __init__.py, constructor.py, and loader.py. Now supports object_pairs_hook option for load commands. Diff of changes I made is below.
__init__.py
$ diff __init__.py Original
64c64
< def load(stream, Loader=Loader, **kwds):
---
> def load(stream, Loader=Loader):
69c69
< loader = Loader(stream, **kwds)
---
> loader = Loader(stream)
75c75
< def load_all(stream, Loader=Loader, **kwds):
---
> def load_all(stream, Loader=Loader):
80c80
< loader = Loader(stream, **kwds)
---
> loader = Loader(stream)
constructor.py
$ diff constructor.py Original
20,21c20
< def __init__(self, object_pairs_hook=dict):
< self.object_pairs_hook = object_pairs_hook
---
> def __init__(self):
27,29d25
< def create_object_hook(self):
< return self.object_pairs_hook()
<
54,55c50,51
< self.constructed_objects = self.create_object_hook()
< self.recursive_objects = self.create_object_hook()
---
> self.constructed_objects = {}
> self.recursive_objects = {}
129c125
< mapping = self.create_object_hook()
---
> mapping = {}
400c396
< data = self.create_object_hook()
---
> data = {}
595c591
< dictitems = self.create_object_hook()
---
> dictitems = {}
602c598
< dictitems = value.get('dictitems', self.create_object_hook())
---
> dictitems = value.get('dictitems', {})
loader.py
$ diff loader.py Original
13c13
< def __init__(self, stream, **constructKwds):
---
> def __init__(self, stream):
18c18
< BaseConstructor.__init__(self, **constructKwds)
---
> BaseConstructor.__init__(self)
23c23
< def __init__(self, stream, **constructKwds):
---
> def __init__(self, stream):
28c28
< SafeConstructor.__init__(self, **constructKwds)
---
> SafeConstructor.__init__(self)
33c33
< def __init__(self, stream, **constructKwds):
---
> def __init__(self, stream):
38c38
< Constructor.__init__(self, **constructKwds)
---
> Constructor.__init__(self)
here's a simple solution that also checks for duplicated top level keys in your map.
import yaml
import re
from collections import OrderedDict
def yaml_load_od(fname):
"load a yaml file as an OrderedDict"
# detects any duped keys (fail on this) and preserves order of top level keys
with open(fname, 'r') as f:
lines = open(fname, "r").read().splitlines()
top_keys = []
duped_keys = []
for line in lines:
m = re.search(r'^([A-Za-z0-9_]+) *:', line)
if m:
if m.group(1) in top_keys:
duped_keys.append(m.group(1))
else:
top_keys.append(m.group(1))
if duped_keys:
raise Exception('ERROR: duplicate keys: {}'.format(duped_keys))
# 2nd pass to set up the OrderedDict
with open(fname, 'r') as f:
d_tmp = yaml.load(f)
return OrderedDict([(key, d_tmp[key]) for key in top_keys])

Can Python be made to generate tracing similar to bash's set -x?

Is there a similar mechanism in Python, to the effect set -x has on bash?
Here's some example output from bash in this mode:
+ for src in cpfs.c log.c popcnt.c ssse3_popcount.c blkcache.c context.c types.c device.c
++ my_mktemp blkcache.c.o
+++ mktemp -t blkcache.c.o.2160.XXX
++ p=/tmp/blkcache.c.o.2160.IKA
++ test 0 -eq 0
++ echo /tmp/blkcache.c.o.2160.IKA
+ obj=/tmp/blkcache.c.o.2160.IKA
I'm aware of the Python trace module, however its output seems to be extremely verbose, and not high level like that of bash.
Perhaps use sys.settrace:
Use traceit() to turn on tracing, use traceit(False) to turn off tracing.
import sys
import linecache
def _traceit(frame, event, arg):
'''
http://www.dalkescientific.com/writings/diary/archive/2005/04/20/tracing_python_code.html
'''
if event == "line":
lineno = frame.f_lineno
filename = frame.f_globals["__file__"]
if (filename.endswith(".pyc") or
filename.endswith(".pyo")):
filename = filename[:-1]
name = frame.f_globals["__name__"]
line = linecache.getline(filename, lineno)
print "%s # %s:%s" % (line.rstrip(), name, lineno,)
return _traceit
def _passit(frame, event, arg):
return _passit
def traceit(on=True):
if on: sys.settrace(_traceit)
else: sys.settrace(_passit)
def mktemp(src):
pass
def my_mktemp(src):
mktemp(src)
p=src
traceit()
for src in ('cpfs.c','log.c',):
my_mktemp(src)
traceit(False)
yields
mktemp(src) # __main__:33
pass # __main__:30
p=src # __main__:34
mktemp(src) # __main__:33
pass # __main__:30
p=src # __main__:34
if on: sys.settrace(_traceit) # __main__:26
else: sys.settrace(_passit) # __main__:27
To trace specific calls, you can wrap each interesting function with your own logger. This does lead to arguments expanded to their values rather than just argument names in the output.
Functions have to be passed in as strings to prevent issues where modules redirect to other modules, like os.path / posixpath. I don't think you can extract the right module name to patch from just the function object.
Wrapping code:
import importlib
def wrapper(ffull, f):
def logger(*args, **kwargs):
print "TRACE: %s (%s, %s)" % (ffull, args, kwargs)
return f(*args, **kwargs)
return logger
def log_execution(ffull):
parts = ffull.split('.')
mname = '.'.join(parts[:-1])
fname = parts[-1]
m = importlib.import_module(mname)
f = getattr(m, fname)
setattr(m, fname, wrapper(ffull, f))
Usage:
for f in ['os.path.join', 'os.listdir', 'sys.exit']:
log_execution(f)
p = os.path.join('/usr', 'bin')
os.listdir(p)
sys.exit(0)
....
% ./a.py
TRACE: os.path.join (('/usr', 'bin'), {})
TRACE: os.listdir (('/usr/bin',), {})
TRACE: sys.exit ((0,), {})
You should try to instrument the trace module to get an higher detail level.
What do you need exactly?

Categories