I'm trying to write unit test for some_func function in some_func.py below. I do not want to connect to any database during this tests and I want to mock out any calls to DB.
Since the actual db calls are somewhat nested, i'm not able to get this to work. How do i patch any db interaction in this case?
db_module.py
import MySQLdb
from random_module.config.config_loader import config
from random_module.passwd_util import get_creds
class MySQLConn:
_conn = None
def __init__(self):
self._conn = self._get_rds_connection()
def get_conn(self):
return self._conn
def _get_rds_connection(self):
"""
Returns a conn object after establishing connection with the MySQL database
:return: obj: conn
"""
try:
logger.info("Establishing connection with MySQL")
username, password = get_creds(config['mysql_dev_creds'])
connection = MySQLdb.connect(
host=config['host'],
user=username,
passwd=password,
db=config['db_name'],
port=int(config['db_port']))
connection.autocommit = True
except Exception as err:
logger.error("Unable to establish connection to MySQL")
raise ConnectionAbortedError(err)
if (connection):
logger.info("Connection to MySQL successful")
return connection
else:
return None
db_mysql = MySQLConn()
some_func.py
from random_module.utils.db_module import db_mysql
def some_func():
try:
db_conn = db_mysql.get_conn()
db_cursor = db_conn.cursor()
results = db_cursor.execute("SELECT * FROM some_table")
results = db_cursor.fetchall()
result_set = []
for row in results:
result_set.insert(i, row['x'])
result_set.insert(i, row['y'])
except Exception:
logger.error("some error")
return result_set
dir struct -
src
├──pkg
| ├── common
| |___ some_func.py
| |___ __init__.py
|
| ├── utils
| |___ db_module.py
| |___ __init__.py
|
| __init__.py
You need to mock this line: db_conn = db_mysql.get_conn()
The return value of the get_conn method is what you're interested in.
from random_module.utils.db_module import db_mysql
#mock.patch.object(db_mysql, 'get_conn')
def test_some_func(self, mock_get):
mock_conn = mock.MagicMock()
mock_get.return_value = mock_conn
mock_cursor = mock.MagicMock()
mock_conn.cursor.return_value = mock_cursor
expect = ...
result = some_func()
self.assertEqual(expect, result)
self.assertTrue(mock_cursor.execute.called)
As you can see there is a lot of complexity setting up these mocks. That is because you are instantiating objects inside of your function. A better approach would be to refactor your code to inject the cursor because the cursor is the only thing relevant to this function. An even better approach would be to create a database fixture to test the function actually interacts with the database correctly.
You should mock db_mysql in some_module.py and then assert that the expected calls were made after some_func() executes.
from unittest TestCase
from unittest.mock import patch
from some_module import some_func
class SomeFuncTest(TestCase):
#patch('some_module.db_mysql')
def test_some_func(self, mock_db):
result_set = some_func()
mock_db.get_conn.assert_called_once_with()
mock_cursor = mock_db.get_conn.return_value
mock_cursor.assert_called_once_with()
mock_cursor.execute.assert_called_once_with("SELECT * FROM some_table")
mock_cursor.fetchall.return_value = [
{'x': 'foo1', 'y': 'bar1'},
{'x': 'foo2', 'y': 'bar2'}
]
mock_cursor.fetchall.assert_called_once_with()
self.assertEqual(result_set, ['foo1', 'bar1', 'foo2', 'bar2'])
Related
In my Flask app, I'm using Dependency Injection and here's what my app looks like. I have a service which uses S3 as a datastore and I'm trying to instantiate my app with the service injected (which is injected with the S3 client). However, it doesn't look like the S3 client is correctly instantiated or I'm doing something wildly different.
containers.py
class Container(containers.DeclarativeContainer):
wiring_config = containers.WiringConfiguration(modules=[".routes", ".scheduler"])
config = providers.Configuration()
s3_config = dict()
s3_config["path"], s3_config["filters"] = config.get("s3_bucket"), [("member_status_nm", "=", "ACTIVE")]
s3_repository = providers.Singleton(S3Repository, s3_config)
my_service = providers.Factory(
MyService, config, S3Repository
)
Here's my S3Repository:
import logging
import sys
import time
import some_library as lib
class S3Repository:
def __init__(self, s3_config):
self.path, self.columns, self.filters = \
s3_config.get("path", ""), s3_config.get("columns", []), s3_config.get("filters", [])
def fetch(self):
# execute fetch
result = lib.some_fetch_method(self.path, self.columns, self.filters)
return result
and MyService:
import #all relevant imports here
class MyService:
def __init__(self, config: dict, s3_repository: S3Repository) -> None:
logging.info("HealthSignalService(): initializing")
self.config = config["app"]["health_signal_service"]
# prepare s3_repository for the service
self.s3_repository = s3_repository
self.s3_repository.columns, self.s3_repository.filters, self.s3_repository.path = \
["x","y"], ["x1","y1"], "file_path"
def fetch_data(self) -> None:
try:
summary_result = self.s3_repository.fetch()
except (FileNotFoundError, IOError) as e:
print("failure")
return summary_result
def get_data(memberId):
sth = self.fetch_data()
return sth.get(memberId)
and finally tying it together in my routes.py:
#inject
#auth.login_required
def get_signals(
my_service: MyService = Provide[
Container.my_service
],
):
content = request.json
member_id = content["memberId"]
result = my_service.get_signals(member_id)
return jsonify(result)
When I hit my API endpoint I see this error:
summary_result = self.s3_repository.fetch()
TypeError: fetch() missing 1 required positional argument: 'self'
How do I correctly initialize my S3 client while using dependency injection?
I am trying to write pytest to test the following method by mocking the database. How to mock the database connection without actually connecting to the real database server.I tried with sample test case. I am not sure if that is right way to do it. Please correct me if I am wrong.
//fetch.py
import pymssql
def cur_fetch(query):
with pymssql.connect('host', 'username', 'password') as conn:
with conn.cursor(as_dict=True) as cursor:
cursor.execute(query)
response = cursor.fetchall()
return response
//test_fetch.py
import mock
from unittest.mock import MagicMock, patch
from .fetch import cur_fetch
def test_cur_fetch():
with patch('fetch.pymssql', autospec=True) as mock_pymssql:
mock_cursor = mock.MagicMock()
test_data = [{'password': 'secret', 'id': 1}]
mock_cursor.fetchall.return_value = MagicMock(return_value=test_data)
mock_pymssql.connect.return_value.cursor.return_value.__enter__.return_value = mock_cursor
x = cur_fetch({})
assert x == None
The result is:
AssertionError: assert <MagicMock name='pymssql.connect().__enter__().cursor().__enter__().fetchall()' id='2250972950288'> == None
Please help.
Trying to mock out a module is difficult. Mocking a method call is simple. Rewrite your test like this:
import unittest.mock as mock
import fetch
def test_cur_tech():
with mock.patch('fetch.pymssql.connect') as mock_connect:
mock_conn = mock.MagicMock()
mock_cursor = mock.MagicMock()
mock_connect.return_value.__enter__.return_value = mock_conn
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_cursor.fetchall.return_value = [{}]
res = fetch.cur_fetch('select * from table')
assert mock_cursor.execute.call_args.args[0] == 'select * from table'
assert res == [{}]
In the above code we're explicitly mocking pymyssql.connect, and
providing appropriate fake context managers to make the code in
cur_fetch happy.
You could simplify things a bit if cur_fetch received the connection
as an argument, rather than calling pymssql.connect itself.
I am currently trying to write unit tests for my python code using Moto & #mock_dynamodb2 . So far it's been working for me to test my "successful operation" test cases. But I'm having trouble getting it to work for my "failure cases".
In my test code I have:
#mock_dynamodb2
class TestClassUnderTestExample(unittest.TestCase):
def setUp(self):
ddb = boto3.resource("dynamodb", "us-east-1")
self.table = ddb.create_table(<the table definition)
self.example_under_test = ClassUnderTestExample(ddb)
def test_some_thing_success(self):
expected_response = {<some value>}
assert expected_response = self.example_under_test.write_entry(<some value>)
def test_some_thing_success(self):
response = self.example_under_test.write_entry(<some value>)
# How to assert exception is thrown by forcing put item to fail?
The TestClassUnderTestExample would look something like this:
class ClassUnderTestExample:
def __init__(self, ddb_resource=None):
if not ddb_resource:
ddb_resource = boto3.resource('dynamodb')
self.table = ddb_resource.Table(.....)
def write_entry(some_value)
ddb_item = <do stuff with some_value to create sanitized item>
response = self.table.put_item(
Item=ddb_item
)
if pydash.get(response, "ResponseMetadata.HTTPStatusCode") != 200:
raise SomeCustomErrorType("Unexpected response from DynamoDB when attempting to PutItem")
return ddb_item
I've been completely stuck when it comes to actually mocking the .put_item operation to return a non-success value so that I can test that the ClassUnderTestExample will handle it as expected and throw the custom error. I've tried things like deleting the table before running the test, but that just throws an exception when getting the table rather than an executed PutItem with an error code.
I've also tried putting a patch for pydash or for the table above the test but I must be doing something wrong. I can't find anything in moto's documentation. Any help would be appreciated!
The goal of Moto is to completely mimick AWS' behaviour, including how to behave when the user supplies erroneous inputs. In other words, a call to put_item() that fails against AWS, would/should also fail against Moto.
There is no build-in way to force an error response on a valid input.
It's difficult to tell from your example how this can be forced, but it looks like it's worth playing around with this line to create an invalid input:
ddb_item = <do stuff with some_value to create sanitized item>
Yes, you can. Use mocking for this. Simple and runnable example:
from unittest import TestCase
from unittest.mock import Mock
from uuid import uuid4
import boto3
from moto import mock_dynamodb2
def create_user_table(table_name: str) -> dict:
return dict(
TableName=table_name,
KeySchema=[
{
'AttributeName': 'id',
'KeyType': 'HASH'
},
],
AttributeDefinitions=[
{
'AttributeName': 'id',
'AttributeType': 'S'
},
],
BillingMode='PAY_PER_REQUEST'
)
class UserRepository:
table_name = 'users'
def __init__(self, ddb_resource):
if not ddb_resource:
ddb_resource = boto3.resource('dynamodb')
self.table = ddb_resource.Table(self.table_name)
def create_user(self, username):
return self.table.put_item(Item={'id': str(uuid4), 'username': username})
#mock_dynamodb2
class TestUserRepository(TestCase):
def setUp(self):
ddb = boto3.resource("dynamodb", "us-east-1")
self.table = ddb.create_table(**create_user_table('users'))
self.test_user_repo = UserRepository(ddb)
def tearDown(self):
self.table.delete()
def test_some_thing_success(self):
user = self.test_user_repo.create_user(username='John')
assert len(self.table.scan()['Items']) == 1
def test_some_thing_failure(self):
self.test_user_repo.table = table = Mock()
table.put_item.side_effect = Exception('Boto3 Exception')
with self.assertRaises(Exception) as exc:
self.test_user_repo.create_user(username='John')
self.assertTrue('Boto3 Exception' in exc.exception)
I am new to pytest and wanted to add the below 3 methods for unit test coverage without actually using a real mongo db instance but rather mock it.
Could try using a real db instance but it isn't recommended.
Request for an example on how to mock mongodb client and get a document
import os
import logging
import urllib.parse
from dotenv import load_dotenv
from pymongo import MongoClient
from logger import *
load_dotenv()
def getMongoConnection():
userName = urllib.parse.quote_plus(os.getenv("USER_NAME"))
password = urllib.parse.quote_plus(os.getenv("PASSWORD"))
hostName1_port = os.getenv("HOST_NAME1")
hostName2_port = os.getenv("HOST_NAME2")
hostName3_port = os.getenv("HOST_NAME3")
authSourceDatabase = os.getenv("AUTH_SOURCE_DATABASE")
replicaSet = os.getenv("REPLICA_SET")
connectTimeoutMS = "1000"
socketTimeoutMS = "30000"
maxPoolSize = "100"
try:
client = MongoClient('mongodb://'+userName+':'+password+'#'+hostName1_port+','+hostName2_port+','+hostName3_port+'/'+authSourceDatabase+'?ssl=true&replicaSet='+replicaSet +
'&authSource='+authSourceDatabase+'&retryWrites=true&w=majority&connectTimeoutMS='+connectTimeoutMS+'&socketTimeoutMS='+socketTimeoutMS+'&maxPoolSize='+maxPoolSize)
return client
except Exception as e:
logging.error("Error while connecting to mongoDB.")
return False
def connectToDBCollection(client, databaseName, collectionName):
db = client[databaseName]
collection = db[collectionName]
return collection
def getDoc(bucketName, databaseName, collectionName):
try:
client = getMongoConnection()
if client != False:
collection = connectToDBCollection(
client, databaseName, collectionName)
return collection.find_one({'bucket': bucketName})
except Exception as e:
logging.error("An exception occurred while fetching doc, error is ", e)
Edit : (Tried using below code and was able to cover most of the cases but seeing an error)
def test_mongo():
db_conn = mongomock.MongoClient()
assert isinstance(getMongoConnection(), MongoClient)
def test_connect_mongo():
return connectToDBCollection(mongomock.MongoClient(), "sampleDB", "sampleCollection")
//trying to cover exception block for getMongoConnection()
def test_exception():
with pytest.raises(Exception) as excinfo:
getMongoConnection()
assert str(excinfo.value) == False
def test_getDoc():
collection = mongomock.MongoClient().db.collection
stored_obj = collection.find_one({'_id': 1})
assert stored_obj == getDoc("bucket", "db", "collection")
def test_createDoc():
collection = mongomock.MongoClient().db.collection
stored_obj = collection.insert_one({'_id': 1})
assert stored_obj == createDoc("bucket", "db", "collection")
def test_updateDoc():
collection = mongomock.MongoClient().db.collection
stored_obj = collection.replace_one({'_id': 1}, {'_id': 2})
assert stored_obj == updateDoc(
{'_id': 1}, {'$set': {'_id': 2}}, "db", "collection")
Errors :
test_exception - Failed: DID NOT RAISE <class 'Exception'>
test_createDoc - TypeError: not all arguments converted during string formatting
AssertionError: assert <pymongo.results.UpdateResult object at 0x7fc0e835a400> == <pymongo.results.UpdateResult object at 0x7fc0e8211900>
Looks like MongoClient is a nested dict with databaseName and collectionName or implemented with a key accessor.
You could mock the client first with
import unittest
mocked_collection = unittest.mock.MagicMock()
# mock the find_one method
mocked_collection.find_one.return_value = {'data': 'collection_find_one_result'}
mocked_client = unittest.mock.patch('pymongo.MongoClient').start()
mocked_client.return_value = {
'databaseName': {'collectionname': mocked_collection}
}
Maybe try a specialized mocking library like MongoMock?
In particular the last example using #mongomock.patch looks like it can be relevant for your code.
I have module billing/billing-collect-project-license which have LicenseStatistics class. Which have calls to Redis, ORMRDS, CE are other modules which used in this class. Following is LicenseStatistics class where get_allocate_count is instance method which calls to ce_obj.get_set_ce_obj, get_license_id and many others.
The get_license_id method calls get_customer_details.
class LicenseStatistics():
"""This class encapsulates all logic related to license usage."""
def __init__(self):
self.logger = LOGGER
# RedisNode List should be in the following format:
# HOST1:PORT1,HOST2:PORT2,HOST3:PORT3 etc
redis_node_list = os.environ.get('REDIS_NODE_LIST', '').split(',')
self.redis_utils = RedisUtils(redis_node_list)
# DB object to read customer details
dbhost = os.environ.get('DBHOST')
dbuser = os.environ.get('DBUSER')
dbpass = os.environ.get('DBPASSWORD')
dbname = os.environ.get('DBNAME')
dbport = os.environ.get('DBPORT')
self.dbutils = ORMDBUtils(dbhost, dbuser, dbpass, dbname, dbport)
self.ce_obj = CE()
self.bill = Billing()
def get_license_id(self, project_id):
"""
#Summary: Get License Id for customer/project from customer table by
project id
#param project_id (string): CE project Id
#return (string): CE License Id which associate with Project.
"""
# Get license ID from RDS
customer_details = self.get_customer_details(project_id)
print("customer_details:", customer_details)
license_id = customer_details["license_id"]
if not license_id:
msg = "No license for project {}".format(project_id)
self.logger.error(msg)
raise InvalidParamException(msg)
print("license_id:", license_id)
return license_id
def get_customer_details(self, project_id):
"""
#Summary: Get Customer/Project details from customer table
#param project_id (string): CE project Id
#return (dictionary): Customer details from customer table.
"""
filters = {"project_id": project_id}
customer_details = self.dbutils.get_items(
table_name=RDSCustomerTable.TABLE_NAME.value,
columns_to_select=["account_id", "license_id"],
filters=filters
)
if not customer_details:
msg = "No customer found for project {}".format(project_id)
self.logger.error(msg)
raise NoCustomerException(msg)
return customer_details[0]
def is_shared_license(self, license_id):
# This function return True or False
pass
def get_project_machines_count(self, project_id, license_id):
# This function return number of used license.
count = 20
return count
def get_license_usage(self, project_id, license_id):
# This function return number of machines used project.
count = 10
return count
def get_allocate_count(self, project_id):
"""
#Summary: Get number of licenses are used by Project.
#param project_id (string): CloudEndure Project Id.
#return (int): Number of license are used in Project.
"""
# set Session get_customer_detailsfrom Redis
status = self.ce_obj.get_set_ce_obj(
project_id=project_id, redis_utils=self.redis_utils
)
print("license_id status--:", status)
if not status:
msg = "Unable to set CEproject {}".format(project_id)
self.logger.critical(msg)
raise InvalidParamException(msg, "project_id", project_id)
print("project_id------------:", project_id)
# Get license Id
license_id = self.get_license_id(project_id)
print("license_id:", license_id)
# Check license is shared
shared_flag = self.is_shared_license(license_id)
if not shared_flag:
# Get license usage
licenses_used = self.get_license_usage(project_id, license_id)
else:
# Get machine account
licenses_used = self.get_project_machines_count(
project_id, license_id
)
return licenses_used
I am writing unit test for get_allocate_count method, I mock the Redis, ORMRDS, Custom Exception, Logger.
This function call ce_obj.get_set_ce_obj function which return True/False. I am to mock/patch return value of this function successfully.
But when call goes to next function call i.e. get_license_id, call goes into actual function call and due to improper inputs. I am not able to patch/mock
Following is unit test code:
import responses
import unittest
from unittest.mock import patch
import os
import sys
cwd_path = os.getcwd()
sys.path.append(cwd_path)
sys.path.append(cwd_path+"/../sam-cfns/code")
sys.path.append(cwd_path+"/../sam-cfns/code/billing")
from unit_tests.common.mocks.env_mock import ENV_VAR
from unit_tests.common.mocks.logger import FakeLogger
from unit_tests.common.mocks.cache_mock import RedisUtilsMock
from unit_tests.common.mocks.ormdb_mock import ORMDBUtilsMockProject
from unit_tests.common.mocks.exceptions_mock import NoCustomerExceptionMock
from unit_tests.common.mocks.exceptions_mock import BillingExceptionMock
from unit_tests.common.mocks.exceptions_mock import InvalidParamExceptionMock
from unit_tests.common.mocks.api_responses import mock_response
from unit_tests.common.examples import ce_machines_data
from unit_tests.common.examples import ce_license_data
from unit_tests.common.examples import ce_data
class BillingTest(unittest.TestCase):
""" Billing TEST class drive from UnitTest """
#patch("billing-collect-project-license.Logger", FakeLogger)
#patch("os.environ", ENV_VAR)
#patch("billing-collect-project-license.RedisUtils", RedisUtilsMock)
#patch("billing-collect-project-license.ORMDBUtils", ORMDBUtilsMockProject)
#patch("exceptions.NoCustomerException", NoCustomerExceptionMock)
#patch("billing.billing_exceptions.BillingException", BillingExceptionMock)
#patch("billing.billing_exceptions.InvalidParamException", InvalidParamExceptionMock)
def __init__(self, *args, **kwargs):
"""Initialization"""
super(BillingTest, self).__init__(*args, **kwargs)
billing_collect_project_license_module = (
__import__("cbr-billing-collect-project-license")
)
self.licenses_stats_obj = (
billing_collect_project_license_module.LicenseStatistics()
)
class BillingCollectProjectLicense(BillingTest):
"""Login Unit Test Cases"""
def __init__(self, *args, **kwargs):
"""Initialization"""
super(BillingCollectProjectLicense, self).__init__(*args, **kwargs)
def setUp(self):
"""Setup for all Test Cases."""
pass
##patch("billing.cbr-billing-collect-project-license.LicenseStatistics."
# "get_project_machines_count")
##patch("billing.cbr-billing-collect-project-license.LicenseStatistics."
# "get_customer_details")
##patch("billing.cbr-billing-collect-project-license.LicenseStatistics.get_license_id")
#patch("billing.cbr-billing-collect-project-license.LicenseStatistics.get_license_id")
#patch("cbr.ce.CloudEndure.get_set_ce_obj")
def test_get_allocate_count(self, get_set_ce_obj_mock, get_license_id_mock):
project_id = ce_data.CE_PROJECT_ID
license_id = ce_license_data.LICENSE_ID
get_set_ce_obj_mock.return_value = True
get_license_id_mock.return_value = license_id
# LicenseStatistics_mock.return_value.get_license_id.return_value = license_id
#get_license_id_mock.return_value = license_id
# self.licenses_stats_obj.get_license_id = get_license_id_mock
get_customer_details_mock = {"license_id": license_id}
# is_shared_license_mock.return_value = True
# get_project_machines_count_mock.return_value = 20
resp = self.licenses_stats_obj.get_allocate_count(
project_id
)
self.assertEqual(resp, 20)
if __name__ == '__main__':
unittest.main()
I am not able to patch get_license_id function from same class. This actually calling get_license_id function and fails.
I want to mock return value of get_license_id function.
Anyone help help me?
Thank you.
Issue is I am initializing it in the init, so monkeypatching methods of LicenseStatistics class later have no effect on the already created instance. #hoefling
By monkey Patching I able to run Test Cases successfully.
sample code:
def test_get_allocate_count_ok_4(self, ):
"""
#Summary: Test case for successful response for shared license
by other unittest method - Monkey Patching
"""
def get_customer_details_mp(_):
"""Monkey Patching function for get_customer_details"""
data = [
{
"account_id": "abc",
"project_id": "abc",
"license_id": "abc",
"status": "Active"
}
]
return data
def get_set_ce_obj_mp(_, _tmp):
"""Monkey Patching function for get_set_ce_obj"""
return True
def get_license_id_mp(_):
"""Monkey Patching function for get_license_id"""
return "abc"
def is_shared_license_mp(_):
"""Monkey Patching function for is_shared_license"""
return True
def get_project_machines_count_mp(_, _license_id):
"""Monkey Patching function for get_project_machines_count"""
return 5
project_id = "abc"
# Monkey Patching
self.licenses_stats_obj.get_customer_details = get_customer_details_mp
self.licenses_stats_obj.ce_obj.get_set_ce_obj = get_set_ce_obj_mp
self.licenses_stats_obj.get_license_id = get_license_id_mp
self.licenses_stats_obj.is_shared_license = is_shared_license_mp
self.licenses_stats_obj.get_project_machines_count = (
get_project_machines_count_mp
)
resp = self.licenses_stats_obj.get_allocate_count(project_id)
self.assertEqual(resp, 5)