How to mock an existing function that creates a Bigquery client - python

I'm relatively new to pytest and unit tests in general. Specifically I'm struggling to implement mocking, but I do understand the concept. I have a Class, let's call it MyClass. It has a constructor which takes a number of arguments used by other functions within MyClass.
I have a get_tables() function that I have to test and it relies on some arguments defined in the constructor. I need to mock the BigQuery connection and return a mocked list of tables. A small snippet of the script is below
from google.cloud import bigquery
from google.cloud import storage
import logging
class MyClass:
def __init__(self, project_id: str, pro_dataset_id: str, balancing_dataset_id: str, files: list,
key_file: str = None, run_local=True):
"""Creates a BigQuery client.
Args:
project_id: (string), name of project
pro_dataset_id: (string), name of production dataset
balancing_dataset_id: (string), name of balancing dataset
files: (list), list of tables
key_file: (string), path to the key file
"""
if key_file is None:
self.bq_client = bigquery.Client(project=project_id)
self.storage_client = storage.Client(project=project_id)
self.bucket = self.storage_client.get_bucket("{0}-my-bucket".format(project_id))
else:
self.bq_client = bigquery.Client.from_service_account_json(key_file)
self.project_id = project_id
self.run_local = run_local
self.pro_dataset_id = pro_dataset_id
self.balancing_dataset_id = balancing_dataset_id
def get_tables(self):
"""Gets full list of all tables in a BigQuery dataset.
Args:
Returns:
List of tables from a specified dataset
"""
full_table_list = []
dataset_ref = '{0}.{1}'.format(self.project_id, self.pro_dataset_id)
tables = self.bq_client.list_tables(dataset_ref) # Make an API request.
logging.info(f"Tables contained in '{dataset_ref}':")
for table in tables:
full_table_list.append(table.table_id)
logging.info(f"tables: {full_table_list}")
return full_table_list
This was my attempt at mocking the connection and response based on an amalgamation of articles I've read and stackoverflow answers on other questions including this one How can I mock requests and the response?
import pytest
from unittest import mock
from MyPackage import MyClass
class TestMyClass:
def mocked_list_tables():
tables = ['mock_table1', 'mock_table2', 'mock_table3', 'mock_table4']
return tables
#mock.patch('MyPackage.MyClass', side_effect=mocked_list_tables)
def test_get_tables(self):
m = MyClass()
assert m.get_tables() == ['mock_table1', 'mock_table2', 'mock_table3', 'mock_table4']
This is the error I get with the above unit test
TypeError: test_get_tables() takes 1 positional argument but 2 were given
How do I get this test to work correctly? Incase you're wondering, the arguments in the constructor are declared using argparse.ArgumentParser().add_argument()

Related

How to mock list of response objects from boto3?

I'd like to get all archives from a specific directory on S3 bucket like the following:
def get_files_from_s3(bucket_name, s3_prefix):
files = []
s3_resource = boto3.resource("s3")
bucket = s3_resource.Bucket(bucket_name)
response = bucket.objects.filter(Prefix=s3_prefix)
for obj in response:
if obj.key.endswidth('.zip'):
# get all archives
files.append(obj.key)
return files
My question is about testing it; because I'd like to mock the list of objects in the response to be able to iterate on it. Here is what I tried:
from unittest.mock import patch
from dataclasses import dataclass
#dataclass
class MockZip:
key = 'file.zip'
#patch('module.boto3')
def test_get_files_from_s3(self, mock_boto3):
bucket = mock_boto3.resource('s3').Bucket(self.bucket_name)
response = bucket.objects.filter(Prefix=S3_PREFIX)
response.return_value = [MockZip()]
files = module.get_files_from_s3(BUCKET_NAME, S3_PREFIX)
self.assertEqual(['file.zip'], files)
I get an assertion error like this: E AssertionError: ['file.zip'] != []
Does anyone have a better approach? I used struct but I don't think this is the problem, I guess I get an empty list because the response is not iterable. So how can I mock it to be a list of mock objects instead of just a MockMagick type?
Thanks
You could use moto, which is an open-source libray specifically build to mock boto3-calls. It allows you to work directly with boto3, without having to worry about setting up mocks manually.
The testfunction that you're currently using would look like this:
from moto import mock_s3
#pytest.fixture(scope='function')
def aws_credentials():
"""Mocked AWS Credentials, to ensure we're not touching AWS directly"""
os.environ['AWS_ACCESS_KEY_ID'] = 'testing'
os.environ['AWS_SECRET_ACCESS_KEY'] = 'testing'
os.environ['AWS_SECURITY_TOKEN'] = 'testing'
os.environ['AWS_SESSION_TOKEN'] = 'testing'
#mock_s3
def test_get_files_from_s3(self, aws_credentials):
s3 = boto3.resource('s3')
bucket = s3.Bucket(self.bucket_name)
# Create the bucket first, as we're interacting with an empty mocked 'AWS account'
bucket.create()
# Create some example files that are representative of what the S3 bucket would look like in production
client = boto3.client('s3', region_name='us-east-1')
client.put_object(Bucket=self.bucket_name, Key="file.zip", Body="...")
client.put_object(Bucket=self.bucket_name, Key="file.nonzip", Body="...")
# Retrieve the files again using whatever logic
files = module.get_files_from_s3(BUCKET_NAME, S3_PREFIX)
self.assertEqual(['file.zip'], files)
Full documentation for Moto can be found here:
http://docs.getmoto.org/en/latest/index.html
Disclaimer: I am a maintainer for Moto.

Import a JSON project wise, so it loads just once

I have a Python project that performs a JSON validation against a specific schema.
It will run as a Transform step in GCP Dataflow, so it's very important that all dependencies are gathered before the run to avoid downloading the same file again and again.
The schema is placed in a separated Git repository.
The nature of the Transformer is that you receive a single record in your class, and you work with it. The typical flow is that you load the JSON Schema, you validate the record against it, and then you do stuff with the invalid and with the valid. Loading the schema in this way means that I download the schema from the repo for every record, and it could be hundred thousands.
The code gets "cloned" into the workers and then work kinda independent.
Inspired by the way Python loads the requirements at the beginning (one single time) and using them as imports, I thought I could add the repository (where the JSON schema lives) as a Python requirement, and then simply use it in my Python code. But of course, it's a JSON, not a Python module to be imported. How can it work?
An example would be something like:
requirements.txt
git+git://github.com/path/to/json/schema#41b95ec
dataflow_transformer.py
import apache_beam as beam
import the_downloaded_schema
from jsonschema import validate
class Verifier(beam.DoFn):
def process(self, record: dict):
validate(instance=record, schema=the_downloaded_schema)
# ... more stuff
yield record
class Transformer(beam.PTransform):
def expand(self, record):
return (
record
| "Verify Schema" >> beam.ParDo(Verifier())
)
You can load the json schema once and use it as a side input.
An example:
import json
import requests
json_current='https://covidtracking.com/api/v1/states/current.json'
def get_json_schema(url):
with requests.Session() as session:
schema = json.loads(session.get(url).text)
return schema
schema_json = get_json_schema(json_current)
def feed_schema(data, schema):
yield {'record': data, 'schema': schema[0]}
schema = p | beam.Create([schema_json])
data = p | beam.Create(range(10))
data_with_schema = data | beam.FlatMap(feed_schema, schema=beam.pvalue.AsSingleton(schema))
# Now do your schema validation
Just a demonstration of what the data_with_schema pcollection looks like
Why don't you just use a class for loading your resources that uses a cache in order to prevent double loading? Something along the lines of:
class JsonLoader:
def __init__(self):
self.cache = set()
def import(self, filename):
filename = os.path.absname(filename)
if filename not in self.cache:
self._load_json(filename)
self.cache.add(filename)
def _load_json(self, filename):
...

Error executing S3Hook list_keys or read_key methods

I get this error message:
{logging_mixin.py:112} INFO - [2020-03-22 12:34:53,672] {local_task_job.py:103} INFO - Task exited with return code -6
when I use the list_keys or read_key methods of S3 hook. The get_credentials method works fine though. Have searched around and can't find why this occurs.
I'm using apache-airflow==1.10.9, boto3==1.12.21, botocore==1.15.21
Here's my code for my custom operator that makes use of the S3Hook:
class SASValueToRedshiftOperator(BaseOperator):
"""Custom Operator for extracting data from SAS source code.
Attributes:
ui_color (str): color code for task in Airflow UI.
"""
ui_color = '#358150'
#apply_defaults
def __init__(self,
aws_credentials_id="",
redshift_conn_id="",
table="",
s3_bucket="",
s3_key="",
sas_value="",
columns="",
*args, **kwargs):
"""Extracts label mappings from SAS source code and store as Redshift table
Args:
aws_credentials_id (str): Airflow connection ID for AWS key and secret.
redshift_conn_id (str): Airflow connection ID for redshift database.
table (str): Name of table to load data to.
s3_bucket (str): S3 Bucket Name Where SAS source code is store.
s3_key (str): S3 Key Name for SAS source code.
sas_value (str): value to search for in sas file for extraction of data.
columns (list): resulting data column names.
Returns:
None
"""
super(SASValueToRedshiftOperator, self).__init__(*args, **kwargs)
self.aws_credentials_id = aws_credentials_id
self.redshift_conn_id = redshift_conn_id
self.table = table
self.s3_bucket = s3_bucket
self.s3_key = s3_key
self.sas_value = sas_value
self.columns = columns
def execute(self, context):
"""Executes task for staging to redshift.
Args:
context (:obj:`dict`): Dict with values to apply on content.
Returns:
None
"""
s3 = S3Hook(self.aws_credentials_id)
redshift_conn = BaseHook.get_connection(self.redshift_conn_id)
self.log.info(s3)
self.log.info(s3.get_credentials())
self.log.info(s3.list_keys(self.s3_bucket))
s3 = S3Hook(self.aws_credentials_id)
s3.list_keys(bucket_name=s3_bucket, prefix= s3_path, delimiter=delimiter)

Python mock: AssertionError: Expected and actual call not same

I am new to unittest.mock library and unable to solve the issue I am experiencing.
I have a class called ‘function.py’ in the below folder structure
src
_ init.py
function.py
tests
init.py
test_function.py
In test_function.py I have some code like this:
import unittest
from unittest import mock
from ..src.function import get_subscriptions
from ..src import function
class TestCheckOrder(unittest.TestCase):
#mock.patch.object(function, 'table')
def test_get_subscriptions_success(self, mocked_table):
mocked_table.query.return_value = []
user_id = "test_user"
status = True
get_subscriptions(user_id, status)
mocked_table.query.assert_called_with(
KeyConditionExpression=conditions.Key('user_id').eq(user_id),
FilterExpression=conditions.Attr('status').eq(int(status)))
In function.py:
import boto3
from boto3.dynamodb import conditions
dynamodb = boto3.resource("dynamodb")
table = dynamodb.Table("Subscriptions")
def get_subscriptions(user_id, active=True):
results = table.query(
KeyConditionExpression=conditions.Key(
'user_id').eq(user_id),
FilterExpression=conditions.Attr('status').eq(int(active))
)
return results['Items']
If I run this I get the following exception:
**AssertionError: Expected call: query(FilterExpression=<boto3.dynamodb.conditions.Equals object at 0x1116011d0>, KeyConditionExpression=<boto3.dynamodb.conditions.Equals object at 0x111601160>)
Actual call: query(FilterExpression=<boto3.dynamodb.conditions.Equals object at 0x1116010f0>, KeyConditionExpression=<boto3.dynamodb.conditions.Equals object at 0x111601080>)**
Thanks in advance for helping me out.
The issue is that when you're calling assert_called_with in your test, you're creating new instances of conditions.Key and conditions.Attr. And as these instances are different from one we had in the actual call, there's a mismatch(check the hex ids shown in the traceback).
Instead of this you can fetch the kwargs from the function call itself and test their properties:
name, args, kwargs = mocked_table.query.mock_calls[0]
assert kwargs['KeyConditionExpression'].get_expression()['values'][1] == user_id
assert kwargs['FilterExpression'].get_expression()['values'][1] == int(status)

mock s3 connection and boto.S3key to check set_content_from_string method

I am doing unit test with python mock. I've gone through blogs and python docs related to mocking but confuse about mocking the test case.
Here is the snippet for which I want to write test case.
The agenda is to test the method "set_contents_from_string()" using mock.
def write_to_customer_registry(customer):
# establish a connection with S3
conn = _connect_to_s3()
# build customer registry dict and convert it to json
customer_registry_dict = json.dumps(build_customer_registry_dict(customer))
# attempt to access requested bucket
bucket = _get_customer_bucket(conn)
s3_key = _get_customer_key(bucket, customer)
s3_key.set_metadata('Content-Type', 'application/json')
s3_key.set_contents_from_string(customer_registry_dict)
return s3_key
As you are testing some private methods I have added them to a module which I called s3.py that contains your code:
import json
def _connect_to_s3():
raise
def _get_customer_bucket(conn):
raise
def _get_customer_key(bucket, customer):
raise
def build_customer_registry_dict(cust):
raise
def write_to_customer_registry(customer):
# establish a connection with S3
conn = _connect_to_s3()
# build customer registry dict and convert it to json
customer_registry_dict = json.dumps(build_customer_registry_dict(customer))
# attempt to access requested bucket
bucket = _get_customer_bucket(conn)
s3_key = _get_customer_key(bucket, customer)
s3_key.set_metadata('Content-Type', 'application/json')
s3_key.set_contents_from_string(customer_registry_dict)
return s3_key
Next, in another module test_s3.py, I tested your code taking into account that for Unit Tests all interactions with third parties, such as network calls to s3 should be patched:
from unittest.mock import MagicMock, Mock, patch
from s3 import write_to_customer_registry
import json
#patch('json.dumps', return_value={})
#patch('s3._get_customer_key')
#patch('s3.build_customer_registry_dict')
#patch('s3._get_customer_bucket')
#patch('s3._connect_to_s3')
def test_write_to_customer_registry(connect_mock, get_bucket_mock, build_customer_registry_dict_mock, get_customer_key_mock, json_mock):
customer = MagicMock()
connect_mock.return_value = 'connection'
get_bucket_mock.return_value = 'bucket'
get_customer_key_mock.return_value = MagicMock()
write_to_customer_registry(customer)
assert connect_mock.call_count == 1
assert get_bucket_mock.call_count == 1
assert get_customer_key_mock.call_count == 1
get_bucket_mock.assert_called_with('connection')
get_customer_key_mock.assert_called_with('bucket', customer)
get_customer_key_mock.return_value.set_metadata.assert_called_with('Content-Type', 'application/json')
get_customer_key_mock.return_value.set_contents_from_string.assert_called_with({})
As you can see from the tests I am not testing that set_contents_from_string is doing what is supposed to do (since that should already be tested by the boto library) but that is being called with the proper arguments.
If you still doubt that the boto library is not properly testing such call you can always check it yourself in boto Github or boto3 Github
Something else you could test is that your are handling the different exceptions and edge cases in your code properly.
Finally, you can find more about patching and mocking in the docs. Usually the section about where to patch is really useful.
Some other resources are this blog post with python mock gotchas or this blog post I wrote myself (shameless self plug) after answering related pytest, patching and mocking questions in Stackoverflow.
came up with solution that worked for me, Posting it here, may be helpful for someone.
def setup(self):
self.customer = Customer.objects.create('tiertranstests')
self.customer.save()
def test_build_customer_registry(self):
mock_connection = Mock()
mock_bucket = Mock()
mock_s3_key = Mock()
customer_registry_dict = json.dumps(build_customer_registry_dict(self.customer))
# Patch S3 connection and Key class of registry method
with patch('<path>.customer_registry.S3Connection', Mock(return_value=mock_connection)),\
patch('<path>.customer_registry.Key', Mock(return_value=mock_s3_key)):
mock_connection.get_bucket = Mock(return_value=mock_bucket)
mock_s3_key.set_metadata.return_value = None
mock_s3_key.set_contents_from_string = Mock(return_value=customer_registry_dict)
write_to_customer_registry(self.customer)
mock_s3_key.set_contents_from_string.assert_called_once_with(customer_registry_dict)

Categories