I am having trouble understanding how mocking works when the get responses involve using the with keyword. Here is an example I am following for my class `Album' and I have been successful when I am mocking a url as seen below:
def find_album_by_id(id):
url = f'https://jsonplaceholder.typicode.com/albums/{id}'
response = requests.get(url)
if response.status_code == 200:
return response.json()['title']
else:
return None
Here the test
class TestAlbum(unittest.TestCase):
#patch('album.requests')
def test_find_album_by_id_success(self, mock_requests):
# mock the response
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
'userId': 1,
'id': 1,
'title': 'hello',
}
# specify the return value of the get() method
mock_requests.get.return_value = mock_response
# call the find_album_by_id and test if the title is 'hello'
self.assertEqual(find_album_by_id(1), 'hello')
However, I am trying to understand how this would work with the with keyword involved in the code logic which I am using in my project. This is how I changed the method
def find_album_by_id(id):
url = f'https://jsonplaceholder.typicode.com/albums/{id}'
with requests.get(url) as response:
pipelines = response.json()
if response.status_code == 200:
return pipelines['title']
else:
return None
Here is my current test:
#patch('album.requests.get')
def test_find_album_by_id_success(self, mock_requests):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.pipelines.response.json = {
'userId': 1,
'id': 1,
'title': 'hello',
}
mock_requests.return_value.json.return_value = mock_response
self.assertEqual(find_album_by_id(1), 'hello')
Thanks
I have tried debugging the test and it just never receives the status code of 200 so I am not sure I am mocking response correctly at all? From my understanding the mock_response is supposed to have that status code of 200 but breakline indicates that response is just a MagicMock with nothing in it.
I'm planning to call multiple functions and execute SQL queries from random inputs from JSON.
Actual code:
def daily():
db_connection = DatabaseConnection()
aaa(db_connection)
bbb(db_connection)
return {'success': True}
def aaa(db_connection):
database_table = 'aaa'
symbols_list = [{'code': 'XXX', 'database_column': 'aaa_1'}, {'code': 'YYY', 'database_column': 'aaa_2'}]
load_data_from_api_to_database(db_connection, database_table, symbols_list)
def load_data_from_api_to_database(db_connection, database_table, symbols_list):
http_request = HttpRequest()
for _, symbol in enumerate(symbols_list):
code = symbol['code']
database_column = symbol['database_column']
response = http_request.get(f'https://api.example.com/value/{code}', headers={'accept': 'application/json', 'appkey': api_key})
json_data = json.loads(response.text)
if response.status_code != 200:
return
data_points = json_data['dataPoint']
for x, _ in enumerate(data_points):
value = data_points[x]['value']
date = data_points[x]['date']
db_connection.execute(f'INSERT INTO "{database_table}" ("{database_date_column}") VALUES (%(date_time)s) ON CONFLICT ("{database_date_column}") DO NOTHING', {'date_time': date})
db_connection.execute(f'UPDATE "{database_table}" SET "{database_column}" = %(value)s WHERE "{database_date_column}" = %(date_time)s', {'value': value, 'date_time': date})
db_connection.commit()
bbb() similar to aaa() just different json array value.
Test code:
class TestDailyHandler(unittest.TestCase):
#classmethod
def setup_class(cls):
cls.mock_get_patcher = patch('src.daily_handler.HttpRequest.get')
cls.mock_get = cls.mock_get_patcher.start()
cls.mock_database_connection_patcher = patch('src.daily_handler.DatabaseConnection')
cls.mock_database_connection = cls.mock_database_connection_patcher.start()
def load_data_from_api_to_database(self):
assert load_data_from_api_to_database({}, None) == {'success': True}
symbols_list = [{'code': 'XXX', 'database_column': 'aaa_1'}, {'code': 'YYY', 'database_column': 'aaa_2'}]
for x in range(len(symbols_list)):
code = [x]['code']
self.mock_get.assert_any_call(f'https://api.example.com/value/symbols_list{code}', headers={'accept': 'application/json', 'appkey': self.mock_get_aws_secret.return_value})
db_execute_many_args_list = self.mock_database_connection.return_value.execute_many.call_args_list
daily_table_insert_command_length = len([x for x in db_execute_many_args_list if re.search(r'INSERT INTO ', str(x), re.IGNORECASE)])
self.assertEqual(daily_table_insert_command_length, len(db_execute_many_args_list))
self.assertEqual(self.mock_database_connection.return_value.commit.call_count, daily_table_insert_command_length)
db_execute_many_args_list = self.mock_database_connection.return_value.execute_many.call_args_list
daily_table_update_command_length = len([x for x in db_execute_many_args_list if re.search(r'UPDATE ', str(x), re.IGNORECASE)])
self.assertEqual(daily_table_update_command_length, len(db_execute_many_args_list))
self.assertEqual(self.mock_database_connection.return_value.commit.call_count, daily_table_update_command_length)
By the way, I'm not sure how to call multiple functions aaa() and bbb(). I supposely to test starting from daily() instead load_data_from_api_to_database() function. Also JSON array input from each function. Currently it's static value.
Warning: the following code is provided as guidance only
taken from example I have a string whose content is a function name, how to refer to the corresponding function in Python?
additional reading Store functions in list and call them later
new answer
def aaa(db_connection, data):
# bind from data
symbols_list = data
# ...
def bbb(db_connection, data):
# bind from data
symbols_list = data
# ...
dispatcher = { "aaa" : aaa, "bbb" : bbb }
def daily():
db_connection = DatabaseConnection()
#loop the dict for functions
for fn in dispatcher:
if(callable(dispatcher[func_name]))
json_str = input('Enter your JSON data:')
try
data = json.load(json_str);
dispatcher[func_name](db_connection, data)
except JSONDecodeError
print('Error loading json')
return None
return {'success': True}
I am new to pytest and wanted to add the below 3 methods for unit test coverage without actually using a real mongo db instance but rather mock it.
Could try using a real db instance but it isn't recommended.
Request for an example on how to mock mongodb client and get a document
import os
import logging
import urllib.parse
from dotenv import load_dotenv
from pymongo import MongoClient
from logger import *
load_dotenv()
def getMongoConnection():
userName = urllib.parse.quote_plus(os.getenv("USER_NAME"))
password = urllib.parse.quote_plus(os.getenv("PASSWORD"))
hostName1_port = os.getenv("HOST_NAME1")
hostName2_port = os.getenv("HOST_NAME2")
hostName3_port = os.getenv("HOST_NAME3")
authSourceDatabase = os.getenv("AUTH_SOURCE_DATABASE")
replicaSet = os.getenv("REPLICA_SET")
connectTimeoutMS = "1000"
socketTimeoutMS = "30000"
maxPoolSize = "100"
try:
client = MongoClient('mongodb://'+userName+':'+password+'#'+hostName1_port+','+hostName2_port+','+hostName3_port+'/'+authSourceDatabase+'?ssl=true&replicaSet='+replicaSet +
'&authSource='+authSourceDatabase+'&retryWrites=true&w=majority&connectTimeoutMS='+connectTimeoutMS+'&socketTimeoutMS='+socketTimeoutMS+'&maxPoolSize='+maxPoolSize)
return client
except Exception as e:
logging.error("Error while connecting to mongoDB.")
return False
def connectToDBCollection(client, databaseName, collectionName):
db = client[databaseName]
collection = db[collectionName]
return collection
def getDoc(bucketName, databaseName, collectionName):
try:
client = getMongoConnection()
if client != False:
collection = connectToDBCollection(
client, databaseName, collectionName)
return collection.find_one({'bucket': bucketName})
except Exception as e:
logging.error("An exception occurred while fetching doc, error is ", e)
Edit : (Tried using below code and was able to cover most of the cases but seeing an error)
def test_mongo():
db_conn = mongomock.MongoClient()
assert isinstance(getMongoConnection(), MongoClient)
def test_connect_mongo():
return connectToDBCollection(mongomock.MongoClient(), "sampleDB", "sampleCollection")
//trying to cover exception block for getMongoConnection()
def test_exception():
with pytest.raises(Exception) as excinfo:
getMongoConnection()
assert str(excinfo.value) == False
def test_getDoc():
collection = mongomock.MongoClient().db.collection
stored_obj = collection.find_one({'_id': 1})
assert stored_obj == getDoc("bucket", "db", "collection")
def test_createDoc():
collection = mongomock.MongoClient().db.collection
stored_obj = collection.insert_one({'_id': 1})
assert stored_obj == createDoc("bucket", "db", "collection")
def test_updateDoc():
collection = mongomock.MongoClient().db.collection
stored_obj = collection.replace_one({'_id': 1}, {'_id': 2})
assert stored_obj == updateDoc(
{'_id': 1}, {'$set': {'_id': 2}}, "db", "collection")
Errors :
test_exception - Failed: DID NOT RAISE <class 'Exception'>
test_createDoc - TypeError: not all arguments converted during string formatting
AssertionError: assert <pymongo.results.UpdateResult object at 0x7fc0e835a400> == <pymongo.results.UpdateResult object at 0x7fc0e8211900>
Looks like MongoClient is a nested dict with databaseName and collectionName or implemented with a key accessor.
You could mock the client first with
import unittest
mocked_collection = unittest.mock.MagicMock()
# mock the find_one method
mocked_collection.find_one.return_value = {'data': 'collection_find_one_result'}
mocked_client = unittest.mock.patch('pymongo.MongoClient').start()
mocked_client.return_value = {
'databaseName': {'collectionname': mocked_collection}
}
Maybe try a specialized mocking library like MongoMock?
In particular the last example using #mongomock.patch looks like it can be relevant for your code.
Howdie do,
So I have this API that does the following:
1) It receives a XML request. It first parses that request to ensure the request is in the correct format.
2) If the XML is in the correct format, it checks to ensure that the correct user is authenicating.
3) If authenication is successful, it then retreives a URI from a database that the API will send a get request to.
4) If the response is successful, meaning it's a XML reply, it will use XSLT to transform the request into a format
5) It then adds the request to the database and returns the transformed XML to the user that queried the API
I have error handling involved at each step, but the issue is, I've had to nest 5 if else statements to accomplish this.
I know there has to be a better way to rewrite this error handling logic without so many nested if statements, but I'm not sure how. The subsequent steps rely on the previous to ensure that if any error occurs, it's returned properly to the user.
Below is my main Flask-API that I've created. The second file is a module that I've created which does a lot of the error processing. Those functions return a state(True/False) and the response to the main Flask-API.
Can someone give me some ideas on how to rewrite this API without the nesting? The API works 100% and does what it should for catching errors, but I just know there's a better way
Main API:
#dbConnect.app.route('/services/tracking/getShipmentStatus', methods=['POST'])
def parsexml2():
parseStatus, returnValues = func.parseNWRequest(request.data, 'SS')
if parseStatus:
authStatus, authResponse = func.checkAuthorization(returnValues['bu'], request.headers['Authorization'], 'SS')
if authStatus:
getURIStatus, uriResponse = func.getURI('SS')
if getURIStatus:
search = {'bu': returnValues['bu'], 'starttime': returnValues['start'], 'endtime': returnValues['end'],
'requestid': returnValues['requestid'], 'pagesize': returnValues['page']}
responseStatus, depascoResponse = func.sendDepascoRequest(uriResponse, search, 'SS')
if responseStatus:
nakedResponse = func.transformXML(depascoResponse.content, 'transformTracking.xsl')
# Write Request to db
file_name = 'WS_SS_' + returnValues['bu']
request_file_size = request.headers['Content-Length']
if func.addToDb(file_name, 'text/xml', 'SS.Request', returnValues['bu'], 'Y', request.data,
request_file_size) or func.addToDb(file_name, 'text/xml', 'SS.Result',
returnValues['bu'], 'Y', nakedResponse):
pass
return Response(nakedResponse, mimetype='text/xml')
else:
return Response(depascoResponse, mimetype='text/xml')
else:
return Response(uriResponse, mimetype='text/xml')
else:
return Response(authResponse, mimetype='text/xml')
else:
return Response(returnValues, mimetype='text/xml')
Imported module functions:
def transformXML(response, xsl):
xml = ET.tostring(ET.fromstring(response))
xslt = ET.XSLT(ET.parse(xsl))
transformedXML = xslt(ET.XML(xml))
return ET.tostring(transformedXML, pretty_print=True)
def addToDb(filename, mime, docType, customer_code, activeState, document_blob, filesize=None):
try:
response = dbConnect.Documents(file_name=filename, mime_type=mime, file_size=filesize, doc_type=docType,
customer_code=customer_code, is_active=activeState, document_blob=document_blob)
dbConnect.db.session.add(response)
dbConnect.db.session.commit()
dbConnect.db.session.close()
except exc.SQLAlchemyError:
return False
else:
return True
def generateXMLErrorResponse(errorMessage, api):
E = ElementMaker()
if api == 'SS':
GETSHIPMENTSTATUSRESPONSE = E.getShipmentStatusResponse
GETSHIPMENTSTATUSRESULT = E.getShipmentStatusResult
OUTCOME = E.outcome
RESULT = E.result
ERROR = E.error
xml_error = GETSHIPMENTSTATUSRESPONSE(
GETSHIPMENTSTATUSRESULT(
OUTCOME(
RESULT('Failure'),
ERROR(errorMessage)
)
)
)
return ET.tostring(xml_error, pretty_print=True)
elif api == 'IS':
GETINVENTORYSTATUSRESPONSE = E.getInventoryStatusResponse
GETINVENTORYSTATUSRESULT = E.getInventoryStatusResult
OUTCOME = E.outcome
RESULT = E.result
ERROR = E.error
xml_error = GETINVENTORYSTATUSRESPONSE(
GETINVENTORYSTATUSRESULT(
OUTCOME(
RESULT('Failure'),
ERROR(errorMessage)
)
)
)
return ET.tostring(xml_error, pretty_print=True)
def checkAuthorization(bu, headers, status):
error = 'Invalid clientCode for account type'
auth_search = re.search('username="(.*?)"', headers)
auth_user = auth_search.group(1)
if auth_user.upper() != "ADMIN":
if (str(bu).upper() != auth_user.upper()) and status == 'SS':
return False, Response(generateXMLErrorResponse(error, status), mimetype='text/xml')
elif (str(bu).upper() != auth_user.upper()) and status == 'IS':
return False, Response(generateXMLErrorResponse(error, status), mimetype='text/xml')
else:
return True, 'User authenticated'
else:
return True, 'User authenticated'
def getURI(api):
try:
if api == 'SS':
tracking = dbConnect.db.session.query(dbConnect.AppParam).\
filter(dbConnect.AppParam.name == 'TRACKING_WEB_SERVICE_URI').first()
return True, tracking.value
elif api == 'IS':
inventory = dbConnect.db.session.query(dbConnect.AppParam).\
filter(dbConnect.AppParam.name == 'INVENTORY_WEB_SERVICE_URI').first()
return True, inventory.value
except exc.OperationalError:
sendEmail()
return False, generateXMLErrorResponse('Service Unavailable', api)
def sendEmail():
msg = MIMEText('Unable to connect to DB')
msg['Subject'] = "Database server down!"
msg['From'] = ''
msg['To'] = ''
s = smtplib.SMTP('localhost')
s.sendmail(msg['From'], msg['To'], msg.as_string())
s.quit()
return True
def parseNWRequest(nwRequest, api):
returnValues = {}
if api == 'SS':
try:
xml = xmltodict.parse(nwRequest)
returnValues['start'] = xml['getShipmentStatus']['getShipmentStatusRequest']['startTime']
returnValues['end'] = xml['getShipmentStatus']['getShipmentStatusRequest']['endTime']
if validateDate(returnValues['start'], returnValues['end']):
returnValues['bu'] = xml['getShipmentStatus']['getShipmentStatusRequest']['clientCode']
returnValues['page'] = xml['getShipmentStatus']['getShipmentStatusRequest']['pageSize']
returnValues['requestid'] = xml['getShipmentStatus']['getShipmentStatusRequest']['requestId']
return True, returnValues
else:
return False, generateXMLErrorResponse('Invalid startDate/endDate', api)
except xmltodict.expat.ExpatError:
return False, generateXMLErrorResponse('Invalid Formed XML', api)
elif api == 'IS':
try:
xml = xmltodict.parse(request.data)
returnValues['bu'] = xml['getInventoryStatus']['getInventoryStatusRequest']['clientCode']
returnValues['facility'] = xml['getInventoryStatus']['getInventoryStatusRequest']['facility']
return True, returnValues
except xmltodict.expat.ExpatError:
return False, generateXMLErrorResponse('Invalid Formed XML', api)
def validateDate(startDate, endDate):
try:
datetime.strptime(startDate, '%Y-%m-%dT%H:%M:%S')
datetime.strptime(endDate, '%Y-%m-%dT%H:%M:%S')
except ValueError:
return False
else:
return True
def sendDepascoRequest(uri, search, api):
auth=('', '')
depascoResponse = requests.get(uri, auth=auth, params=search)
try:
depascoResponse.raise_for_status()
except requests.exceptions.HTTPError:
return False, generateXMLErrorResponse(depascoResponse.content, api)
else:
return True, requests.get(uri, auth=auth, params=search)
******* UPDATE **********
Thanks to the accepted answer, I removed all layers of the nested if statements. I have my functions just raise a custom exception which is handled in the main program.
__author__ = 'jw1050'
from functions import parseNWRequest, checkAuthorization, getURI, sendDepascoRequest, transformXML, addToDb, APIError
from flask import request
from flask import Response
import dbConnect
import lxml.etree as ET
from sqlalchemy import text
#dbConnect.app.route('/services/inventory/getInventoryStatus', methods=['POST'])
def parsexml():
try:
returnValues = parseNWRequest(request.data, 'IS')
checkAuthorization(returnValues['bu'], request.headers['Authorization'], 'IS')
uri = getURI('IS')
search = {'bu': returnValues['bu'], 'facility': returnValues['facility']}
depascoResponse = sendDepascoRequest(uri, search, 'IS')
s = text("Select sku, allocated from fgw_allocated_sku_count where client_code = :c and fulfillment_location = :t")
result = dbConnect.db.engine.execute(s, c=returnValues['bu'], t=returnValues['facility']).fetchall()
root = ET.fromstring(depascoResponse.content)
for row in result:
for element in root.iter('Item'):
xmlSKU = element.find('SKU').text
if xmlSKU == row[0]:
newQonOrder = int(element.find('QuantityOnOrder').text) + row[1]
element.find('QuantityOnOrder').text = str(newQonOrder)
newQAvailable = int(element.find('QuantityAvailable').text) - newQonOrder
element.find('QuantityAvailable').text = str(newQAvailable)
nakedResponse = transformXML(ET.tostring(root), 'transformInventory.xsl')
# Write Request to db
file_name = 'WS_IS_' + returnValues['bu']
request_file_size = request.headers['Content-Length']
addToDb('SS', file_name, 'text/xml', 'IS.Req', returnValues['bu'], 'Y', request.data, request_file_size)
addToDb('SS', file_name, 'text/xml', 'IS.Result', returnValues['bu'], 'Y', nakedResponse)
return Response(nakedResponse, mimetype='text/xml')
except APIError as error:
return Response(error.errorResponse, mimetype='text/xml')
#dbConnect.app.route('/services/tracking/getShipmentStatus', methods=['POST'])
def parsexml2():
try:
returnValues = parseNWRequest(request.data, 'SS')
checkAuthorization(returnValues['bu'], request.headers['Authorization'], 'SS')
uri = getURI('SS')
search = {'bu': returnValues['bu'], 'starttime': returnValues['start'], 'endtime': returnValues['end'],
'requestid': returnValues['requestid'], 'pagesize': returnValues['page']}
depascoResponse = sendDepascoRequest(uri, search, 'SS')
nakedResponse = transformXML(depascoResponse.content, 'transformTracking.xsl')
# Write Request to db
file_name = 'WS_SS_' + returnValues['bu']
request_file_size = request.headers['Content-Length']
addToDb('SS', file_name, 'text/xml', 'SS.Request', returnValues['bu'], 'Y', request.data, request_file_size)
addToDb('SS', file_name, 'text/xml', 'SS.Result', returnValues['bu'], 'Y', nakedResponse)
return Response(nakedResponse, mimetype='text/xml')
except APIError as error:
return Response(error.errorResponse, mimetype='text/xml')
if __name__ == '__main__':
dbConnect.app.run(host='localhost', port=int("5010"), debug=True)
My custom exception class:
class APIError(Exception):
def __init__(self, errorResponse):
self.errorResponse = errorResponse
An example of a function that raises my custom exception:
def validateDate(startDate, endDate):
try:
datetime.strptime(startDate, '%Y-%m-%dT%H:%M:%S')
datetime.strptime(endDate, '%Y-%m-%dT%H:%M:%S')
except ValueError:
raise APIError(generateXMLErrorResponse('Invalid format for startDate/endDate', 'SS'))
else:
return True
Another way is to raise exceptions and handle them.
import myexceptions
try:
# a complicated hunk of stuff that raises exceptions when
# things don't work out. The exceptions can be raised down
# inside functions, if breaking this lump into functions
# makes it easier to follow.
except MyExceptionA:
return Response( ...) # appropriate response for error A
except MyExceptionB
return Response( ...) # appropriate response for error B
except ...
I like to leave my methods ASAP, so in your case, it would look like:
if !X:
return...
if !Y:
return...
if !Z:
return...
though generally, I prefer
if X:
return...
if Y:
return...
Here is my class:
class WorkflowsCloudant(cloudant.Account):
def __init__(self, account_id):
super(WorkflowsCloudant, self).__init__(settings.COUCH_DB_ACCOUNT_NAME,
auth=(settings.COUCH_PUBLIC_KEY, settings.COUCH_PRIVATE_KEY))
self.db = self.database(settings.COUCH_DB_NAME)
self.account_id = account_id
def get_by_id(self, key, design='by_workflow_id', view='by_workflow_id', limit=None):
params = dict(key=key, include_docs=True, limit=limit)
docs = self.db.design(design).view(view, params=params)
if limit is 1:
doc = [doc['doc'] for doc in docs]
if doc:
workflow = doc[0]
if workflow.get("account_id") != self.account_id:
raise InvalidAccount("Invalid Account")
return workflow
else:
raise NotFound("Autoresponder Cannot Be Found")
return docs
Here is my test:
def test_get_by_id_single_invalid_account(self):
self.klass.account_id = 200
self.klass.db = mock.MagicMock()
self.klass.db.design.return_value.view.return_value = [{
'doc': test_workflow()
}]
# wc.get_by_id = mock.MagicMock(side_effect=InvalidAccount("Invalid Account"))
with self.assertRaises(InvalidAccount("Invalid Account")) as context:
self.klass.get_by_id('workflow_id', limit=1)
self.assertEqual('Invalid Account', str(context.exception))
I'm trying to get the above test to simple raise the exception of InvalidAccount but I'm unsure how to mock out the self.db.design.view response. That's what's causing my test to fail because it's trying to make a real call out
I think this is what you want.
def test_get_by_id_single_invalid_account(self):
self.klass.account_id = 200
self.klass.db = mock.MagicMock()
self.klass.db.design = mock.MagicMock()
view_mock = mock.MagicMock()
view_mock.return_value =[{
'doc': test_workflow()
}]
self.klass.db.design.return_value.view = view_mock
# wc.get_by_id = mock.MagicMock(side_effect=InvalidAccount("Invalid Account"))
with self.assertRaises(InvalidAccount("Invalid Account")) as context:
self.klass.get_by_id('workflow_id', limit=1)
self.assertEqual('Invalid Account', str(context.exception))