I am writing unit test case for a function which has multiple sql queries in it.I am using psycopg2 module and trying to mock the cursor.
app.py
import psycopg2
def my_function():
# all connection related code goes here ...
query = "SELECT name,phone FROM customer WHERE name='shanky'"
cursor.execute(query)
columns = [i[0] for i in cursor.description]
customer_response = []
for row in cursor.fetchall():
customer_response.append(dict(zip(columns, row)))
query = "SELECT name,id FROM product WHERE name='soap'"
cursor.execute(query)
columns = [i[0] for i in cursor.description]
product_response = []
for row in cursor.fetchall():
product_response.append(dict(zip(columns, row)))
return product_response
test.py
from pytest_mock import mocker
import psycopg2
def test_my_function(mocker):
from my_module import app
mocker.patch('psycopg2.connect')
#first query
mocked_cursor_one = psycopg2.connect.return_value.cursor.return_value
mocked_cursor_one.description = [['name'],['phone']]
mocked_cursor_one.fetchall.return_value = [('shanky', '347539593')]
mocked_cursor_one.execute.call_args == "SELECT name,phone FROM customer WHERE name='shanky'"
#second query
mocked_cursor_two = psycopg2.connect.return_value.cursor.return_value
mocked_cursor_two.description = [['name'],['id']]
mocked_cursor_two.fetchall.return_value = [('nirma', 12313)]
mocked_cursor_two.execute.call_args == "SELECT name,id FROM product WHERE name='soap'"
ret = app.my_function()
assert ret == {'name' : 'nirma', 'id' : 12313}
But the mocker always takes the last mock object (the second query).I have already tried multiple hacks, but that didn't work out. How can i mock multiple queries in one function and successfully pass the unit test case? Is it possible to write a unit test case in this fashion or do i need to split the queries in different functions?
After drilling a lot through the documentation, I was able to achieve this with the help of unittest mock decorator and side_effect which was suggested by #Pavel Vergeev.I was able to write a unit test case that is good enough to test the functionality.
from unittest import mock
from my_module import app
#mock.patch('psycopg2.connect')
def test_my_function(mocked_db):
mocked_cursor = mocked_db.return_value.cursor.return_value
description_mock = mock.PropertyMock()
type(mocked_cursor).description = description_mock
fetchall_return_one = [('shanky', '347539593')]
fetchall_return_two = [('nirma', 12313)]
descriptions = [
[['name'],['phone']],
[['name'],['id']]
]
mocked_cursor.fetchall.side_effect = [fetchall_return_one, fetchall_return_two]
description_mock.side_effect = descriptions
ret = app.my_function()
# assert whether called with mocked side effect objects
mocked_db.assert_has_calls(mocked_cursor.fetchall.side_effect)
# assert db query count is 2
assert mocked_db.return_value.cursor.return_value.execute.call_count == 2
# first query
query1 = """
SELECT name,phone FROM customer WHERE name='shanky'
"""
assert mocked_db.return_value.cursor.return_value.execute.call_args_list[0][0][0] == query1
# second query
query2 = """
SELECT name,id FROM product WHERE name='soap'
"""
assert mocked_db.return_value.cursor.return_value.execute.call_args_list[1][0][0] == query2
# assert the data of response
assert ret == {'name' : 'nirma', 'id' : 12313}
In addition to this if there are dynamic parameters in the query, that can be asserted too by the following method.
assert mocked_db.return_value.cursor.return_value.execute.call_args_list[0][0][1] = (parameter_name,)
so when the first query is executed, cursor.execute(query,(parameter_name,)) at call_args_list[0][0][0] the query can be obtained and asserted, at call_args_list[0][0][1] the first parameter parameter_name can be obtained. similarly incrementing the index, all the other params and different queries can be obtained and asserted.
Try side_effect argument of mocker.patch:
from unittest.mock import MagicMock
from pytest_mock import mocker
import psycopg2
def test_my_function(mocker):
from my_module import app
mocker.patch('psycopg2.connect', side_effect=[MagicMock(), MagicMock()])
#first query
mocked_cursor_one = psycopg2.connect().cursor.return_value # note that we actually call psyocpg2.connect -- it's important
mocked_cursor_one.description = [['name'],['phone']]
mocked_cursor_one.fetchall.return_value = [('shanky', '347539593')]
mocked_cursor_one.execute.call_args == "SELECT name,phone FROM customer WHERE name='shanky'"
#second query
mocked_cursor_two = psycopg2.connect().cursor.return_value
mocked_cursor_two.description = [['name'],['id']]
mocked_cursor_two.fetchall.return_value = [('nirma', 12313)]
mocked_cursor_two.execute.call_args == "SELECT name,id FROM product WHERE name='soap'"
assert mocked_cursor_one is not mocked_cursor_two # show that they are different
ret = app.my_function()
assert ret == {'name' : 'nirma', 'id' : 12313}
As per the docs, side_effect allows you to change returned value each time the patched object is called:
If you pass in an iterable, it is used to retrieve an iterator which must yield a value on every call. This value can either be an exception instance to be raised, or a value to be returned from the call to the mock
As I have mentioned in an earlier comment, the best way to make unit testing portable is to develop a complete Mock of your database's behavior.
I've done it for MySQL but it's pretty much the same for all databases.
First of all, I like using wrapper classes over the packages I'm using, it helps quickly change the database at one place instead of changing it everywhere in the code.
Here's a samople of what I use as a wrapper:
Now, you would need to Mock this MySQL class:
# _database.py
# -----------------------------------------------------------------------------
# Database Metaclass
# -----------------------------------------------------------------------------
"""Metaclass for Database implementation.
"""
# -----------------------------------------------------------------------------
import logging
logger = logging.getLogger(__name__)
class Database:
"""Database Metaclass"""
def __init__(self, connect_func, **kwargs):
self.connection = connect_func(**kwargs)
def execute(self, statement, fetchall=True):
"""Execute a statement.
Execute the statement passed as arugment.
Args:
statement (str): SQL Query or Command to execute.
Returns:
set: List of returned objects by the cursor.
"""
cursor = self.connection.cursor()
logger.debug(f"Executing: {statement}")
cursor.execute(statement)
if fetchall:
return cursor.fetchall()
else:
return cursor.fetchone()
def __del__(self):
"""Close connection on object deletion."""
self.connection.close()
And the mysql module:
# mysql.py
# -*- coding: utf-8 -*-
# -----------------------------------------------------------------------------
# MySQL Database Class
# -----------------------------------------------------------------------------
"""Class for MySQL Database connection."""
# -----------------------------------------------------------------------------
import logging
import mysql.connector
from . import _database
logger = logging.getLogger(__name__)
class MySQL(_database.Database):
"""Snowflake Database Class Wrapper.
Attributes:
connection (obj): Object returned from mysql.connector.connect
"""
def __init__(self, autocommit=True, **kwargs):
super().__init__(connect_func=mysql.connector.connect, **kwargs)
self.connection.autocommit = autocommit
Instantiate like: db = MySQL(user='...', password='...', ...)
Here's the data file:
# database_mock_data.json
{
"customer": {
"name": [
"shanky",
"nirma"
],
"phone": [
123123123,
232342342
]
},
"product": {
"name": [
"shanky",
"nirma"
],
"id": [
1,
2
]
}
}
The mocks.py
# mocks.py
import json
import re
from . import mysql
_MOCK_DATA_PATH = 'database_mock_data.json'
class MockDatabase(MySQL):
"""
"""
def __init__(self, **kwargs):
self.connection = MockConnection()
class MockConnection:
"""
Mock the connection object by returning a mock cursor.
"""
#staticmethod
def cursor():
return MockCursor()
class MockCursor:
"""
The Mocked Cursor
A call to execute() will initiate the read on the json data file and will set
the description object (containing the column names usually).
You could implement an update function like `_json_sql_update()`
"""
def __init__(self):
self.description = []
self.__result = None
def execute(self, statement):
data = _read_json_file(_MOCK_DATA_PATH)
if statement.upper().startswith('SELECT'):
self.__result, self.description = _json_sql_select(data, statement)
def fetchall(self):
return self.__result
def fetchone(self):
return self.__result[0]
def _json_sql_select(data, query):
"""
Takes a dictionary and returns the values from a sql query.
NOTE: It does not work with other where clauses than '='.
Also, note that a where statement is expected.
:param (dict) data: Dictionary with the following structure:
{
'tablename': {
'column_name_1': ['value1', 'value2],
'column_name_2': ['value1', 'value2],
...
},
...
}
:param (str) query: An update sql query as:
`update TABLENAME set column_name_1='value'
where column_name_2='value1'`
:return: List of list of values and header description
"""
try:
match = (re.search("select(.*)from(.*)where(.*)[;]?", query,
re.IGNORECASE | re.DOTALL).groups())
except AttributeError:
print("Select Query pattern mismatch... {}".format(query))
raise
# Parse values from the select query
tablename = match[1].strip().upper()
columns = [col.strip().upper() for col in match[0].split(",")]
if columns == ['*']:
columns = data[tablename].keys()
where = [cmd.upper().strip().replace(' ', '')
for cmd in match[2].split('and')]
# Select values
selected_values = []
nb_lines = len(list(data[tablename].values())[0])
for i in range(nb_lines):
is_match = True
for condition in where:
key_condition, value_condition = (_clean_string(condition)
.split('='))
if data[tablename][key_condition][i].upper() != value_condition:
# Set flag to yes
is_match = False
if is_match:
sub_list = []
for column in columns:
sub_list.append(data[tablename][column][i])
selected_values.append(sub_list)
# Usual descriptor has nested list
description = zip(columns, ['...'] * len(columns))
return selected_values, description
def _read_json_file(file_path):
with open(file_path, 'r') as f_in:
data = json.load(f_in)
return data
And then you have your test in a test_module_yourfunction.py
import pytest
def my_function(db, query):
# Code goes here
#pytest.fixture
def db_connection():
return MockDatabase()
#pytest.mark.parametrize(
("query", "expected"),
[
("SELECT name,phone FROM customer WHERE name='shanky'", {'name' : 'nirma', 'id' : 12313}),
("<second query goes here>", "<second result goes here>")
]
)
def test_my_function(db_connection, query, expected):
assert my_function(db_connection, query) == expected
Now I'm sorry if you can't copy/paste this code and make it work, but you get the feeling :) just trying to help
Related
#file utils.py
def update_configuration(configuration, mysql_client):
query = "SELECT * from some database"
cursor = mysql_client.execute_query(query)
function_mapping = cursor.fetchall()
cursor.close()
configuration["mapping"] = {}
for (display_name, formula_name) in function_mapping:
configuration["formula_mapping"][formula_name] = display_name
#main.py
def main_func(configuration):
mysql_client = get_mysql_connection()
ut.update_configuration(configuration, mysql_client) #ut means utils.py
#test.py
#mock.patch("src.utils.cursor.fetchall",return_value = [1,2,3,4])
#mock.patch("src.main.get_sql_connection", return_value = mock.Mock())
def test_initiate_calc(self, dummy1):
# perform integration testing on "main_func"
The project structure is as shown below
--project
--src
--tests
--test_main.py
--main.py
--utils.py
When I try to mock the cursor.fetchall() to return some value I get an error saying "ModuleNotFoundError: No module named 'src.utils.cursor'; 'src.utils' is not a package"
Need help in finding a way to get the function_mapping = cursor.fetchall() value some return value
You can't patch the local variable cursor, nor do you need to. Configure the get_mysql_connection mock properly instead.
#mock.patch("src.main.get_sql_connection")
def test_initiate_calc(self, mock_conn):
mock_conn.return_value.execute_query.return_value.fetchall.return_value = [1,2,3,4]
# perform integration testing on "main_func"
You can also try cursor.fetchall.side_effect = "some-value"
Something like this:
#patch('connector')
def test_create_table(self, connector):
connection = MagicMock()
cursor = MagicMock()
connector.connect.return_value = connection
connection.cursor.return_value = cursor
cursor.fetchall.side_effect = "List of tuple"
connector.connect.assert_called_with("creds")
cursor.execute.assert_called_with("statements - you - are - trying - to - execute")
For multiple statements use assert_has_calls.
Disclaimer: Yes I am well aware this is a mad attempt.
Use case:
I am reading from a config file to run a test collection where each such collection comprises of set of test cases with corresponding results and a fixed setup.
Flow (for each test case):
Setup: wipe and setup database with specific test case dataset (glorified SQL file)
load expected test case results from csv
execute collections query/report
compare results.
Sounds good, except the people writing the test cases are more from a tech admin perspective, so the goal is to enable this without writing any python code.
code
Assume these functions exist.
# test_queries.py
def gather_collections(): (collection, query, config)
def gather_cases(collection): (test_case)
def load_collection_stubs(collection): None
def load_case_dataset(test_case): None
def read_case_result_csv(test_case): [csv_result]
def execute(query): [query_result]
class TestQueries(unittest.TestCase):
def setup_method(self, method):
collection = self._item.name.replace('test_', '')
load_collection_stubs(collection)
# conftest.py
import pytest
#pytest.hookimpl(hookwrapper=True)
def pytest_runtest_protocol(item, nextitem):
item.cls._item = item
yield
Example Data
Collection stubs / data (setting up of environment)
-- stubs/test_setup_log.sql
DROP DATABASE IF EXISTS `test`;
CREATE DATABASE `test`;
USE test;
CREATE TABLE log (`id` int(9) NOT NULL AUTO_INCREMENT, `timestamp` datetime NOT NULL DEFAULT NOW(), `username` varchar(100) NOT NULL, `message` varchar(500));
Query to test
-- queries/count.sql
SELECT count(*) as `log_count` from test.log where username = 'unicorn';
Test case 1 input data
-- test_case_1.sql
INSERT INTO log (`id`, `timestamp`, `username`, `message`)
VALUES
(1,'2020-12-18T11:23.01Z', 'unicorn', 'user logged in'),
(2,'2020-12-18T11:23.02Z', 'halsey', 'user logged off'),
(3,'2020-12-18T11:23.04Z', 'unicorn', 'user navigated to home')
Test case 1 expected result
test_case_1.csv
log_count
2
Attempt 1
for collection, query, config in gather_collections():
test_method_name = 'test_{}'.format(collection)
LOGGER.debug("collections.{}.test - {}".format(collection, config))
cases = gather_cases(collection)
LOGGER.debug("collections.{}.cases - {}".format(collection, cases))
setattr(
TestQueries,
test_method_name,
pytest.mark.parametrize(
'case_name',
cases,
ids=cases
)(
lambda self, case_name: (
load_case_dataset(case_name),
self.assertEqual(execute(query, case_name), read_case_result_csv( case_name))
)
)
)
Attempt 2
for collection, query, config in gather_collections():
test_method_name = 'test_{}'.format(collection)
LOGGER.debug("collections.{}.test - {}".format(collection, config))
setattr(
TestQueries,
test_method_name,
lambda self, case_name: (
load_case_dataset(case_name),
self.assertEqual(execute(query, case_name), read_case_result_csv(case_name))
)
)
def pytest_generate_tests(metafunc):
collection = metafunc.function.__name__.replace('test_', '')
# FIXME logs and id setting not working
cases = gather_cases(collection)
LOGGER.info("collections.{}.pytest.cases - {}".format(collection, cases))
metafunc.parametrize(
'case_name',
cases,
ids=cases
)
So I figured it out, but it's not the most elegant solution.
Essentially you use one function and then use some of pytests hooks to change the function names for reporting.
There are numerous issues, e.g. if you don't use pytest.param to pass the parameters to parametrize then you do not have the required information available.
Also the method passed to setup_method is not aware of the actual iteration being run when its called, so I had to hack that in with the iter counter.
# test_queries.py
def gather_tests():
global TESTS
for test_collection_name in TESTS.keys():
LOGGER.debug("collections.{}.gather - {}".format(test_collection_name, TESTS[test_collection_name]))
query = path.join(SRC_DIR, TESTS[test_collection_name]['query'])
cases_dir = TESTS[test_collection_name]['cases']
result_sets = path.join(TEST_DIR, cases_dir, '*.csv')
for case_result_csv in glob.glob(result_sets):
test_case_name = path.splitext(path.basename(case_result_csv))[0]
yield test_case_name, query, test_collection_name, TESTS[test_collection_name]
class TestQueries():
iter = 0
def setup_method(self, method):
method_name = method.__name__ # or self._item.originalname
global TESTS
if method_name == 'test_scripts_reports':
_mark = next((m for m in method.pytestmark if m.name == 'parametrize' and 'collection_name' in m.args[0]), None)
if not _mark:
raise Exception('test {} missing collection_name parametrization'.format(method_name)) # nothing to do here
_args = _mark.args[0]
_params = _mark.args[1]
LOGGER.debug('setup_method: _params - {}'.format(_params))
if not _params:
raise Exception('test {} missing pytest.params'.format(method_name)) # nothing to do here
_currparams =_params[self.iter]
self.iter += 1
_argpos = [arg.strip() for arg in _args.split(',')].index('collection_name')
collection = _currparams.values[_argpos]
LOGGER.debug('collections.{}.setup_method[{}] - {}'.format(collection, self.iter, _currparams))
load_collection_stubs(collection)
#pytest.mark.parametrize(
'case_name, collection_query, collection_name, collection_config',
[pytest.param(*c, id='{}:{}'.format(c[2], c[0])) for c in gather_tests()]
)
def test_scripts_reports(self, case_name, collection_query, collection_name, collection_config):
if not path.isfile(collection_query):
pytest.skip("report query does not exist: {}".format(collection_query))
LOGGER.debug("test_scripts_reports.{}.{} - ".format(collection_name, case_name))
load_case_dataset( case_name)
assert execute(collection_query, case_name) == read_case_result_csv(case_name)
Then to make the test ids more human you can do this
# conftest.py
def pytest_collection_modifyitems(items):
# https://stackoverflow.com/questions/61317809/pytest-dynamically-generating-test-name-during-runtime
for item in items:
if item.originalname == 'test_scripts_reports':
item._nodeid = re.sub(r'::\w+::\w+\[', '[', item.nodeid)
the result with the following files:
stubs/
00-wipe-db.sql
setup-db.sql
queries/
report1.sql
collection/
report1/
case1.sql
case1.csv
case2.sql
case2.csv
# results (with setup_method firing before each test and loading the appropriate stubs as per configuration)
FAILED test_queries.py[report1:case1]
FAILED test_queries.py[report1:case2]
My general idea is, that I have a or more test functions which can call another function.
So:
test_function does some things, then calls another function
Called other function will do something, then return a result and
Wait for the test_function to finish, then do everything that is to be done after the "return"
Of course, that does not work with a literal return, so my question is, if there is a possibility I do not see.
A practical example:
Test function in pytest - test_config is a fixture:
def test_load_into_database(test_config):
logger.info('Testing load_into_database')
data = {
'filename': 'file_0'
}
accessor = list(database_accessor(test_config))[0]
access_data = accessor['session'].query(accessor['table']).all()
assert access_data[0].filename != data['filename']
accessor['session'].add(accessor['table'](
filename='file_0'
))
accessor['session'].commit()
access_data = accessor['session'].query(accessor['table']).all()
assert access_data[0].filename == data['filename']
This test function calls another one:
def database_accessor(cfg):
setup = 'mysql+pymysql://{}:{}#{}/{}'.format(
cfg['database_user'],
cfg['database_passwd'],
cfg['database_host'],
cfg['database_db']
)
Base = automap_base()
engine = create_engine(setup, echo=False)
Base.prepare(engine, reflect=True)
table = Base.classes.get(cfg['database_table'])
session = Session(engine)
->return and wait<- {
'session': session,
'table': table
}
session.close()
with engine.connect() as con:
con.execution_options(autocommit=True).execute("TRUNCATE TABLE {}".format(cfg['database_table']))
What I want is, that the database_accessor returns the dictionary, then waits for the underlying function (in this case the test_function) to finish, then resumes.
That way, it is possible to use a variable number of test functions with the same database_accessor without executing all the different test functions in the database_accessor.
Callbacks
I know that there is a way, with a callback function, but that doubles my functions, which I do not want. E.g.
def test_load_into_database():
database_accessor(load_into_database_1, var1, var2)
def database_accessor(function, args*):
# do stuff
function(args, stuff)
# do other stuff
def load_into_database_1(args, stuff):
# do something
I think you'd want to make database_accessor a context manager:
with database_accessor(test_config) as accessor:
access_data = accessor['session'].query(accessor['table']).all()
assert access_data[0].filename != data['filename']
accessor['session'].add(accessor['table'](
filename='file_0'
))
accessor['session'].commit()
access_data = accessor['session'].query(accessor['table']).all()
assert access_data[0].filename == data['filename']
The implementation would looks something like this:
from contextlib import contextmanager
#contextmanager
def database_accessor(cfg):
setup = 'mysql+pymysql://{}:{}#{}/{}'.format(
cfg['database_user'],
cfg['database_passwd'],
cfg['database_host'],
cfg['database_db']
)
Base = automap_base()
engine = create_engine(setup, echo=False)
Base.prepare(engine, reflect=True)
table = Base.classes.get(cfg['database_table'])
session = Session(engine)
try:
yield {
'session': session,
'table': table
}
finally:
session.close()
with engine.connect() as con:
con.execution_options(autocommit=True).execute("TRUNCATE TABLE {}".format(cfg['database_table']))
In this case, I would modify it slightly to yield a tuple instead of a dict:
yield session, table
and then, when using it you could do:
with database_accessor(test_config) as (session, table):
Brand new to this library
Here is the call stack of my mocked object
[call(),
call('test'),
call().instance('test'),
call().instance().database('test'),
call().instance().database().snapshot(),
call().instance().database().snapshot().__enter__(),
call().instance().database().snapshot().__enter__().execute_sql('SELECT * FROM users'),
call().instance().database().snapshot().__exit__(None, None, None),
call().instance().database().snapshot().__enter__().execute_sql().__iter__()]
Here is the code I have used
#mock.patch('testmodule.Client')
def test_read_with_query(self, mock_client):
mock = mock_client()
pipeline = TestPipeline()
records = pipeline | ReadFromSpanner(TEST_PROJECT_ID, TEST_INSTANCE_ID, self.database_id).with_query('SELECT * FROM users')
pipeline.run()
print mock_client.mock_calls
exit()
I want to mock this whole stack that eventually it gives me some fake data which I will provide as a return value.
The code being tested is
spanner_client = Client(self.project_id)
instance = spanner_client.instance(self.instance_id)
database = instance.database(self.database_id)
with database.snapshot() as snapshot:
results = snapshot.execute_sql(self.query)
So my requirements is that the results variable should contain the data I will provide.
How can I provide a return value to such a nested calls
Thanks
Create separate MagicMock instances for the instance, database and snapshot objects in the code under test. Use return_value to configure the return values of each method. Here is an example. I simplified the method under test to just be a free standing function called mut.
# test_module.py : the module under test
class Client:
pass
def mut(project_id, instance_id, database_id, query):
spanner_client = Client(project_id)
instance = spanner_client.instance(instance_id)
database = instance.database(database_id)
with database.snapshot() as snapshot:
results = snapshot.execute_sql(query)
return results
# test code (pytest)
from unittest.mock import MagicMock
from unittest import mock
from test_module import mut
#mock.patch('test_module.Client')
def test_read_with_query(mock_client_class):
mock_client = MagicMock()
mock_instance = MagicMock()
mock_database = MagicMock()
mock_snapshot = MagicMock()
expected = 'fake query results'
mock_client_class.return_value = mock_client
mock_client.instance.return_value = mock_instance
mock_instance.database.return_value = mock_database
mock_database.snapshot.return_value = mock_snapshot
mock_snapshot.execute_sql.return_value = expected
mock_snapshot.__enter__.return_value = mock_snapshot
observed = mut(29, 42, 77, 'select *')
mock_client_class.assert_called_once_with(29)
mock_client.instance.assert_called_once_with(42)
mock_instance.database.assert_called_once_with(77)
mock_database.snapshot.assert_called_once_with()
mock_snapshot.__enter__.assert_called_once_with()
mock_snapshot.execute_sql.assert_called_once_with('select *')
assert observed == expected
This test is kind of portly. Consider breaking it apart by using a fixture and a before function that sets up the mocks.
Either set the value directly to your Mock instance (those enters and exit should have not seen) with:
mock.return_value.instance.return_value.database.return_value.snapshot.return_value.execute_sql.return_value = MY_MOCKED_DATA
or patch and set return_value to target method, something like:
#mock.patch('database_engine.execute_sql', return_value=MY_MOCKED_DATA)
Is there an elegant way to do an INSERT ... ON DUPLICATE KEY UPDATE in SQLAlchemy? I mean something with a syntax similar to inserter.insert().execute(list_of_dictionaries) ?
ON DUPLICATE KEY UPDATE post version-1.2 for MySQL
This functionality is now built into SQLAlchemy for MySQL only. somada141's answer below has the best solution:
https://stackoverflow.com/a/48373874/319066
ON DUPLICATE KEY UPDATE in the SQL statement
If you want the generated SQL to actually include ON DUPLICATE KEY UPDATE, the simplest way involves using a #compiles decorator.
The code (linked from a good thread on the subject on reddit) for an example can be found on github:
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import Insert
#compiles(Insert)
def append_string(insert, compiler, **kw):
s = compiler.visit_insert(insert, **kw)
if 'append_string' in insert.kwargs:
return s + " " + insert.kwargs['append_string']
return s
my_connection.execute(my_table.insert(append_string = 'ON DUPLICATE KEY UPDATE foo=foo'), my_values)
But note that in this approach, you have to manually create the append_string. You could probably change the append_string function so that it automatically changes the insert string into an insert with 'ON DUPLICATE KEY UPDATE' string, but I'm not going to do that here due to laziness.
ON DUPLICATE KEY UPDATE functionality within the ORM
SQLAlchemy does not provide an interface to ON DUPLICATE KEY UPDATE or MERGE or any other similar functionality in its ORM layer. Nevertheless, it has the session.merge() function that can replicate the functionality only if the key in question is a primary key.
session.merge(ModelObject) first checks if a row with the same primary key value exists by sending a SELECT query (or by looking it up locally). If it does, it sets a flag somewhere indicating that ModelObject is in the database already, and that SQLAlchemy should use an UPDATE query. Note that merge is quite a bit more complicated than this, but it replicates the functionality well with primary keys.
But what if you want ON DUPLICATE KEY UPDATE functionality with a non-primary key (for example, another unique key)? Unfortunately, SQLAlchemy doesn't have any such function. Instead, you have to create something that resembles Django's get_or_create(). Another StackOverflow answer covers it, and I'll just paste a modified, working version of it here for convenience.
def get_or_create(session, model, defaults=None, **kwargs):
instance = session.query(model).filter_by(**kwargs).first()
if instance:
return instance
else:
params = dict((k, v) for k, v in kwargs.iteritems() if not isinstance(v, ClauseElement))
if defaults:
params.update(defaults)
instance = model(**params)
return instance
I should mention that ever since the v1.2 release, the SQLAlchemy 'core' has a solution to the above with that's built in and can be seen under here (copied snippet below):
from sqlalchemy.dialects.mysql import insert
insert_stmt = insert(my_table).values(
id='some_existing_id',
data='inserted value')
on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update(
data=insert_stmt.inserted.data,
status='U'
)
conn.execute(on_duplicate_key_stmt)
Based on phsource's answer, and for the specific use-case of using MySQL and completely overriding the data for the same key without performing a DELETE statement, one can use the following #compiles decorated insert expression:
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import Insert
#compiles(Insert)
def append_string(insert, compiler, **kw):
s = compiler.visit_insert(insert, **kw)
if insert.kwargs.get('on_duplicate_key_update'):
fields = s[s.find("(") + 1:s.find(")")].replace(" ", "").split(",")
generated_directive = ["{0}=VALUES({0})".format(field) for field in fields]
return s + " ON DUPLICATE KEY UPDATE " + ",".join(generated_directive)
return s
It's depends upon you. If you want to replace then pass OR REPLACE in prefixes
def bulk_insert(self,objects,table):
#table: Your table class and objects are list of dictionary [{col1:val1, col2:vale}]
for counter,row in enumerate(objects):
inserter = table.__table__.insert(prefixes=['OR IGNORE'], values=row)
try:
self.db.execute(inserter)
except Exception as E:
print E
if counter % 100 == 0:
self.db.commit()
self.db.commit()
Here commit interval can be changed to speed up or speed down
My way
import typing
from datetime import datetime
from sqlalchemy.dialects import mysql
class MyRepository:
def model(self):
return MySqlAlchemyModel
def upsert(self, data: typing.List[typing.Dict]):
if not data:
return
model = self.model()
if hasattr(model, 'created_at'):
for item in data:
item['created_at'] = datetime.now()
stmt = mysql.insert(getattr(model, '__table__')).values(data)
for_update = []
for k, v in data[0].items():
for_update.append(k)
dup = {k: getattr(stmt.inserted, k) for k in for_update}
stmt = stmt.on_duplicate_key_update(**dup)
self.db.session.execute(stmt)
self.db.session.commit()
Usage:
myrepo.upsert([
{
"field11": "value11",
"field21": "value21",
"field31": "value31",
},
{
"field12": "value12",
"field22": "value22",
"field32": "value32",
},
])
The other answers have this covered but figured I'd reference another good example for mysql I found in this gist. This also includes the use of LAST_INSERT_ID, which may be useful depending on your innodb auto increment settings and whether your table has a unique key. Lifting the code here for easy reference but please give the author a star if you find it useful.
from app import db
from sqlalchemy import func
from sqlalchemy.dialects.mysql import insert
def upsert(model, insert_dict):
"""model can be a db.Model or a table(), insert_dict should contain a primary or unique key."""
inserted = insert(model).values(**insert_dict)
upserted = inserted.on_duplicate_key_update(
id=func.LAST_INSERT_ID(model.id), **{k: inserted.inserted[k]
for k, v in insert_dict.items()})
res = db.engine.execute(upserted)
return res.lastrowid
ORM
use upset func based on on_duplicate_key_update
class Model():
__input_data__ = dict()
def __init__(self, **kwargs) -> None:
self.__input_data__ = kwargs
self.session = Session(engine)
def save(self):
self.session.add(self)
self.session.commit()
def upsert(self, *, ingore_keys = []):
column_keys = self.__table__.columns.keys()
udpate_data = dict()
for key in self.__input_data__.keys():
if key not in column_keys:
continue
else:
udpate_data[key] = self.__input_data__[key]
insert_stmt = insert(self.__table__).values(**udpate_data)
all_ignore_keys = ['id']
if isinstance(ingore_keys, list):
all_ignore_keys =[*all_ignore_keys, *ingore_keys]
else:
all_ignore_keys.append(ingore_keys)
udpate_columns = dict()
for key in self.__input_data__.keys():
if key not in column_keys or key in all_ignore_keys:
continue
else:
udpate_columns[key] = insert_stmt.inserted[key]
on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update(
**udpate_columns
)
# self.session.add(self)
self.session.execute(on_duplicate_key_stmt)
self.session.commit()
class ManagerAssoc(ORM_Base, Model):
def __init__(self, **kwargs):
self.id = idWorker.get_id()
column_keys = self.__table__.columns.keys()
udpate_data = dict()
for key in kwargs.keys():
if key not in column_keys:
continue
else:
udpate_data[key] = kwargs[key]
ORM_Base.__init__(self, **udpate_data)
Model.__init__(self, **kwargs, id = self.id)
....
# you can call it as following:
manager_assoc.upsert()
manager.upsert(ingore_keys = ['manager_id'])
Got a simpler solution:
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import Insert
#compiles(Insert)
def replace_string(insert, compiler, **kw):
s = compiler.visit_insert(insert, **kw)
s = s.replace("INSERT INTO", "REPLACE INTO")
return s
my_connection.execute(my_table.insert(replace_string=""), my_values)
I just used plain sql as:
insert_stmt = "REPLACE INTO tablename (column1, column2) VALUES (:column_1_bind, :columnn_2_bind) "
session.execute(insert_stmt, data)
Update Feb 2023: SQLAlchemy version 2 was recently released and supports on_duplicate_key_update in the MySQL dialect. Many many thanks to Federico Caselli of the SQLAlchemy project who helped me develop sample code in a discussion at https://github.com/sqlalchemy/sqlalchemy/discussions/9328
Please see https://stackoverflow.com/a/75538576/1630244
If it's ok to post the same answer twice (?) here is my small self-contained code example:
import sqlalchemy as db
import sqlalchemy.dialects.mysql as mysql
from sqlalchemy import delete, select, String
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
pass
class User(Base):
__tablename__ = "foo"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(30))
engine = db.create_engine('mysql+mysqlconnector://USER-NAME-HERE:PASS-WORD-HERE#localhost/SCHEMA-NAME-HERE')
conn = engine.connect()
# setup step 0 - ensure the table exists
Base().metadata.create_all(bind=engine)
# setup step 1 - clean out rows with id 1..5
del_stmt = delete(User).where(User.id.in_([1, 2, 3, 4, 5]))
conn.execute(del_stmt)
conn.commit()
sel_stmt = select(User)
users = list(conn.execute(sel_stmt))
print(f'Table size after cleanout: {len(users)}')
# setup step 2 - insert 4 rows
ins_stmt = mysql.insert(User).values(
[
{"id": 1, "name": "x"},
{"id": 2, "name": "y"},
{"id": 3, "name": "w"},
{"id": 4, "name": "z"},
]
)
conn.execute(ins_stmt)
conn.commit()
users = list(conn.execute(sel_stmt))
print(f'Table size after insert: {len(users)}')
# demonstrate upsert
ups_stmt = mysql.insert(User).values(
[
{"id": 1, "name": "xx"},
{"id": 2, "name": "yy"},
{"id": 3, "name": "ww"},
{"id": 5, "name": "new"},
]
)
ups_stmt = ups_stmt.on_duplicate_key_update(name=ups_stmt.inserted.name)
# if you want to see the compiled result
# x = ups_stmt.compile(dialect=mysql.dialect())
# print(x.string, x.construct_params())
conn.execute(ups_stmt)
conn.commit()
users = list(conn.execute(sel_stmt))
print(f'Table size after upsert: {len(users)}')
As none of these solutions seem all the elegant. A brute force way is to query to see if the row exists. If it does delete the row and then insert otherwise just insert. Obviously some overhead involved but it does not rely on modifying the raw sql and it works on non orm stuff.