Is there an elegant way to do an INSERT ... ON DUPLICATE KEY UPDATE in SQLAlchemy? I mean something with a syntax similar to inserter.insert().execute(list_of_dictionaries) ?
ON DUPLICATE KEY UPDATE post version-1.2 for MySQL
This functionality is now built into SQLAlchemy for MySQL only. somada141's answer below has the best solution:
https://stackoverflow.com/a/48373874/319066
ON DUPLICATE KEY UPDATE in the SQL statement
If you want the generated SQL to actually include ON DUPLICATE KEY UPDATE, the simplest way involves using a #compiles decorator.
The code (linked from a good thread on the subject on reddit) for an example can be found on github:
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import Insert
#compiles(Insert)
def append_string(insert, compiler, **kw):
s = compiler.visit_insert(insert, **kw)
if 'append_string' in insert.kwargs:
return s + " " + insert.kwargs['append_string']
return s
my_connection.execute(my_table.insert(append_string = 'ON DUPLICATE KEY UPDATE foo=foo'), my_values)
But note that in this approach, you have to manually create the append_string. You could probably change the append_string function so that it automatically changes the insert string into an insert with 'ON DUPLICATE KEY UPDATE' string, but I'm not going to do that here due to laziness.
ON DUPLICATE KEY UPDATE functionality within the ORM
SQLAlchemy does not provide an interface to ON DUPLICATE KEY UPDATE or MERGE or any other similar functionality in its ORM layer. Nevertheless, it has the session.merge() function that can replicate the functionality only if the key in question is a primary key.
session.merge(ModelObject) first checks if a row with the same primary key value exists by sending a SELECT query (or by looking it up locally). If it does, it sets a flag somewhere indicating that ModelObject is in the database already, and that SQLAlchemy should use an UPDATE query. Note that merge is quite a bit more complicated than this, but it replicates the functionality well with primary keys.
But what if you want ON DUPLICATE KEY UPDATE functionality with a non-primary key (for example, another unique key)? Unfortunately, SQLAlchemy doesn't have any such function. Instead, you have to create something that resembles Django's get_or_create(). Another StackOverflow answer covers it, and I'll just paste a modified, working version of it here for convenience.
def get_or_create(session, model, defaults=None, **kwargs):
instance = session.query(model).filter_by(**kwargs).first()
if instance:
return instance
else:
params = dict((k, v) for k, v in kwargs.iteritems() if not isinstance(v, ClauseElement))
if defaults:
params.update(defaults)
instance = model(**params)
return instance
I should mention that ever since the v1.2 release, the SQLAlchemy 'core' has a solution to the above with that's built in and can be seen under here (copied snippet below):
from sqlalchemy.dialects.mysql import insert
insert_stmt = insert(my_table).values(
id='some_existing_id',
data='inserted value')
on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update(
data=insert_stmt.inserted.data,
status='U'
)
conn.execute(on_duplicate_key_stmt)
Based on phsource's answer, and for the specific use-case of using MySQL and completely overriding the data for the same key without performing a DELETE statement, one can use the following #compiles decorated insert expression:
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import Insert
#compiles(Insert)
def append_string(insert, compiler, **kw):
s = compiler.visit_insert(insert, **kw)
if insert.kwargs.get('on_duplicate_key_update'):
fields = s[s.find("(") + 1:s.find(")")].replace(" ", "").split(",")
generated_directive = ["{0}=VALUES({0})".format(field) for field in fields]
return s + " ON DUPLICATE KEY UPDATE " + ",".join(generated_directive)
return s
It's depends upon you. If you want to replace then pass OR REPLACE in prefixes
def bulk_insert(self,objects,table):
#table: Your table class and objects are list of dictionary [{col1:val1, col2:vale}]
for counter,row in enumerate(objects):
inserter = table.__table__.insert(prefixes=['OR IGNORE'], values=row)
try:
self.db.execute(inserter)
except Exception as E:
print E
if counter % 100 == 0:
self.db.commit()
self.db.commit()
Here commit interval can be changed to speed up or speed down
My way
import typing
from datetime import datetime
from sqlalchemy.dialects import mysql
class MyRepository:
def model(self):
return MySqlAlchemyModel
def upsert(self, data: typing.List[typing.Dict]):
if not data:
return
model = self.model()
if hasattr(model, 'created_at'):
for item in data:
item['created_at'] = datetime.now()
stmt = mysql.insert(getattr(model, '__table__')).values(data)
for_update = []
for k, v in data[0].items():
for_update.append(k)
dup = {k: getattr(stmt.inserted, k) for k in for_update}
stmt = stmt.on_duplicate_key_update(**dup)
self.db.session.execute(stmt)
self.db.session.commit()
Usage:
myrepo.upsert([
{
"field11": "value11",
"field21": "value21",
"field31": "value31",
},
{
"field12": "value12",
"field22": "value22",
"field32": "value32",
},
])
The other answers have this covered but figured I'd reference another good example for mysql I found in this gist. This also includes the use of LAST_INSERT_ID, which may be useful depending on your innodb auto increment settings and whether your table has a unique key. Lifting the code here for easy reference but please give the author a star if you find it useful.
from app import db
from sqlalchemy import func
from sqlalchemy.dialects.mysql import insert
def upsert(model, insert_dict):
"""model can be a db.Model or a table(), insert_dict should contain a primary or unique key."""
inserted = insert(model).values(**insert_dict)
upserted = inserted.on_duplicate_key_update(
id=func.LAST_INSERT_ID(model.id), **{k: inserted.inserted[k]
for k, v in insert_dict.items()})
res = db.engine.execute(upserted)
return res.lastrowid
ORM
use upset func based on on_duplicate_key_update
class Model():
__input_data__ = dict()
def __init__(self, **kwargs) -> None:
self.__input_data__ = kwargs
self.session = Session(engine)
def save(self):
self.session.add(self)
self.session.commit()
def upsert(self, *, ingore_keys = []):
column_keys = self.__table__.columns.keys()
udpate_data = dict()
for key in self.__input_data__.keys():
if key not in column_keys:
continue
else:
udpate_data[key] = self.__input_data__[key]
insert_stmt = insert(self.__table__).values(**udpate_data)
all_ignore_keys = ['id']
if isinstance(ingore_keys, list):
all_ignore_keys =[*all_ignore_keys, *ingore_keys]
else:
all_ignore_keys.append(ingore_keys)
udpate_columns = dict()
for key in self.__input_data__.keys():
if key not in column_keys or key in all_ignore_keys:
continue
else:
udpate_columns[key] = insert_stmt.inserted[key]
on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update(
**udpate_columns
)
# self.session.add(self)
self.session.execute(on_duplicate_key_stmt)
self.session.commit()
class ManagerAssoc(ORM_Base, Model):
def __init__(self, **kwargs):
self.id = idWorker.get_id()
column_keys = self.__table__.columns.keys()
udpate_data = dict()
for key in kwargs.keys():
if key not in column_keys:
continue
else:
udpate_data[key] = kwargs[key]
ORM_Base.__init__(self, **udpate_data)
Model.__init__(self, **kwargs, id = self.id)
....
# you can call it as following:
manager_assoc.upsert()
manager.upsert(ingore_keys = ['manager_id'])
Got a simpler solution:
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import Insert
#compiles(Insert)
def replace_string(insert, compiler, **kw):
s = compiler.visit_insert(insert, **kw)
s = s.replace("INSERT INTO", "REPLACE INTO")
return s
my_connection.execute(my_table.insert(replace_string=""), my_values)
I just used plain sql as:
insert_stmt = "REPLACE INTO tablename (column1, column2) VALUES (:column_1_bind, :columnn_2_bind) "
session.execute(insert_stmt, data)
Update Feb 2023: SQLAlchemy version 2 was recently released and supports on_duplicate_key_update in the MySQL dialect. Many many thanks to Federico Caselli of the SQLAlchemy project who helped me develop sample code in a discussion at https://github.com/sqlalchemy/sqlalchemy/discussions/9328
Please see https://stackoverflow.com/a/75538576/1630244
If it's ok to post the same answer twice (?) here is my small self-contained code example:
import sqlalchemy as db
import sqlalchemy.dialects.mysql as mysql
from sqlalchemy import delete, select, String
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
pass
class User(Base):
__tablename__ = "foo"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(30))
engine = db.create_engine('mysql+mysqlconnector://USER-NAME-HERE:PASS-WORD-HERE#localhost/SCHEMA-NAME-HERE')
conn = engine.connect()
# setup step 0 - ensure the table exists
Base().metadata.create_all(bind=engine)
# setup step 1 - clean out rows with id 1..5
del_stmt = delete(User).where(User.id.in_([1, 2, 3, 4, 5]))
conn.execute(del_stmt)
conn.commit()
sel_stmt = select(User)
users = list(conn.execute(sel_stmt))
print(f'Table size after cleanout: {len(users)}')
# setup step 2 - insert 4 rows
ins_stmt = mysql.insert(User).values(
[
{"id": 1, "name": "x"},
{"id": 2, "name": "y"},
{"id": 3, "name": "w"},
{"id": 4, "name": "z"},
]
)
conn.execute(ins_stmt)
conn.commit()
users = list(conn.execute(sel_stmt))
print(f'Table size after insert: {len(users)}')
# demonstrate upsert
ups_stmt = mysql.insert(User).values(
[
{"id": 1, "name": "xx"},
{"id": 2, "name": "yy"},
{"id": 3, "name": "ww"},
{"id": 5, "name": "new"},
]
)
ups_stmt = ups_stmt.on_duplicate_key_update(name=ups_stmt.inserted.name)
# if you want to see the compiled result
# x = ups_stmt.compile(dialect=mysql.dialect())
# print(x.string, x.construct_params())
conn.execute(ups_stmt)
conn.commit()
users = list(conn.execute(sel_stmt))
print(f'Table size after upsert: {len(users)}')
As none of these solutions seem all the elegant. A brute force way is to query to see if the row exists. If it does delete the row and then insert otherwise just insert. Obviously some overhead involved but it does not rely on modifying the raw sql and it works on non orm stuff.
Related
I am writing unit test case for a function which has multiple sql queries in it.I am using psycopg2 module and trying to mock the cursor.
app.py
import psycopg2
def my_function():
# all connection related code goes here ...
query = "SELECT name,phone FROM customer WHERE name='shanky'"
cursor.execute(query)
columns = [i[0] for i in cursor.description]
customer_response = []
for row in cursor.fetchall():
customer_response.append(dict(zip(columns, row)))
query = "SELECT name,id FROM product WHERE name='soap'"
cursor.execute(query)
columns = [i[0] for i in cursor.description]
product_response = []
for row in cursor.fetchall():
product_response.append(dict(zip(columns, row)))
return product_response
test.py
from pytest_mock import mocker
import psycopg2
def test_my_function(mocker):
from my_module import app
mocker.patch('psycopg2.connect')
#first query
mocked_cursor_one = psycopg2.connect.return_value.cursor.return_value
mocked_cursor_one.description = [['name'],['phone']]
mocked_cursor_one.fetchall.return_value = [('shanky', '347539593')]
mocked_cursor_one.execute.call_args == "SELECT name,phone FROM customer WHERE name='shanky'"
#second query
mocked_cursor_two = psycopg2.connect.return_value.cursor.return_value
mocked_cursor_two.description = [['name'],['id']]
mocked_cursor_two.fetchall.return_value = [('nirma', 12313)]
mocked_cursor_two.execute.call_args == "SELECT name,id FROM product WHERE name='soap'"
ret = app.my_function()
assert ret == {'name' : 'nirma', 'id' : 12313}
But the mocker always takes the last mock object (the second query).I have already tried multiple hacks, but that didn't work out. How can i mock multiple queries in one function and successfully pass the unit test case? Is it possible to write a unit test case in this fashion or do i need to split the queries in different functions?
After drilling a lot through the documentation, I was able to achieve this with the help of unittest mock decorator and side_effect which was suggested by #Pavel Vergeev.I was able to write a unit test case that is good enough to test the functionality.
from unittest import mock
from my_module import app
#mock.patch('psycopg2.connect')
def test_my_function(mocked_db):
mocked_cursor = mocked_db.return_value.cursor.return_value
description_mock = mock.PropertyMock()
type(mocked_cursor).description = description_mock
fetchall_return_one = [('shanky', '347539593')]
fetchall_return_two = [('nirma', 12313)]
descriptions = [
[['name'],['phone']],
[['name'],['id']]
]
mocked_cursor.fetchall.side_effect = [fetchall_return_one, fetchall_return_two]
description_mock.side_effect = descriptions
ret = app.my_function()
# assert whether called with mocked side effect objects
mocked_db.assert_has_calls(mocked_cursor.fetchall.side_effect)
# assert db query count is 2
assert mocked_db.return_value.cursor.return_value.execute.call_count == 2
# first query
query1 = """
SELECT name,phone FROM customer WHERE name='shanky'
"""
assert mocked_db.return_value.cursor.return_value.execute.call_args_list[0][0][0] == query1
# second query
query2 = """
SELECT name,id FROM product WHERE name='soap'
"""
assert mocked_db.return_value.cursor.return_value.execute.call_args_list[1][0][0] == query2
# assert the data of response
assert ret == {'name' : 'nirma', 'id' : 12313}
In addition to this if there are dynamic parameters in the query, that can be asserted too by the following method.
assert mocked_db.return_value.cursor.return_value.execute.call_args_list[0][0][1] = (parameter_name,)
so when the first query is executed, cursor.execute(query,(parameter_name,)) at call_args_list[0][0][0] the query can be obtained and asserted, at call_args_list[0][0][1] the first parameter parameter_name can be obtained. similarly incrementing the index, all the other params and different queries can be obtained and asserted.
Try side_effect argument of mocker.patch:
from unittest.mock import MagicMock
from pytest_mock import mocker
import psycopg2
def test_my_function(mocker):
from my_module import app
mocker.patch('psycopg2.connect', side_effect=[MagicMock(), MagicMock()])
#first query
mocked_cursor_one = psycopg2.connect().cursor.return_value # note that we actually call psyocpg2.connect -- it's important
mocked_cursor_one.description = [['name'],['phone']]
mocked_cursor_one.fetchall.return_value = [('shanky', '347539593')]
mocked_cursor_one.execute.call_args == "SELECT name,phone FROM customer WHERE name='shanky'"
#second query
mocked_cursor_two = psycopg2.connect().cursor.return_value
mocked_cursor_two.description = [['name'],['id']]
mocked_cursor_two.fetchall.return_value = [('nirma', 12313)]
mocked_cursor_two.execute.call_args == "SELECT name,id FROM product WHERE name='soap'"
assert mocked_cursor_one is not mocked_cursor_two # show that they are different
ret = app.my_function()
assert ret == {'name' : 'nirma', 'id' : 12313}
As per the docs, side_effect allows you to change returned value each time the patched object is called:
If you pass in an iterable, it is used to retrieve an iterator which must yield a value on every call. This value can either be an exception instance to be raised, or a value to be returned from the call to the mock
As I have mentioned in an earlier comment, the best way to make unit testing portable is to develop a complete Mock of your database's behavior.
I've done it for MySQL but it's pretty much the same for all databases.
First of all, I like using wrapper classes over the packages I'm using, it helps quickly change the database at one place instead of changing it everywhere in the code.
Here's a samople of what I use as a wrapper:
Now, you would need to Mock this MySQL class:
# _database.py
# -----------------------------------------------------------------------------
# Database Metaclass
# -----------------------------------------------------------------------------
"""Metaclass for Database implementation.
"""
# -----------------------------------------------------------------------------
import logging
logger = logging.getLogger(__name__)
class Database:
"""Database Metaclass"""
def __init__(self, connect_func, **kwargs):
self.connection = connect_func(**kwargs)
def execute(self, statement, fetchall=True):
"""Execute a statement.
Execute the statement passed as arugment.
Args:
statement (str): SQL Query or Command to execute.
Returns:
set: List of returned objects by the cursor.
"""
cursor = self.connection.cursor()
logger.debug(f"Executing: {statement}")
cursor.execute(statement)
if fetchall:
return cursor.fetchall()
else:
return cursor.fetchone()
def __del__(self):
"""Close connection on object deletion."""
self.connection.close()
And the mysql module:
# mysql.py
# -*- coding: utf-8 -*-
# -----------------------------------------------------------------------------
# MySQL Database Class
# -----------------------------------------------------------------------------
"""Class for MySQL Database connection."""
# -----------------------------------------------------------------------------
import logging
import mysql.connector
from . import _database
logger = logging.getLogger(__name__)
class MySQL(_database.Database):
"""Snowflake Database Class Wrapper.
Attributes:
connection (obj): Object returned from mysql.connector.connect
"""
def __init__(self, autocommit=True, **kwargs):
super().__init__(connect_func=mysql.connector.connect, **kwargs)
self.connection.autocommit = autocommit
Instantiate like: db = MySQL(user='...', password='...', ...)
Here's the data file:
# database_mock_data.json
{
"customer": {
"name": [
"shanky",
"nirma"
],
"phone": [
123123123,
232342342
]
},
"product": {
"name": [
"shanky",
"nirma"
],
"id": [
1,
2
]
}
}
The mocks.py
# mocks.py
import json
import re
from . import mysql
_MOCK_DATA_PATH = 'database_mock_data.json'
class MockDatabase(MySQL):
"""
"""
def __init__(self, **kwargs):
self.connection = MockConnection()
class MockConnection:
"""
Mock the connection object by returning a mock cursor.
"""
#staticmethod
def cursor():
return MockCursor()
class MockCursor:
"""
The Mocked Cursor
A call to execute() will initiate the read on the json data file and will set
the description object (containing the column names usually).
You could implement an update function like `_json_sql_update()`
"""
def __init__(self):
self.description = []
self.__result = None
def execute(self, statement):
data = _read_json_file(_MOCK_DATA_PATH)
if statement.upper().startswith('SELECT'):
self.__result, self.description = _json_sql_select(data, statement)
def fetchall(self):
return self.__result
def fetchone(self):
return self.__result[0]
def _json_sql_select(data, query):
"""
Takes a dictionary and returns the values from a sql query.
NOTE: It does not work with other where clauses than '='.
Also, note that a where statement is expected.
:param (dict) data: Dictionary with the following structure:
{
'tablename': {
'column_name_1': ['value1', 'value2],
'column_name_2': ['value1', 'value2],
...
},
...
}
:param (str) query: An update sql query as:
`update TABLENAME set column_name_1='value'
where column_name_2='value1'`
:return: List of list of values and header description
"""
try:
match = (re.search("select(.*)from(.*)where(.*)[;]?", query,
re.IGNORECASE | re.DOTALL).groups())
except AttributeError:
print("Select Query pattern mismatch... {}".format(query))
raise
# Parse values from the select query
tablename = match[1].strip().upper()
columns = [col.strip().upper() for col in match[0].split(",")]
if columns == ['*']:
columns = data[tablename].keys()
where = [cmd.upper().strip().replace(' ', '')
for cmd in match[2].split('and')]
# Select values
selected_values = []
nb_lines = len(list(data[tablename].values())[0])
for i in range(nb_lines):
is_match = True
for condition in where:
key_condition, value_condition = (_clean_string(condition)
.split('='))
if data[tablename][key_condition][i].upper() != value_condition:
# Set flag to yes
is_match = False
if is_match:
sub_list = []
for column in columns:
sub_list.append(data[tablename][column][i])
selected_values.append(sub_list)
# Usual descriptor has nested list
description = zip(columns, ['...'] * len(columns))
return selected_values, description
def _read_json_file(file_path):
with open(file_path, 'r') as f_in:
data = json.load(f_in)
return data
And then you have your test in a test_module_yourfunction.py
import pytest
def my_function(db, query):
# Code goes here
#pytest.fixture
def db_connection():
return MockDatabase()
#pytest.mark.parametrize(
("query", "expected"),
[
("SELECT name,phone FROM customer WHERE name='shanky'", {'name' : 'nirma', 'id' : 12313}),
("<second query goes here>", "<second result goes here>")
]
)
def test_my_function(db_connection, query, expected):
assert my_function(db_connection, query) == expected
Now I'm sorry if you can't copy/paste this code and make it work, but you get the feeling :) just trying to help
I have a table with equipment and each of them has dates for level of maintenance. The user can select the maintenance level. So, I should adjust my SQLAlchemy for each combination of maintenance level chosen. For example:
SELECT * WHERE (equipment IN []) AND m_level1 = DATE AND m_level2 = DATE ....)
So it is possible to have combinations for each if condition, depending on checkboxes I used multiple strings to reach my goal, but I want to improve the query using SQLAlchemy.
I assume you are using the ORM.
in that case, the filter function returns a query object. You can conditionaly build the query by doing something like
query = Session.query(schema.Object).filter_by(attribute=value)
if condition:
query = query.filter_by(condition_attr=condition_val)
if another_condition:
query = query.filter_by(another=another_val)
#then finally execute it
results = query.all()
The function filter(*criterion) means you can use tuple as it's argument, #Wolph has detail here:
SQLALchemy dynamic filter_by for detail
if we speak of SQLAlchemy core, there is another way:
from sqlalchemy import and_
filters = [table.c.col1 == filter1, table.c.col2 > filter2]
query = table.select().where(and_(*filters))
If you're trying to filter based on incoming form criteria:
form = request.form.to_dict()
filters = []
for col in form:
sqlalchemybinaryexpression = (getattr(MODEL, col) == form[col])
filters.append(sqlalchemybinaryexpression)
query = table.select().where(and_(*filters))
Where MODEL is your SQLAlchemy Model
Another resolution to this question, this case is raised in a more secure way, since it verifies if the field to be filtered exists in the model.
To add operators to the value to which you want to filter. And not having to add a new parameter to the query, we can add an operator before the value e.g ?foo=>1, '?foo=<1,?foo=>=1, ?foo=<=1 ', ?foo=!1,?foo=1, and finally between which would be like this `?foo=a, b'.
from sqlalchemy.orm import class_mapper
import re
# input parameters
filter_by = {
"column1": "!1", # not equal to
"column2": "1", # equal to
"column3": ">1", # great to. etc...
}
def computed_operator(column, v):
if re.match(r"^!", v):
"""__ne__"""
val = re.sub(r"!", "", v)
return column.__ne__(val)
if re.match(r">(?!=)", v):
"""__gt__"""
val = re.sub(r">(?!=)", "", v)
return column.__gt__(val)
if re.match(r"<(?!=)", v):
"""__lt__"""
val = re.sub(r"<(?!=)", "", v)
return column.__lt__(val)
if re.match(r">=", v):
"""__ge__"""
val = re.sub(r">=", "", v)
return column.__ge__(val)
if re.match(r"<=", v):
"""__le__"""
val = re.sub(r"<=", "", v)
return column.__le__(val)
if re.match(r"(\w*),(\w*)", v):
"""between"""
a, b = re.split(r",", v)
return column.between(a, b)
""" default __eq__ """
return column.__eq__(v)
query = Table.query
filters = []
for k, v in filter_by.items():
mapper = class_mapper(Table)
if not hasattr(mapper.columns, k):
continue
filters.append(computed_operator(mapper.columns[k], "{}".format(v))
query = query.filter(*filters)
query.all()
Here is a solution that works both for AND or OR...
Just replace or_ with and_ in the code if you need that case:
from sqlalchemy import or_, and_
my_filters = set() ## <-- use a set to contain only unique values, avoid duplicates
if condition_1:
my_filters.add(MySQLClass.id == some_id)
if condition_2:
my_filters.add(MySQLClass.name == some_name)
fetched = db_session.execute(select(MySQLClass).where(or_(*my_filters))).scalars().all()
I have a Flask application with a RESTful API. One of the API calls is a 'mass upsert' call with a JSON payload. I am struggling with performance.
The first thing I tried was to use merge-result on a Query object, because...
This is an optimized method which will merge all mapped instances, preserving the structure of the result rows and unmapped columns with less method overhead than that of calling Session.merge() explicitly for each value.
This was the initial code:
class AdminApiUpdateTasks(Resource):
"""Bulk task creation / update endpoint"""
def put(self, slug):
taskdata = json.loads(request.data)
existing = db.session.query(Task).filter_by(challenge_slug=slug)
existing.merge_result(
[task_from_json(slug, **task) for task in taskdata])
db.session.commit()
return {}, 200
A request to that endpoint with ~5000 records, all of them already existing in the database, takes more than 11m to return:
real 11m36.459s
user 0m3.660s
sys 0m0.391s
As this would be a fairly typical use case, I started looking into alternatives to improve performance. Against my better judgement, I tried to merge the session for each individual record:
class AdminApiUpdateTasks(Resource):
"""Bulk task creation / update endpoint"""
def put(self, slug):
# Get the posted data
taskdata = json.loads(request.data)
for task in taskdata:
db.session.merge(task_from_json(slug, **task))
db.session.commit()
return {}, 200
To my surprise, this turned out to be more than twice as fast:
real 4m33.945s
user 0m3.608s
sys 0m0.258s
I have two questions:
Why is the second strategy using merge faster than the supposedly optimized first one that uses merge_result?
What other strategies should I pursue to optimize this more, if any?
This is an old question, but I hope this answer can still help people.
I used the same idea as this example set by SQLAlchemy, but I added benchmarking for doing UPSERT (insert if exists, otherwise update the existing record) operations. I added the results on a PostgreSQL 11 database below:
Tests to run: test_customer_individual_orm_select, test_customer_batched_orm_select, test_customer_batched_orm_select_add_all, test_customer_batched_orm_merge_result
test_customer_individual_orm_select : UPSERT statements via individual checks on whether objects exist and add new objects individually (10000 iterations); total time 9.359603 sec
test_customer_batched_orm_select : UPSERT statements via batched checks on whether objects exist and add new objects individually (10000 iterations); total time 1.553555 sec
test_customer_batched_orm_select_add_all : UPSERT statements via batched checks on whether objects exist and add new objects in bulk (10000 iterations); total time 1.358680 sec
test_customer_batched_orm_merge_result : UPSERT statements using batched merge_results (10000 iterations); total time 7.191284 sec
As you can see, merge-result is far from the most efficient option. I'd suggest checking in batches whether the results exist and should be updated. Hope this helps!
"""
This series of tests illustrates different ways to UPSERT
or INSERT ON CONFLICT UPDATE a large number of rows in bulk.
"""
from sqlalchemy import Column
from sqlalchemy import create_engine
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
from profiler import Profiler
Base = declarative_base()
engine = None
class Customer(Base):
__tablename__ = "customer"
id = Column(Integer, primary_key=True)
name = Column(String(255))
description = Column(String(255))
Profiler.init("bulk_upserts", num=100000)
#Profiler.setup
def setup_database(dburl, echo, num):
global engine
engine = create_engine(dburl, echo=echo)
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
s = Session(engine)
for chunk in range(0, num, 10000):
# Insert half of the customers we want to merge
s.bulk_insert_mappings(
Customer,
[
{
"id": i,
"name": "customer name %d" % i,
"description": "customer description %d" % i,
}
for i in range(chunk, chunk + 10000, 2)
],
)
s.commit()
#Profiler.profile
def test_customer_individual_orm_select(n):
"""
UPSERT statements via individual checks on whether objects exist
and add new objects individually
"""
session = Session(bind=engine)
for i in range(0, n):
customer = session.query(Customer).get(i)
if customer:
customer.description += "updated"
else:
session.add(Customer(
id=i,
name=f"customer name {i}",
description=f"customer description {i} new"
))
session.flush()
session.commit()
#Profiler.profile
def test_customer_batched_orm_select(n):
"""
UPSERT statements via batched checks on whether objects exist
and add new objects individually
"""
session = Session(bind=engine)
for chunk in range(0, n, 1000):
customers = {
c.id: c for c in
session.query(Customer)\
.filter(Customer.id.between(chunk, chunk + 1000))
}
for i in range(chunk, chunk + 1000):
if i in customers:
customers[i].description += "updated"
else:
session.add(Customer(
id=i,
name=f"customer name {i}",
description=f"customer description {i} new"
))
session.flush()
session.commit()
#Profiler.profile
def test_customer_batched_orm_select_add_all(n):
"""
UPSERT statements via batched checks on whether objects exist
and add new objects in bulk
"""
session = Session(bind=engine)
for chunk in range(0, n, 1000):
customers = {
c.id: c for c in
session.query(Customer)\
.filter(Customer.id.between(chunk, chunk + 1000))
}
to_add = []
for i in range(chunk, chunk + 1000):
if i in customers:
customers[i].description += "updated"
else:
to_add.append({
"id": i,
"name": "customer name %d" % i,
"description": "customer description %d new" % i,
})
if to_add:
session.bulk_insert_mappings(
Customer,
to_add
)
to_add = []
session.flush()
session.commit()
#Profiler.profile
def test_customer_batched_orm_merge_result(n):
"UPSERT statements using batched merge_results"
session = Session(bind=engine)
for chunk in range(0, n, 1000):
customers = session.query(Customer)\
.filter(Customer.id.between(chunk, chunk + 1000))
customers.merge_result(
Customer(
id=i,
name=f"customer name {i}",
description=f"customer description {i} new"
) for i in range(chunk, chunk + 1000)
)
session.flush()
session.commit()
I think that either this was causing your slowness in the first query:
existing = db.session.query(Task).filter_by(challenge_slug=slug)
Also you should probably change this:
existing.merge_result(
[task_from_json(slug, **task) for task in taskdata])
To:
existing.merge_result(
(task_from_json(slug, **task) for task in taskdata))
As that should save you some memory and time, as the list won't be generated in memory before sending it to the merge_result method.
Using SQLAlchemy, I am trying to print out all of the attributes of each model that I have in a manner similar to:
SELECT * from table;
However, I would like to do something with each models instance information as I get it. So far the best that I've been able to come up with is:
for m in session.query(model).all():
print [getattr(m, x.__str__().split('.')[1]) for x in model.__table__.columns]
# additional code
And this will give me what I'm looking for, but it's a fairly roundabout way of getting it. I was kind of hoping for an attribute along the lines of:
m.attributes
# or
m.columns.values
I feel I'm missing something and there is a much better way of doing this. I'm doing this because I'll be printing everything to .CSV files, and I don't want to have to specify the columns/attributes that I'm interested in, I want everything (there's a lot of columns in a lot of models to be printed).
This is an old post, but I ran into a problem with the actual database column names not matching the mapped attribute names on the instance. We ended up going with this:
from sqlalchemy import inspect
inst = inspect(model)
attr_names = [c_attr.key for c_attr in inst.mapper.column_attrs]
Hope that helps somebody with the same problem!
Probably the shortest solution (see the recent documentation):
from sqlalchemy.inspection import inspect
columns = [column.name for column in inspect(model).c]
The last line might look more readable, if rewrite it in three lines:
table = inspect(model)
for column in table.c:
print column.name
Building on Rodney L's answer:
model = MYMODEL
columns = [m.key for m in model.__table__.columns]
Take a look at SQLAchemy's metadata reflection feature.
A Table object can be instructed to load information about itself from the corresponding database schema object already existing within the database. This process is called reflection.
print repr(model.__table__)
Or just the columns:
print str(list(model.__table__.columns))
I believe this is the easiest way:
print [cname for cname in m.__dict__.keys()]
EDIT: The answer above me using sqlalchemy.inspection.inspect() seems to be a better solution.
Put this together and found it helpful:
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
engine = create_engine('mysql+pymysql://testuser:password#localhost:3306/testdb')
DeclarativeBase = declarative_base()
metadata = DeclarativeBase.metadata
metadata.bind = engine
# configure Session class with desired options
Session = sessionmaker()
# associate it with our custom Session class
Session.configure(bind=engine)
# work with the session
session = Session()
And then:
d = {k: metadata.tables[k].columns.keys() for k in metadata.tables.keys()}
Example output print(d):
{'orderdetails': ['orderNumber', 'productCode', 'quantityOrdered', 'priceEach', 'orderLineNumber'],
'offices': ['addressLine1', 'addressLine2', 'city', 'country', 'officeCode', 'phone', 'postalCode', 'state', 'territory'],
'orders': ['comments', 'customerNumber', 'orderDate', 'orderNumber', 'requiredDate', 'shippedDate', 'status'],
'products': ['MSRP', 'buyPrice', 'productCode', 'productDescription', 'productLine', 'productName', 'productScale', 'productVendor', 'quantityInStock'],
'employees': ['employeeNumber', 'lastName', 'firstName', 'extension', 'email', 'officeCode', 'reportsTo', 'jobTitle'],
'customers': ['addressLine1', 'addressLine2', 'city', 'contactFirstName', 'contactLastName', 'country', 'creditLimit', 'customerName', 'customerNumber', 'phone', 'postalCode', 'salesRepEmployeeNumber', 'state'],
'productlines': ['htmlDescription', 'image', 'productLine', 'textDescription'],
'payments': ['amount', 'checkNumber', 'customerNumber', 'paymentDate']}
OR and then:
from sqlalchemy.sql import text
cmd = "SELECT * FROM information_schema.columns WHERE table_schema = :db ORDER BY table_name,ordinal_position"
result = session.execute(
text(cmd),
{"db": "classicmodels"}
)
result.fetchall()
I'm using SQL Alchemy v 1.0.14 on Python 3.5.2
Assuming you can connect to an engine with create_engine(), I was able to display all columns using the following code. Replace "my connection string" and "my table name" with the appropriate values.
from sqlalchemy import create_engine, MetaData, Table, select
engine = create_engine('my connection string')
conn = engine.connect()
metadata = MetaData(conn)
t = Table("my table name", metadata, autoload=True)
columns = [m.key for m in t.columns]
columns
the last row just displays the column names from the previous statement.
You may be interested in what I came up with to do this.
from sqlalchemy.orm import class_mapper
import collections
# structure returned by get_metadata function.
MetaDataTuple = collections.namedtuple("MetaDataTuple",
"coltype, colname, default, m2m, nullable, uselist, collection")
def get_metadata_iterator(class_):
for prop in class_mapper(class_).iterate_properties:
name = prop.key
if name.startswith("_") or name == "id" or name.endswith("_id"):
continue
md = _get_column_metadata(prop)
if md is None:
continue
yield md
def get_column_metadata(class_, colname):
prop = class_mapper(class_).get_property(colname)
md = _get_column_metadata(prop)
if md is None:
raise ValueError("Not a column name: %r." % (colname,))
return md
def _get_column_metadata(prop):
name = prop.key
m2m = False
default = None
nullable = None
uselist = False
collection = None
proptype = type(prop)
if proptype is ColumnProperty:
coltype = type(prop.columns[0].type).__name__
try:
default = prop.columns[0].default
except AttributeError:
default = None
else:
if default is not None:
default = default.arg(None)
nullable = prop.columns[0].nullable
elif proptype is RelationshipProperty:
coltype = RelationshipProperty.__name__
m2m = prop.secondary is not None
nullable = prop.local_side[0].nullable
uselist = prop.uselist
if prop.collection_class is not None:
collection = type(prop.collection_class()).__name__
else:
collection = "list"
else:
return None
return MetaDataTuple(coltype, str(name), default, m2m, nullable, uselist, collection)
I use this because it's slightly shorter:
for m in session.query(*model.__table__.columns).all():
print m
I have some problems with setting up the dictionary collection in Python's SQLAlchemy:
I am using declarative definition of tables. I have Item table in 1:N relation with Record table. I set up the relation using the following code:
_Base = declarative_base()
class Record(_Base):
__tablename__ = 'records'
item_id = Column(String(M_ITEM_ID), ForeignKey('items.id'))
id = Column(String(M_RECORD_ID), primary_key=True)
uri = Column(String(M_RECORD_URI))
name = Column(String(M_RECORD_NAME))
class Item(_Base):
__tablename__ = 'items'
id = Column(String(M_ITEM_ID), primary_key=True)
records = relation(Record, collection_class=column_mapped_collection(Record.name), backref='item')
Now I want to work with the Items and Records. Let's create some objects:
i1 = Item(id='id1')
r = Record(id='mujrecord')
And now I want to associate these objects using the following code:
i1.records['source_wav'] = r
but the Record r doesn't have set the name attribute (the foreign key). Is there any solution how to automatically ensure this? (I know that setting the foreign key during the Record creation works, but it doesn't sound good for me).
Many thanks
You want something like this:
from sqlalchemy.orm import validates
class Item(_Base):
[...]
#validates('records')
def validate_record(self, key, record):
assert record.name is not None, "Record fails validation, must have a name"
return record
With this, you get the desired validation:
>>> i1 = Item(id='id1')
>>> r = Record(id='mujrecord')
>>> i1.records['source_wav'] = r
Traceback (most recent call last):
[...]
AssertionError: Record fails validation, must have a name
>>> r.name = 'foo'
>>> i1.records['source_wav'] = r
>>>
I can't comment yet, so I'm just going to write this as a separate answer:
from sqlalchemy.orm import validates
class Item(_Base):
[...]
#validates('records')
def validate_record(self, key, record):
record.name=key
return record
This is basically a copy of Gunnlaugur's answer but abusing the validates decorator to do something more useful than exploding.
You have:
backref='item'
Is this a typo for
backref='name'
?