Related
I am trying to run the same exact test on a single obj which is a models.Model instance and has some relations with other models. I do not want to persist changes in that instance, so effectively I want the same effect of the tearDown method which rollbacks transactions.
To illustrate this:
class MyTestCase(django.test.TestCase):
def test():
# main test that calls the same test using all
# different states of `obj` that need to be tested
# different values of data that update the state of `obj`
# with state I simply mean the values of `obj`'s attributes and relationships
data = [state1, state2, state3]
for state in data:
obj = Obj.objects.get(pk=self.pk) # gets that SINGLE object from the test db
# applies the state data to `obj` to change its state
obj.update(state)
# performs the actual test on `obj` with this particular state
self._test_obj(obj)
def _test_obj(self, obj):
self.assertEqual(len(obj.vals), 10)
self.assertLess(obj.threshold, 99)
# more assert statements...
This design has two problems:
The changes on obj persist on the test database, so on the next iteration the data would be tainted. I would want to rollback those changes and get a fresh instance of obj as if the test method was just called and we are getting the data straight from the fixtures.
If an assert statement fails I will be able to see which one it it, but I won't be able to determine what case (state) failed because of the for loop. I can try-except the _test_obj_ call in the test method but then I wouldn't be able to tell what assert failed.
Does django.test provide any tool to run the same test for different states of the same model? If it doesn't, how can I do what I am trying to do while solving both points mentioned above?
Simply rollback after you're done with the object.
You can use the new subTest in python 3.4+
Here's how your code should look:
class TestProductApp(TestCase):
def setUp(self):
self.product1 = ...
def test_multistate(self):
state1 = dict(name='p1')
state2 = dict(name='p2')
data = [state1, state2]
for i, state in enumerate(data):
with self.subTest(i=i):
try:
with transaction.atomic():
product = Product.objects.get(id=self.product1.id)
product.name = state['name']
product.save()
self.assertEqual(len(product.name), 2)
raise DatabaseError #forces a rollback
except DatabaseError:
pass
print(Product.objects.get(id=self.product1.id)) #prints data created in setUp/fixture
This answer can be improved. Rather than forcing a rollback with an error, you can simply set a rollback for the atomic block. See set_rollback()
class TestProductApp(TestCase):
def setUp(self):
self.product1 = ...
def test_multistate(self):
state1 = dict(name='p1')
state2 = dict(name='p2')
data = [state1, state2]
for i, state in enumerate(data):
with self.subTest(i=i):
with transaction.atomic():
product = Product.objects.get(id=self.product1.id)
product.name = state['name']
product.save()
self.assertEqual(len(product.name), 2)
transaction.set_rollback(True) # forces a rollback
print(Product.objects.get(id=self.product1.id)) #prints data created in setUp/fixture
I want to create a method on a Django model, call it model.duplicate(), that duplicates the model instance, including all the foreign keys pointing to it. I know that you can do this:
def duplicate(self):
self.pk = None
self.save()
...but this way all the related models still point to the old instance.
I can't simply save a reference to the original object because what self points to changes during execution of the method:
def duplicate(self):
original = self
self.pk = None
self.save()
assert original is not self # fails
I could try to save a reference to just the related object:
def duplicate(self):
original_fkeys = self.fkeys.all()
self.pk = None
self.save()
self.fkeys.add(*original_fkeys)
...but this transfers them from the original record to the new one. I need them copied over and pointed at the new record.
Several answers elsewhere (and here before I updated the question) have suggested using Python's copy, which I suspect works for foreign keys on this model, but not foreign keys on another model pointing to it.
def duplicate(self):
new_model = copy.deepcopy(self)
new_model.pk = None
new_model.save()
If you do this new_model.fkeys.all() (to follow my naming scheme thus far) will be empty.
You can create new instance and save it like this
def duplicate(self):
kwargs = {}
for field in self._meta.fields:
kwargs[field.name] = getattr(self, field.name)
# or self.__dict__[field.name]
kwargs.pop('id')
new_instance = self.__class__(**kwargs)
new_instance.save()
# now you have id for the new instance so you can
# create related models in similar fashion
fkeys_qs = self.fkeys.all()
new_fkeys = []
for fkey in fkey_qs:
fkey_kwargs = {}
for field in fkey._meta.fields:
fkey_kwargs[field.name] = getattr(fkey, field.name)
fkey_kwargs.pop('id')
fkey_kwargs['foreign_key_field'] = new_instance.id
new_fkeys.append(fkey_qs.model(**fkey_kwargs))
fkeys_qs.model.objects.bulk_create(new_fkeys)
return new_instance
I'm not sure how it'll behave with ManyToMany fields. But for simple fields it works. And you can always pop the fields you are not interested in for your new instance.
The bits where I'm iterating over _meta.fields may be done with copy but the important thing is to use the new id for the foreign_key_field.
I'm sure it's programmatically possible to detect which fields are foreign keys to the self.__class__ (foreign_key_field) but since you can have more of them it'll better to name the one (or more) explicitly.
Although I accepted the other poster's answer (since it helped me get here), I wanted to post the solution I ended up with in case it helps someone else stuck in the same place.
def duplicate(self):
"""
Duplicate a model instance, making copies of all foreign keys pointing
to it. This is an in-place method in the sense that the record the
instance is pointing to will change once the method has run. The old
record is still accessible but must be retrieved again from
the database.
"""
# I had a known set of related objects I wanted to carry over, so I
# listed them explicitly rather than looping over obj._meta.fields
fks_to_copy = list(self.fkeys_a.all()) + list(self.fkeys_b.all())
# Now we can make the new record
self.pk = None
# Make any changes you like to the new instance here, then
self.save()
foreign_keys = {}
for fk in fks_to_copy:
fk.pk = None
# Likewise make any changes to the related model here
# However, we avoid calling fk.save() here to prevent
# hitting the database once per iteration of this loop
try:
# Use fk.__class__ here to avoid hard-coding the class name
foreign_keys[fk.__class__].append(fk)
except KeyError:
foreign_keys[fk.__class__] = [fk]
# Now we can issue just two calls to bulk_create,
# one for fkeys_a and one for fkeys_b
for cls, list_of_fks in foreign_keys.items():
cls.objects.bulk_create(list_of_fks)
What it looks like when you use it:
In [6]: model.id
Out[6]: 4443
In [7]: model.duplicate()
In [8]: model.id
Out[8]: 17982
In [9]: old_model = Model.objects.get(id=4443)
In [10]: old_model.fkeys_a.count()
Out[10]: 2
In [11]: old_model.fkeys_b.count()
Out[11]: 1
In [12]: model.fkeys_a.count()
Out[12]: 2
In [13]: model.fkeys_b.count()
Out[13]: 1
Model and related_model names changed to protect the innocent.
I tried the other answers in Django 2.1/Python 3.6 and they didn't seem to copy one-to-many and many-to-many related objects (self._meta.fields doesn't include one-to-many related fields but self._meta.get_fields() does). Also, the other answers required prior knowledge of the related field name or knowledge of which foreign keys to copy.
I wrote a way to do this in a more generic fashion, handling one-to-many and many-to-many related fields. Comments included, and suggestions welcome:
def duplicate_object(self):
"""
Duplicate a model instance, making copies of all foreign keys pointing to it.
There are 3 steps that need to occur in order:
1. Enumerate the related child objects and m2m relations, saving in lists/dicts
2. Copy the parent object per django docs (doesn't copy relations)
3a. Copy the child objects, relating to the copied parent object
3b. Re-create the m2m relations on the copied parent object
"""
related_objects_to_copy = []
relations_to_set = {}
# Iterate through all the fields in the parent object looking for related fields
for field in self._meta.get_fields():
if field.one_to_many:
# One to many fields are backward relationships where many child objects are related to the
# parent (i.e. SelectedPhrases). Enumerate them and save a list so we can copy them after
# duplicating our parent object.
print(f'Found a one-to-many field: {field.name}')
# 'field' is a ManyToOneRel which is not iterable, we need to get the object attribute itself
related_object_manager = getattr(self, field.name)
related_objects = list(related_object_manager.all())
if related_objects:
print(f' - {len(related_objects)} related objects to copy')
related_objects_to_copy += related_objects
elif field.many_to_one:
# In testing so far, these relationships are preserved when the parent object is copied,
# so they don't need to be copied separately.
print(f'Found a many-to-one field: {field.name}')
elif field.many_to_many:
# Many to many fields are relationships where many parent objects can be related to many
# child objects. Because of this the child objects don't need to be copied when we copy
# the parent, we just need to re-create the relationship to them on the copied parent.
print(f'Found a many-to-many field: {field.name}')
related_object_manager = getattr(self, field.name)
relations = list(related_object_manager.all())
if relations:
print(f' - {len(relations)} relations to set')
relations_to_set[field.name] = relations
# Duplicate the parent object
self.pk = None
self.save()
print(f'Copied parent object ({str(self)})')
# Copy the one-to-many child objects and relate them to the copied parent
for related_object in related_objects_to_copy:
# Iterate through the fields in the related object to find the one that relates to the
# parent model (I feel like there might be an easier way to get at this).
for related_object_field in related_object._meta.fields:
if related_object_field.related_model == self.__class__:
# If the related_model on this field matches the parent object's class, perform the
# copy of the child object and set this field to the parent object, creating the
# new child -> parent relationship.
related_object.pk = None
setattr(related_object, related_object_field.name, self)
related_object.save()
text = str(related_object)
text = (text[:40] + '..') if len(text) > 40 else text
print(f'|- Copied child object ({text})')
# Set the many-to-many relations on the copied parent
for field_name, relations in relations_to_set.items():
# Get the field by name and set the relations, creating the new relationships
field = getattr(self, field_name)
field.set(relations)
text_relations = []
for relation in relations:
text_relations.append(str(relation))
print(f'|- Set {len(relations)} many-to-many relations on {field_name} {text_relations}')
return self
Here is a somewhat simple-minded solution. This does not depend on any undocumented Django APIs. It assumes that you want to duplicate a single parent record, along with its child, grandchild, etc. records. You pass in a whitelist of classes that should actually be duplicated, in the form of a list of names of the one-to-many relationships on each parent object that point to its child objects. This code assumes that, given the above whitelist, the entire tree is self-contained, with no external references to worry about.
One more thing about this code: it is truly recursive, in that it calls itself for each new level of descendants.
from collections import OrderedDict
def duplicate_model_with_descendants(obj, whitelist, _new_parent_pk=None):
kwargs = {}
children_to_clone = OrderedDict()
for field in obj._meta.get_fields():
if field.name == "id":
pass
elif field.one_to_many:
if field.name in whitelist:
these_children = list(getattr(obj, field.name).all())
if children_to_clone.has_key(field.name):
children_to_clone[field.name] |= these_children
else:
children_to_clone[field.name] = these_children
else:
pass
elif field.many_to_one:
if _new_parent_pk:
kwargs[field.name + '_id'] = _new_parent_pk
elif field.concrete:
kwargs[field.name] = getattr(obj, field.name)
else:
pass
new_instance = obj.__class__(**kwargs)
new_instance.save()
new_instance_pk = new_instance.pk
for ky in children_to_clone.keys():
child_collection = getattr(new_instance, ky)
for child in children_to_clone[ky]:
child_collection.add(duplicate_model_with_descendants(child, whitelist=whitelist, _new_parent_pk=new_instance_pk))
return new_instance
Example usage:
from django.db import models
class Book(models.Model)
class Chapter(models.Model)
book = models.ForeignKey(Book, related_name='chapters')
class Page(models.Model)
chapter = models.ForeignKey(Chapter, related_name='pages')
WHITELIST = ['books', 'chapters', 'pages']
original_record = models.Book.objects.get(pk=1)
duplicate_record = duplicate_model_with_descendants(original_record, WHITELIST)
I want to get an object from the database if it already exists (based on provided parameters) or create it if it does not.
Django's get_or_create (or source) does this. Is there an equivalent shortcut in SQLAlchemy?
I'm currently writing it out explicitly like this:
def get_or_create_instrument(session, serial_number):
instrument = session.query(Instrument).filter_by(serial_number=serial_number).first()
if instrument:
return instrument
else:
instrument = Instrument(serial_number)
session.add(instrument)
return instrument
Following the solution of #WoLpH, this is the code that worked for me (simple version):
def get_or_create(session, model, **kwargs):
instance = session.query(model).filter_by(**kwargs).first()
if instance:
return instance
else:
instance = model(**kwargs)
session.add(instance)
session.commit()
return instance
With this, I'm able to get_or_create any object of my model.
Suppose my model object is :
class Country(Base):
__tablename__ = 'countries'
id = Column(Integer, primary_key=True)
name = Column(String, unique=True)
To get or create my object I write :
myCountry = get_or_create(session, Country, name=countryName)
That's basically the way to do it, there is no shortcut readily available AFAIK.
You could generalize it ofcourse:
def get_or_create(session, model, defaults=None, **kwargs):
instance = session.query(model).filter_by(**kwargs).one_or_none()
if instance:
return instance, False
else:
params = {k: v for k, v in kwargs.items() if not isinstance(v, ClauseElement)}
params.update(defaults or {})
instance = model(**params)
try:
session.add(instance)
session.commit()
except Exception: # The actual exception depends on the specific database so we catch all exceptions. This is similar to the official documentation: https://docs.sqlalchemy.org/en/latest/orm/session_transaction.html
session.rollback()
instance = session.query(model).filter_by(**kwargs).one()
return instance, False
else:
return instance, True
2020 update (Python 3.9+ ONLY)
Here is a cleaner version with Python 3.9's the new dict union operator (|=)
def get_or_create(session, model, defaults=None, **kwargs):
instance = session.query(model).filter_by(**kwargs).one_or_none()
if instance:
return instance, False
else:
kwargs |= defaults or {}
instance = model(**kwargs)
try:
session.add(instance)
session.commit()
except Exception: # The actual exception depends on the specific database so we catch all exceptions. This is similar to the official documentation: https://docs.sqlalchemy.org/en/latest/orm/session_transaction.html
session.rollback()
instance = session.query(model).filter_by(**kwargs).one()
return instance, False
else:
return instance, True
Note:
Similar to the Django version this will catch duplicate key constraints and similar errors. If your get or create is not guaranteed to return a single result it can still result in race conditions.
To alleviate some of that issue you would need to add another one_or_none() style fetch right after the session.commit(). This still is no 100% guarantee against race conditions unless you also use a with_for_update() or serializable transaction mode.
I've been playing with this problem and have ended up with a fairly robust solution:
def get_one_or_create(session,
model,
create_method='',
create_method_kwargs=None,
**kwargs):
try:
return session.query(model).filter_by(**kwargs).one(), False
except NoResultFound:
kwargs.update(create_method_kwargs or {})
created = getattr(model, create_method, model)(**kwargs)
try:
session.add(created)
session.flush()
return created, True
except IntegrityError:
session.rollback()
return session.query(model).filter_by(**kwargs).one(), False
I just wrote a fairly expansive blog post on all the details, but a few quite ideas of why I used this.
It unpacks to a tuple that tells you if the object existed or not. This can often be useful in your workflow.
The function gives the ability to work with #classmethod decorated creator functions (and attributes specific to them).
The solution protects against Race Conditions when you have more than one process connected to the datastore.
EDIT: I've changed session.commit() to session.flush() as explained in this blog post. Note that these decisions are specific to the datastore used (Postgres in this case).
EDIT 2: I’ve updated using a {} as a default value in the function as this is typical Python gotcha. Thanks for the comment, Nigel! If your curious about this gotcha, check out this StackOverflow question and this blog post.
A modified version of erik's excellent answer
def get_one_or_create(session,
model,
create_method='',
create_method_kwargs=None,
**kwargs):
try:
return session.query(model).filter_by(**kwargs).one(), True
except NoResultFound:
kwargs.update(create_method_kwargs or {})
try:
with session.begin_nested():
created = getattr(model, create_method, model)(**kwargs)
session.add(created)
return created, False
except IntegrityError:
return session.query(model).filter_by(**kwargs).one(), True
Use a nested transaction to only roll back the addition of the new item instead of rolling back everything (See this answer to use nested transactions with SQLite)
Move create_method. If the created object has relations and it is assigned members through those relations, it is automatically added to the session. E.g. create a book, which has user_id and user as corresponding relationship, then doing book.user=<user object> inside of create_method will add book to the session. This means that create_method must be inside with to benefit from an eventual rollback. Note that begin_nested automatically triggers a flush.
Note that if using MySQL, the transaction isolation level must be set to READ COMMITTED rather than REPEATABLE READ for this to work. Django's get_or_create (and here) uses the same stratagem, see also the Django documentation.
This SQLALchemy recipe does the job nice and elegant.
The first thing to do is to define a function that is given a Session to work with, and associates a dictionary with the Session() which keeps track of current unique keys.
def _unique(session, cls, hashfunc, queryfunc, constructor, arg, kw):
cache = getattr(session, '_unique_cache', None)
if cache is None:
session._unique_cache = cache = {}
key = (cls, hashfunc(*arg, **kw))
if key in cache:
return cache[key]
else:
with session.no_autoflush:
q = session.query(cls)
q = queryfunc(q, *arg, **kw)
obj = q.first()
if not obj:
obj = constructor(*arg, **kw)
session.add(obj)
cache[key] = obj
return obj
An example of utilizing this function would be in a mixin:
class UniqueMixin(object):
#classmethod
def unique_hash(cls, *arg, **kw):
raise NotImplementedError()
#classmethod
def unique_filter(cls, query, *arg, **kw):
raise NotImplementedError()
#classmethod
def as_unique(cls, session, *arg, **kw):
return _unique(
session,
cls,
cls.unique_hash,
cls.unique_filter,
cls,
arg, kw
)
And finally creating the unique get_or_create model:
from sqlalchemy import Column, Integer, String, create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
engine = create_engine('sqlite://', echo=True)
Session = sessionmaker(bind=engine)
class Widget(UniqueMixin, Base):
__tablename__ = 'widget'
id = Column(Integer, primary_key=True)
name = Column(String, unique=True, nullable=False)
#classmethod
def unique_hash(cls, name):
return name
#classmethod
def unique_filter(cls, query, name):
return query.filter(Widget.name == name)
Base.metadata.create_all(engine)
session = Session()
w1, w2, w3 = Widget.as_unique(session, name='w1'), \
Widget.as_unique(session, name='w2'), \
Widget.as_unique(session, name='w3')
w1b = Widget.as_unique(session, name='w1')
assert w1 is w1b
assert w2 is not w3
assert w2 is not w1
session.commit()
The recipe goes deeper into the idea and provides different approaches but I've used this one with great success.
The closest semantically is probably:
def get_or_create(model, **kwargs):
"""SqlAlchemy implementation of Django's get_or_create.
"""
session = Session()
instance = session.query(model).filter_by(**kwargs).first()
if instance:
return instance, False
else:
instance = model(**kwargs)
session.add(instance)
session.commit()
return instance, True
not sure how kosher it is to rely on a globally defined Session in sqlalchemy, but the Django version doesn't take a connection so...
The tuple returned contains the instance and a boolean indicating if the instance was created (i.e. it's False if we read the instance from the db).
Django's get_or_create is often used to make sure that global data is available, so I'm committing at the earliest point possible.
I slightly simplified #Kevin. solution to avoid wrapping the whole function in an if/else statement. This way there's only one return, which I find cleaner:
def get_or_create(session, model, **kwargs):
instance = session.query(model).filter_by(**kwargs).first()
if not instance:
instance = model(**kwargs)
session.add(instance)
return instance
There is a Python package that has #erik's solution as well as a version of update_or_create(). https://github.com/enricobarzetti/sqlalchemy_get_or_create
Depending on the isolation level you adopted, none of the above solutions would work.
The best solution I have found is a RAW SQL in the following form:
INSERT INTO table(f1, f2, unique_f3)
SELECT 'v1', 'v2', 'v3'
WHERE NOT EXISTS (SELECT 1 FROM table WHERE f3 = 'v3')
This is transactionally safe whatever the isolation level and the degree of parallelism are.
Beware: in order to make it efficient, it would be wise to have an INDEX for the unique column.
One problem I regularly encounter is when a field has a max length (say, STRING(40)) and you'd like to perform a get or create with a string of large length, the above solutions will fail.
Building off of the above solutions, here's my approach:
from sqlalchemy import Column, String
def get_or_create(self, add=True, flush=True, commit=False, **kwargs):
"""
Get the an entity based on the kwargs or create an entity with those kwargs.
Params:
add: (default True) should the instance be added to the session?
flush: (default True) flush the instance to the session?
commit: (default False) commit the session?
kwargs: key, value pairs of parameters to lookup/create.
Ex: SocialPlatform.get_or_create(**{'name':'facebook'})
returns --> existing record or, will create a new record
---------
NOTE: I like to add this as a classmethod in the base class of my tables, so that
all data models inherit the base class --> functionality is transmitted across
all orm defined models.
"""
# Truncate values if necessary
for key, value in kwargs.items():
# Only use strings
if not isinstance(value, str):
continue
# Only use if it's a column
my_col = getattr(self.__table__.columns, key)
if not isinstance(my_col, Column):
continue
# Skip non strings again here
if not isinstance(my_col.type, String):
continue
# Get the max length
max_len = my_col.type.length
if value and max_len and len(value) > max_len:
# Update the value
value = value[:max_len]
kwargs[key] = value
# -------------------------------------------------
# Make the query...
instance = session.query(self).filter_by(**kwargs).first()
if instance:
return instance
else:
# Max length isn't accounted for here.
# The assumption is that auto-truncation will happen on the child-model
# Or directtly in the db
instance = self(**kwargs)
# You'll usually want to add to the session
if add:
session.add(instance)
# Navigate these with caution
if add and commit:
try:
session.commit()
except IntegrityError:
session.rollback()
elif add and flush:
session.flush()
return instance
I want to get an object from the database if it already exists (based on provided parameters) or create it if it does not.
Django's get_or_create (or source) does this. Is there an equivalent shortcut in SQLAlchemy?
I'm currently writing it out explicitly like this:
def get_or_create_instrument(session, serial_number):
instrument = session.query(Instrument).filter_by(serial_number=serial_number).first()
if instrument:
return instrument
else:
instrument = Instrument(serial_number)
session.add(instrument)
return instrument
Following the solution of #WoLpH, this is the code that worked for me (simple version):
def get_or_create(session, model, **kwargs):
instance = session.query(model).filter_by(**kwargs).first()
if instance:
return instance
else:
instance = model(**kwargs)
session.add(instance)
session.commit()
return instance
With this, I'm able to get_or_create any object of my model.
Suppose my model object is :
class Country(Base):
__tablename__ = 'countries'
id = Column(Integer, primary_key=True)
name = Column(String, unique=True)
To get or create my object I write :
myCountry = get_or_create(session, Country, name=countryName)
That's basically the way to do it, there is no shortcut readily available AFAIK.
You could generalize it ofcourse:
def get_or_create(session, model, defaults=None, **kwargs):
instance = session.query(model).filter_by(**kwargs).one_or_none()
if instance:
return instance, False
else:
params = {k: v for k, v in kwargs.items() if not isinstance(v, ClauseElement)}
params.update(defaults or {})
instance = model(**params)
try:
session.add(instance)
session.commit()
except Exception: # The actual exception depends on the specific database so we catch all exceptions. This is similar to the official documentation: https://docs.sqlalchemy.org/en/latest/orm/session_transaction.html
session.rollback()
instance = session.query(model).filter_by(**kwargs).one()
return instance, False
else:
return instance, True
2020 update (Python 3.9+ ONLY)
Here is a cleaner version with Python 3.9's the new dict union operator (|=)
def get_or_create(session, model, defaults=None, **kwargs):
instance = session.query(model).filter_by(**kwargs).one_or_none()
if instance:
return instance, False
else:
kwargs |= defaults or {}
instance = model(**kwargs)
try:
session.add(instance)
session.commit()
except Exception: # The actual exception depends on the specific database so we catch all exceptions. This is similar to the official documentation: https://docs.sqlalchemy.org/en/latest/orm/session_transaction.html
session.rollback()
instance = session.query(model).filter_by(**kwargs).one()
return instance, False
else:
return instance, True
Note:
Similar to the Django version this will catch duplicate key constraints and similar errors. If your get or create is not guaranteed to return a single result it can still result in race conditions.
To alleviate some of that issue you would need to add another one_or_none() style fetch right after the session.commit(). This still is no 100% guarantee against race conditions unless you also use a with_for_update() or serializable transaction mode.
I've been playing with this problem and have ended up with a fairly robust solution:
def get_one_or_create(session,
model,
create_method='',
create_method_kwargs=None,
**kwargs):
try:
return session.query(model).filter_by(**kwargs).one(), False
except NoResultFound:
kwargs.update(create_method_kwargs or {})
created = getattr(model, create_method, model)(**kwargs)
try:
session.add(created)
session.flush()
return created, True
except IntegrityError:
session.rollback()
return session.query(model).filter_by(**kwargs).one(), False
I just wrote a fairly expansive blog post on all the details, but a few quite ideas of why I used this.
It unpacks to a tuple that tells you if the object existed or not. This can often be useful in your workflow.
The function gives the ability to work with #classmethod decorated creator functions (and attributes specific to them).
The solution protects against Race Conditions when you have more than one process connected to the datastore.
EDIT: I've changed session.commit() to session.flush() as explained in this blog post. Note that these decisions are specific to the datastore used (Postgres in this case).
EDIT 2: I’ve updated using a {} as a default value in the function as this is typical Python gotcha. Thanks for the comment, Nigel! If your curious about this gotcha, check out this StackOverflow question and this blog post.
A modified version of erik's excellent answer
def get_one_or_create(session,
model,
create_method='',
create_method_kwargs=None,
**kwargs):
try:
return session.query(model).filter_by(**kwargs).one(), True
except NoResultFound:
kwargs.update(create_method_kwargs or {})
try:
with session.begin_nested():
created = getattr(model, create_method, model)(**kwargs)
session.add(created)
return created, False
except IntegrityError:
return session.query(model).filter_by(**kwargs).one(), True
Use a nested transaction to only roll back the addition of the new item instead of rolling back everything (See this answer to use nested transactions with SQLite)
Move create_method. If the created object has relations and it is assigned members through those relations, it is automatically added to the session. E.g. create a book, which has user_id and user as corresponding relationship, then doing book.user=<user object> inside of create_method will add book to the session. This means that create_method must be inside with to benefit from an eventual rollback. Note that begin_nested automatically triggers a flush.
Note that if using MySQL, the transaction isolation level must be set to READ COMMITTED rather than REPEATABLE READ for this to work. Django's get_or_create (and here) uses the same stratagem, see also the Django documentation.
This SQLALchemy recipe does the job nice and elegant.
The first thing to do is to define a function that is given a Session to work with, and associates a dictionary with the Session() which keeps track of current unique keys.
def _unique(session, cls, hashfunc, queryfunc, constructor, arg, kw):
cache = getattr(session, '_unique_cache', None)
if cache is None:
session._unique_cache = cache = {}
key = (cls, hashfunc(*arg, **kw))
if key in cache:
return cache[key]
else:
with session.no_autoflush:
q = session.query(cls)
q = queryfunc(q, *arg, **kw)
obj = q.first()
if not obj:
obj = constructor(*arg, **kw)
session.add(obj)
cache[key] = obj
return obj
An example of utilizing this function would be in a mixin:
class UniqueMixin(object):
#classmethod
def unique_hash(cls, *arg, **kw):
raise NotImplementedError()
#classmethod
def unique_filter(cls, query, *arg, **kw):
raise NotImplementedError()
#classmethod
def as_unique(cls, session, *arg, **kw):
return _unique(
session,
cls,
cls.unique_hash,
cls.unique_filter,
cls,
arg, kw
)
And finally creating the unique get_or_create model:
from sqlalchemy import Column, Integer, String, create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
engine = create_engine('sqlite://', echo=True)
Session = sessionmaker(bind=engine)
class Widget(UniqueMixin, Base):
__tablename__ = 'widget'
id = Column(Integer, primary_key=True)
name = Column(String, unique=True, nullable=False)
#classmethod
def unique_hash(cls, name):
return name
#classmethod
def unique_filter(cls, query, name):
return query.filter(Widget.name == name)
Base.metadata.create_all(engine)
session = Session()
w1, w2, w3 = Widget.as_unique(session, name='w1'), \
Widget.as_unique(session, name='w2'), \
Widget.as_unique(session, name='w3')
w1b = Widget.as_unique(session, name='w1')
assert w1 is w1b
assert w2 is not w3
assert w2 is not w1
session.commit()
The recipe goes deeper into the idea and provides different approaches but I've used this one with great success.
The closest semantically is probably:
def get_or_create(model, **kwargs):
"""SqlAlchemy implementation of Django's get_or_create.
"""
session = Session()
instance = session.query(model).filter_by(**kwargs).first()
if instance:
return instance, False
else:
instance = model(**kwargs)
session.add(instance)
session.commit()
return instance, True
not sure how kosher it is to rely on a globally defined Session in sqlalchemy, but the Django version doesn't take a connection so...
The tuple returned contains the instance and a boolean indicating if the instance was created (i.e. it's False if we read the instance from the db).
Django's get_or_create is often used to make sure that global data is available, so I'm committing at the earliest point possible.
I slightly simplified #Kevin. solution to avoid wrapping the whole function in an if/else statement. This way there's only one return, which I find cleaner:
def get_or_create(session, model, **kwargs):
instance = session.query(model).filter_by(**kwargs).first()
if not instance:
instance = model(**kwargs)
session.add(instance)
return instance
There is a Python package that has #erik's solution as well as a version of update_or_create(). https://github.com/enricobarzetti/sqlalchemy_get_or_create
Depending on the isolation level you adopted, none of the above solutions would work.
The best solution I have found is a RAW SQL in the following form:
INSERT INTO table(f1, f2, unique_f3)
SELECT 'v1', 'v2', 'v3'
WHERE NOT EXISTS (SELECT 1 FROM table WHERE f3 = 'v3')
This is transactionally safe whatever the isolation level and the degree of parallelism are.
Beware: in order to make it efficient, it would be wise to have an INDEX for the unique column.
One problem I regularly encounter is when a field has a max length (say, STRING(40)) and you'd like to perform a get or create with a string of large length, the above solutions will fail.
Building off of the above solutions, here's my approach:
from sqlalchemy import Column, String
def get_or_create(self, add=True, flush=True, commit=False, **kwargs):
"""
Get the an entity based on the kwargs or create an entity with those kwargs.
Params:
add: (default True) should the instance be added to the session?
flush: (default True) flush the instance to the session?
commit: (default False) commit the session?
kwargs: key, value pairs of parameters to lookup/create.
Ex: SocialPlatform.get_or_create(**{'name':'facebook'})
returns --> existing record or, will create a new record
---------
NOTE: I like to add this as a classmethod in the base class of my tables, so that
all data models inherit the base class --> functionality is transmitted across
all orm defined models.
"""
# Truncate values if necessary
for key, value in kwargs.items():
# Only use strings
if not isinstance(value, str):
continue
# Only use if it's a column
my_col = getattr(self.__table__.columns, key)
if not isinstance(my_col, Column):
continue
# Skip non strings again here
if not isinstance(my_col.type, String):
continue
# Get the max length
max_len = my_col.type.length
if value and max_len and len(value) > max_len:
# Update the value
value = value[:max_len]
kwargs[key] = value
# -------------------------------------------------
# Make the query...
instance = session.query(self).filter_by(**kwargs).first()
if instance:
return instance
else:
# Max length isn't accounted for here.
# The assumption is that auto-truncation will happen on the child-model
# Or directtly in the db
instance = self(**kwargs)
# You'll usually want to add to the session
if add:
session.add(instance)
# Navigate these with caution
if add and commit:
try:
session.commit()
except IntegrityError:
session.rollback()
elif add and flush:
session.flush()
return instance
I'm new with django and I'm trying to make a unit test where I want to compare a QuerySet before and after a batch editing function call.
def test_batchEditing_9(self):
reset() #reset database for test
query = Game.objects.all()
query_old = Game.objects.all()
dict_value = {'game_code' : '001'}
Utility.batchEditing(Game, query, dict_value)
query_new = Game.objects.all()
self.assertTrue(compareQuerySet(query_old, query_new))
My problem is that query_old will be updated after batchEditing is called. Therefor, both querysets will be the same.
It seem that QuerySet is bound to the current state of the database.
Is this normal?
Is there a way to unbind a QuerySet from the database?
I have tried queryset.values, list(queryset) but it still updates the value.
I'm actually thinking about iterating on the queryset and creating a list of dictionaries by myself, but I want to know if there is an easier way.
Here is batchEditing (didn't paste input validity check)
def batchEditing(model, query, values):
for item in query:
if isinstance(item, model):
for field, val in values.iteritems():
if val is not None:
setattr(item, field, val)
item.save()
Here is compareQuerySet
def compareQuerySet(object1, object2):
list_val1 = object1.values_list()
list_val2 = object2.values_list()
for i in range(len(list_val1)):
if list_val1[i] != list_val2[i]:
return False
return True
A Queryset is essentially just generating SQL and only on evaluating it, the database is hit. As far as I remember, that happens on iterating over the Queryset. For instance,
gamescache = list(Game.objects.all())
or
for g in Game.objects.all():
...
hit the database.
Following code should work:
def test_batchEditing_9(self):
reset() #reset database for test
query = Game.objects.all()
query_old = set(query)
dict_value = {'game_code' : '001'}
Utility.batchEditing(Game, query, dict_value)
query_new = set(query)
self.assertEqual(query_old, query_new)
This is because Game.objects.all() is not hitting database, but just creates object that stores query parameters.
BTW. If you will use order_by in query and order is important you can use list rather than set.