I am writing Unit test for Python app that connects to Mysql using MySQLdb .
There is a function that connects to Mysql db and returns the connection object.
def connect_to_database():
conn = MySQLdb.connect(host=db_pb2['mysql_host'],
user=db_pb2['mysql_user'],
passwd=db_pb2['mysql_password'],
db=db_pb2['mysql_db'])
return conn
There is one more function that executes the query using the above connection
def execute_query():
cur = connect_to_database().cursor()
a = cur.execute("query")
if a > 0:
result = cur.fetchall()
return result
I have written #patch to mock the return value from the cur.fetchall() and cur.execute() methods
#patch('application.module1.data_adapters.connect_to_database')
def test_daily_test_failures(self, db_connection):
db_connection.cursor().execute.return_value = 1
db_connection.cursor().fetchall. \
return_value = ((1,5,6),)
self.assertEqual((execute_query(),
((1,5,6),))
I get the following error:
if a > 0:
TypeError: '<=' not supported between instances of 'MagicMock' and 'int'
Seems like the return value in patch function is not working as expected
Your patching requires a bit more effort.
#patch('application.module1.data_adapters.connect_to_database')
def test_daily_test_failures(self, connect_to_database):
db_connection = MagicMock()
connect_to_database.return_value = db_connection # 1.
mock_cursor = MagicMock()
db_connection.cursor.return_value = mock_cursor # 2.
mock_cursor.fetchall.return_value = ((1,5,6),)
self.assertEqual((execute_query(),
((1,5,6),))
You're patching the symbol connect_to_database - that does not mock the call to the function. You need to specify that "when that symbol is called, do this"
When mocking function calls, you can't do it like this: db_connection.cursor(). - you need to make db_connection.cursor return a mock, & then specify the .return_value for that mock object
Related
I had created a simple example to illustrate my issue. First is the setup say mydummy.py:
class TstObj:
def __init__(self, name):
self.name = name
def search(self):
return self.name
MyData = {}
MyData["object1"] = TstObj("object1")
MyData["object2"] = TstObj("object2")
MyData["object3"] = TstObj("object3")
def getObject1Data():
return MyData["object1"].search()
def getObject2Data():
return MyData["object2"].search()
def getObject3Data():
return MyData["object3"].search()
def getExample():
res = f"{getObject1Data()}{getObject2Data()}{getObject3Data()}"
return res
Here is the test that failed.
def test_get_dummy1():
dummy.MyData = MagicMock()
mydummy.MyData["object1"].search.side_effect = ["obj1"]
mydummy.MyData["object2"].search.side_effect = ["obj2"]
mydummy.MyData["object3"].search.side_effect = ["obj3"]
assert mydummy.getExample() == "obj1obj2obj3"
The above failed with run time error:
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/unittest/mock.py:1078: StopIteration
Here is the test that passed:
def test_get_dummy2():
dummy.MyData = MagicMock()
mydummy.MyData["object1"].search.side_effect = ["obj1", "obj2", "obj3"]
assert mydummy.getExample() == "obj1obj2obj3"
Am I missing something? I would have expected test_get_dummy1() to work and test_get_dummy2() to fail and not vice versa. Where and how can I find/learn more information about mocking to explain what is going on...
MyData["object1"] is converted to this function call: MyData.__getitem__("object1"). When you call your getExample method, the __getitem__ method is called 3 times with 3 parameters ("object1", "object2", "object3").
To mock the behavior you could have written your test like so:
def test_get_dummy_alternative():
mydummy.MyData = MagicMock()
mydummy.MyData.__getitem__.return_value.search.side_effect = ["obj1", "obj2", "obj3"]
assert mydummy.getExample() == "obj1obj2obj3"
Note the small change from your version: mydummy.MyData["object1"]... became: mydummy.MyData.__getitem__.return_value.... This is the regular MagicMock syntax - we want to to change the return value of the __getitem__ method.
BONUS:
I often struggle with mock syntax and understanding what's happening under the hood. This is why I wrote a helper library: the pytest-mock-generator. It can show you the actual calls made to the mock object.
To use it in your case you could have added this "exploration test":
def test_get_dummy_explore(mg):
mydummy.MyData = MagicMock()
mydummy.getExample()
mg.generate_asserts(mydummy.MyData, name='mydummy.MyData')
When you execute this test, the following output is printed to the console, which contains all the asserts to the actual calls to the mock:
from mock import call
mydummy.MyData.__getitem__.assert_has_calls(calls=[call('object1'),call('object2'),call('object3'),])
mydummy.MyData.__getitem__.return_value.search.assert_has_calls(calls=[call(),call(),call(),])
mydummy.MyData.__getitem__.return_value.search.return_value.__str__.assert_has_calls(calls=[call(),call(),call(),])
You can easily derive from here what has to be mocked.
I'm trying to write a unit test to the following function:
import sqlite3
def match_user_id_with_image_uid(image_uid):
# Given the note id, confirm if the current user is the owner of the note which is
# being operated.
_conn = sqlite3.connect(image_db_file_location)
_c = _conn.cursor()
command = "SELECT owner FROM images WHERE uid = '" + image_uid + "';"
_c.execute(command)
result = _c.fetchone()[0]
_conn.commit()
_conn.close()
return result
Pytest's mocker won't allow to mock sqlite3 functions with the following error:
can't set attributes of built-in/extension type 'sqlite3.Cursor'
I'm using mocker like so:
def test_match_user_to_image(mocker):
m = mocker.patch('sqlite3.connect')
m.return_value = mocker.Mock()
Other than that, I wonder that is the best practice to mock out those function? Do I want to patch each and every function (execute, cursor, commit, close).
Thanks!
I have a class which is intended to create an IBM Cloud Object Storage object. There are 2 functions I can use for initialization : resource() and client(). In the init function there is an object_type parameter which will be used to decide which function to call.
class ObjectStorage:
def __init__(self, object_type: str, endpoint: str, api_key: str, instance_crn: str, auth_endpoint: str):
valid_object_types = ("resource", "client")
if object_type not in valid_object_types:
raise ValueError("Object initialization error: Status must be one of %r." % valid_object_types)
method_type = getattr(ibm_boto3, object_type)()
self._conn = method_type(
"s3",
ibm_api_key_id = api_key,
ibm_service_instance_id= instance_crn,
ibm_auth_endpoint = auth_endpoint,
config=Config(signature_version="oauth"),
endpoint_url=endpoint,
)
#property
def connect(self):
return self._conn
If I run this, I receive the following error:
TypeError: client() missing 1 required positional argument: 'service_name'
If I use this in a simple function and call it by using ibm_boto3.client() or ibm_boto3.resource(), it works like a charm.
def get_cos_client_connection():
COS_ENDPOINT = "xxxxx"
COS_API_KEY_ID = "yyyyy"
COS_INSTANCE_CRN = "zzzzz"
COS_AUTH_ENDPOINT = "----"
cos = ibm_boto3.client("s3",
ibm_api_key_id=COS_API_KEY_ID,
ibm_service_instance_id=COS_INSTANCE_CRN,
ibm_auth_endpoint=COS_AUTH_ENDPOINT,
config=Config(signature_version="oauth"),
endpoint_url=COS_ENDPOINT
)
return cos
cos = get_cos_client_connection()
It looks like it calls the client function on this line, but I am not sure why:
method_type = getattr(ibm_boto3, object_type)()
I tried using:
method_type = getattr(ibm_boto3, lambda: object_type)()
but it was a silly move.
The client function looks like this btw:
def client(*args, **kwargs):
"""
Create a low-level service client by name using the default session.
See :py:meth:`ibm_boto3.session.Session.client`.
"""
return _get_default_session().client(*args, **kwargs)
which refers to:
def client(self, service_name, region_name=None, api_version=None,
use_ssl=True, verify=None, endpoint_url=None,
aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None,
ibm_api_key_id=None, ibm_service_instance_id=None, ibm_auth_endpoint=None,
auth_function=None, token_manager=None,
config=None):
return self._session.create_client(
service_name, region_name=region_name, api_version=api_version,
use_ssl=use_ssl, verify=verify, endpoint_url=endpoint_url,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
ibm_api_key_id=ibm_api_key_id, ibm_service_instance_id=ibm_service_instance_id,
ibm_auth_endpoint=ibm_auth_endpoint, auth_function=auth_function,
token_manager=token_manager, config=config)
Same goes for resource()
If you look at the stracktrace, it will probably point to this line:
method_type = getattr(ibm_boto3, object_type)()
And not the one after where you actually call it. The reason is simple, those last two parenthese () mean you're calling the function you just retrieved via getattr.
So simply do this:
method_type = getattr(ibm_boto3, object_type)
Which means that method_type is actually the method from the ibm_boto3 object you're interested in.
Can confirm that by either debugging using import pdb; pdb.set_trace() and inspect it, or just add a print statement:
print(method_type)
I'm using python 3.5.2 for my project. I installed MySQLdb via pip for Mysql connection.
Code:
import MySQLdb
class DB:
def __init__(self):
self.con = MySQLdb.connect(host="127.0.0.1", user="*", passwd="*", db="*")
self.cur = self.con.cursor()
def query(self, q):
r = self.cur.execute(q)
return r
def test():
db = DB()
result = db.query('SELECT id, name FROM test')
return print(result)
def test1():
db = DB()
result = db.query('SELECT id, name FROM test').fetchall()
for res in result:
id, name = res
print(id, name)
test()
test1()
#test output >>> '3'
#test1 output >>> AttributeError: 'int' object has no attribute 'fetchall'
Test table:
id | name
1 | 'test'
2 | 'test2'
3 | 'test3'
Please read this link:http://mysql-python.sourceforge.net/MySQLdb.html
At this point your query has been executed and you need to get the
results. You have two options:
r=db.store_result()
...or... r=db.use_result() Both methods return a result object. What's the difference? store_result() returns the entire result set to
the client immediately. If your result set is really large, this could
be a problem. One way around this is to add a LIMIT clause to your
query, to limit the number of rows returned. The other is to use
use_result(), which keeps the result set in the server and sends it
row-by-row when you fetch. This does, however, tie up server
resources, and it ties up the connection: You cannot do any more
queries until you have fetched all the rows. Generally I recommend
using store_result() unless your result set is really huge and you
can't use LIMIT for some reason.
def test1():
db = DB()
db.query('SELECT id, name FROM test')
result = db.cur.fetchall()
for res in result:
id, name = res
print(id, name)
cursor.execute() will return the number of rows modified or retrieved, just like in PHP. Have you tried to return the fetchall() like so?
def query(self, q):
r = self.cur.execute(q).fetchall()
return r
See here for more documentation: https://ianhowson.com/blog/a-quick-guide-to-using-mysql-in-python/
i have query like this
query = Notification.query(db.func.count(Notification.id))
query = query.filter(Notification.read == False)
query = query.filter(Notification.id == recv_id)
return query.all()
and i got error like this
query = Notification.query(db.func.count(Notification.id))
TypeError: 'BaseQuery' object is not callable
please help, thanks
Your first line raises the error. query is an instance of BaseQuery and its not callable.
What you are trying to do is similar to:
class A(object):
pass
a_obj = A()
print a_obj()
You can't call an instance.
You should call some method on the instance.
Not sure why you require the first line in your code.
You can do something like:
Notification.query.filter(Notification.read == False)