Django writing mock test - python

I am having problem understanding how mock works and how to write unittests with mock objects. I wanted to mock an external api calls so I can write unittests for these functions and functions which use these calls.
I tried to mock check_sms_request at first and later I need to do something with it to cover check_delivery_status for an object.
How do I write test for this cases?
Function to mock
def check_sms_request(remote_id, phone):
if not can_sms(phone):
return None
client = Client(settings.SMS_API_URL)
base_kwargs = {
'phone': phone,
'remoteId': remote_id,
}
request = client.service.CheckSmsRequest(settings.SMS_API_LOGIN, settings.SMS_API_PASSWORD, base_kwargs)
return request
class SMS(models.Model):
sms_response = models.IntegerField(choices=SMS_SEND_STATUS, null=True, blank=True)
delivery_status = models.IntegerField(choices=DELIVERY_STATUS_CHOICES, null=True, blank=True)
has_final_status = models.BooleanField(default=False)
def check_delivery_status(self):
if not self.has_final_status:
response_status = check_sms_request(self.id, self.phone)
if response_status is not None:
self.history.create(fbs_status=self.sms_response, delivery_status=response_status.response)
if response_status is not None and response_status.response in FINAL_STATUSES:
self.has_final_status = True
if response_status is not None:
self.delivery_status = response_status.response
self.save()
return self.delivery_status
My test:
#override_settings(SMS_WHITELIST=['', ], SMS_ENABLE=True)
def test_soap_check_sms_request(self):
check_sms_request = Mock(return_value=True)
check_sms_request.response = SMS_SENT_AND_RECIEVED
self.assertEqual(check_sms_request.response, SMS_SENT_AND_RECIEVED)
obj = SMS.objects.create(**{
'phone': self.user.phone,
'text': u"Hello",
'user': self.user,
'site': self.site2,
})
obj.check_sms_status()

You could monkey-patch the function for the test, like so:
#override_settings(SMS_WHITELIST=['', ], SMS_ENABLE=True)
def test_soap_check_sms_request(self):
check_sms_request = Mock(return_value=True)
check_sms_request.response = SMS_SENT_AND_RECIEVED
self.assertEqual(check_sms_request.response, SMS_SENT_AND_RECIEVED)
obj = SMS.objects.create(**{
'phone': self.user.phone,
'text': u"Hello",
'user': self.user,
'site': self.site2,
})
import model
old_fn = model.check_sms_request
model.check_sms_request = check_sms_request
obj.check_sms_status()
model.check_sms_request = old_fn

test check_sms_request first
mock Client
retrive service and CheckSmsRequest mocks using return_value
call check_sms_request
check Client and CheckSmsRequest was called once with right args
check_sms_request should be method of a class, move it inside SMS model or just add method and call the function from this method
Mock this model method when testing check_delivery_status
Model:
class SMS(models.Model):
sms_response = models.IntegerField(choices=SMS_SEND_STATUS, null=True, blank=True)
delivery_status = models.IntegerField(choices=DELIVERY_STATUS_CHOICES, null=True, blank=True)
has_final_status = models.BooleanField(default=False)
def check_sms_request(remote_id, phone):
if not can_sms(phone):
return None
client = Client(settings.SMS_API_URL)
base_kwargs = {
'phone': phone,
'remoteId': remote_id,
}
request = client.service.CheckSmsRequest(settings.SMS_API_LOGIN, settings.SMS_API_PASSWORD, base_kwargs)
return request
def check_delivery_status(self):
if not self.has_final_status:
response_status = self.check_sms_request(self.id, self.phone)
if response_status is not None:
self.history.create(fbs_status=self.sms_response, delivery_status=response_status.response)
if response_status is not None and response_status.response in FINAL_STATUSES:
self.has_final_status = True
if response_status is not None:
self.delivery_status = response_status.response
self.save()
return self.delivery_status
Test:
class SMSModelTestCase(TestCase):
#patch('...Client')
def test_check_sms_request(self, ClientMock):
client_object = ClientMock.return_value
CheckSmsRequestMock = client_object.service.return_value.CheckSmsRequest
sms_model = SMS() # don't save
with self.settings(SMS_API_URL='http://example.com', SMS_API_LOGIN='kanata', SMS_API_PASSWORD='izumi'):
sms_model.check_sms_request(101, '+11111111')
ClientMock.assert_called_once_with('http://example.com')
CheckSmsRequestMock.assert_called_once_with('kanata', 'izumi', '+11111111', 101)
#patch('myproject.myapp.models.SMS.check_sms_request')
def test_check_delivery_status(self, CheckSmsRequestMock):
CheckSmsRequestMock.return_value = ...
sms_model = SMS()
...
sms_model.check_delivery_status(...)

Related

Django testing fails with object not found in response.context even though it works when actually running

I'm trying to test if my PlayerPoint model can give me the top 5 players in regards to their points.
This is the Player model:
class Player(AbstractUser):
phone_number = models.CharField(
max_length=14,
unique=True,
help_text="Please ensure +251 is included"
)
and this is the PlayerPoint model:
class PlayerPoint(models.Model):
OPERATION_CHOICES = (('ADD', 'ADD'), ('SUB', 'SUBTRACT'), ('RMN', 'REMAIN'))
points = models.IntegerField(null=False, default=0)
operation = models.CharField(
max_length=3,
null=False,
choices=OPERATION_CHOICES,
default=OPERATION_CHOICES[2][0]
)
operation_amount = models.IntegerField(null=False)
operation_reason = models.CharField(null=False, max_length=1500)
player = models.ForeignKey(
settings.AUTH_USER_MODEL,
null=False,
on_delete=models.PROTECT,
to_field="phone_number",
related_name="player_points"
)
points_ts = models.DateTimeField(auto_now_add=True, null=False)
class Meta:
get_latest_by = ['pk', 'points_ts']
I also have a pre-save signal handler:
#receiver(signals.pre_save, sender=PlayerPoint)
def pre_save_PlayerPoint(sender, instance, **_):
if sender is PlayerPoint:
try:
current_point = PlayerPoint.objects.filter(player=instance.player).latest()
except PlayerPoint.DoesNotExist as pdne:
if "new player" in instance.operation_reason.lower():
print(f"{pdne} {instance.player} must be a new")
instance.operation_amount = 100
instance.points = int(instance.points) + int(instance.operation_amount)
else:
raise pdne
except Exception as e:
print(f"{e} while trying to get current_point of the player, stopping execution")
raise e
else:
if instance.operation == PlayerPoint.OPERATION_CHOICES[0][0]:
instance.points = int(current_point.points) + int(instance.operation_amount)
elif instance.operation == PlayerPoint.OPERATION_CHOICES[1][0]:
if int(current_point.points) < int(instance.operation_amount):
raise ValidationError(
message="not enough points",
params={"points": current_point.points},
code="invalid"
)
instance.points = int(current_point.points) - int(instance.operation_amount)
As you can see there is a foreign key relation.
Before running the tests, in the setUp() I create points for all the players as such:
class Top5PlayersViewTestCase(TestCase):
def setUp(self) -> None:
self.player_model = get_user_model()
self.test_client = Client(raise_request_exception=True)
self.player_list = list()
for i in range(0, 10):
x = self.player_model.objects.create_user(
phone_number=f"+2517{i}{i}{i}{i}{i}{i}{i}{i}",
# first_name="test",
# father_name="user",
# grandfather_name="tokko",
# email=f"test_user#tokko7{i}.com",
# age="29",
password="password"
)
PlayerPoint.objects.create(
operation="ADD",
operation_reason="new player",
player=x
)
self.player_list.append(x)
counter = 500
for player in self.player_list:
counter += int(player.phone_number[-1:])
PlayerPoint.objects.create(
operation="ADD",
operation_amount=counter,
operation_reason="add for testing",
player=player
)
PlayerPoint.objects.create(
operation="ADD",
operation_amount=counter,
operation_reason="add for testing",
player=player
)
return super().setUp()
def test_monthly_awarding_view_displays_top5_players(self):
for player in self.player_list:
print(player.player_points.latest())
# self.test_client.post("/accounts/login/", self.test_login_success_data)
test_results = self.test_client.get("/points_app/monthly_award/", follow=True)
self.assertEqual(test_results.status_code, 200)
self.assertTemplateUsed(test_results, "points_app/monthlytop5players.html")
self.assertEqual(len(test_results.context.get('results')), 5)
top_5 = PlayerPoint.objects.order_by('-points')[:5]
for pt in top_5:
self.assertIn(pt, test_results.context.get('results'))
The full traceback is this after running coverage run manage.py test points_app.tests.test_views.MonthlyAwardingViewTestCase.test_monthly_awarding_view_displays_top5_players -v 2:
Traceback (most recent call last):
File "/home/gadd/vscodeworkspace/websites/25X/twenty_five_X/points_app/tests/test_views.py", line 358, in test_monthly_awarding_view_displays_top5_players
self.assertIn(pt, test_results.context.get('results'))
AssertionError: <PlayerPoint: 1190 -- +251799999999> not found in [<User25X: +251700000000>, <User25X: +251711111111>, <User25X: +251722222222>, <User25X: +251733333333>, <User25X: +251744444444>]
This is the view being tested:
def get(request):
all_players = get_user_model().objects.filter(is_staff=False).prefetch_related('player_points')
top_5 = list()
for player in all_players:
try:
latest_points = player.player_points.latest()
except Exception as e:
print(f"{player} -- {e}")
messages.error(request, f"{player} {e}")
else:
if all(
[
latest_points.points >= 500,
latest_points.points_ts.year == current_year,
latest_points.points_ts.month == current_month
]
):
top_5.append(player)
return render(request, "points_app/monthlytop5players.html", {"results": top_5[:5]})
What am I doing wrong?
There are 2 problems.
In your view, top_5 is an unsorted list of Player objects.
top_5 = sorted(top_5, key=lambda player: player.player_points.latest().points, reverse=True)[:5] # Add this
return render(request, "points_app/monthlytop5players.html", {"results": top_5[:5]})
In your test, top_5 is a list (actually QuerySet) of PlayerPoint objects.
results_pts = [player.player_points.latest() for player in test_results.context['results']] # Add this
for pt in top_5:
# self.assertIn(pt, test_results.context.get('results')) # Change this
self.assertIn(pt, results_pts) # to this
I think your problem is with this line:
latest_points = player.player_points.latest()
Specifically, latest(). Like get(), earliest() and latest() raise DoesNotExist if there is no object with the given parameters.
You may need to add get_latest_by to your model's Meta class. Maybe try this:
class PlayerPoint(models.Model):
...
class Meta:
get_latest_by = ['joined_ts']
If you don't want to add this to your model, you could just do it directly:
latest_points = player.player_points.latest('-joined_ts')
if this is the problem.

Why do I get NameError for the column that is defined in SQLAlchemy model

I am developing REST APIs with Flask. One of the tables is modeled as follows:
class AudioSessionModel(db.Model):
__tablename__ = 'audio_session'
id = db.Column('audio_session_id', db.Integer, primary_key = True)
cs_id = db.Column(db.Integer)
session_id = db.Column(db.Integer)
facility = db.Column(db.Integer)
description = db.Column(db.String(400))
def __init__(self, cs_id, session_id, facility):
self.cs_id = cs_id
self.session_id = session_id
self.facility = facility
Business logics are defined in a DAO class:
class AudioSessionDAO(object):
def update(self, data):
audio = AudioSessionModel.query.filter(cs_id == data['CSID'], session_id == data['Session'])
audio.description = data['Desc']
db.session.commit()
return audio
This upate function is called in my endpoint for PUT request:
#api.route('/OperatorAccessment')
class OperatorAssessment(Resource):
#api.expect(assessment)
def put(self):
as_dao = AudioSessionDAO()
as_dao.update(request.json)
The model assessment looks like this:
assessment = api.model('Operator Assessment', {
'CSID': fields.Integer(required=True, description='Central Station ID'),
'Session': fields.Integer(required=True, description='Session ID'),
'Desc': fields.String(description='Description')
})
When I test the PUT request with the following json in request body:
{
"CSID": 1,
"Session": 1,
"Desc": "Siren"
}
I got the following error:
File "C:\Users\xxx_app\model\dao.py", line 63, in update
audio = AudioSessionModel.query.filter(cs_id == data['CSID'], session_id == data['Session'])
NameError: name 'cs_id' is not defined
Apparently, cs_id is defined. Why am I still getting this error?
You have to use the attributes of the class, i.e.
AudioSessionModel.query.filter(
AudioSessionModel.cs_id == data['CSID'],
AudioSessionModel.session_id == data['Session'])
Or filter_by with keyword arguments using just =:
AudioSessionModel.query.filter_by(
cs_id=data['CSID'],
session_id=data['Session'])
See What's the difference between filter and filter_by in SQLAlchemy?

Filter output by model property

In my model, I have calculated property current_tobe_payed
I want to generate CSV report of all rows where my property current_tobe_payed is less than zero
See my view below:
def export_leaseterm_csv(request):
response = HttpResponse(content_type='text/csv')
response['Content-Disposition'] = 'attachment; filename="leaseterm.csv"'
writer = csv.writer(response)
leaseterms = serializers.serialize( "python", LeaseTerm.objects.all())
[obj for obj in leaseterms if obj.current_tobe_payed > 0]
for leaseterm in obj:
writer.writerow(leaseterm['fields'].values())
return response
However, I am getting an error:
'dict' object has no attribute 'current_tobe_payed'
How can I solve this issue?
(also I want to enter only certain fields into CSV and not all the table.)
UPDATE:
See my model below:
class LeaseTerm(CommonInfo):
version = IntegerVersionField( )
start_period = models.ForeignKey(Period, related_name='start_period' )
end_period = models.ForeignKey(Period, related_name='end_period')
lease = models.ForeignKey(Lease)
increase = models.DecimalField(max_digits=7, decimal_places=2)
amount = models.DecimalField(max_digits=7, decimal_places=2)
is_terminated = models.BooleanField(default=False)
# _total = None
_current_period = None
_total_current = None
_total_payment = None
_total_current_payment = None
_total_discount = None
_total_current_discount = None
_current_tobe_payed = None
_current_balance = None
def _get_total(self):
from payment.models import LeasePayment
from conditions.models import LeaseDiscount
total_payment_dict = LeasePayment.objects.filter(leaseterm_id=self.id, is_active = True ).aggregate(Sum('amount'))
if total_payment_dict ['amount__sum']:
total_payment = total_payment_dict['amount__sum']
else:
total_payment = 0
total_discount_dict = LeaseDiscount.objects.filter(leaseterm_id=self.id, is_active = True ).aggregate(Sum('amount'))
if total_discount_dict ['amount__sum']:
total_discount = total_discount_dict['amount__sum']
else:
total_discount = 0
# current = Period.objects.filter( is_active = True, _is_current = True )
current_date=datetime.datetime.now().date()
current_period_dict = Period.objects.filter(start_date__lte=current_date,end_date__gte=current_date, is_active = True ).aggregate(Max('order_value'))
#self._current_period = current_period
if current_period_dict['order_value__max']:
current_period = current_period_dict['order_value__max']
else:
current_period = 0
current_discount_dict = LeaseDiscount.objects.filter(leaseterm_id=self.id,
is_active = True, period_date__gte=self.start_period,
period_date__lte=current_period).aggregate(Sum('amount'))
if current_discount_dict ['amount__sum']:
current_discount = current_discount_dict['amount__sum']
else:
current_discount = 0
current_periods_number = current_period - self.start_period.order_value + 1
current_tobe_payed = current_periods_number * self.amount - current_discount
current_balance = total_payment - current_tobe_payed
self._current_period = current_period
self._total_payment = total_payment
self._total_discount = total_discount
self._current_tobe_payed = current_tobe_payed
self._current_balance = current_balance
#property
def current_tobe_payed(self):
if self._current_tobe_payed is None:
self._get_total()
return self._current_tobe_payed
#property
def current_balance(self):
if self._current_balance is None:
self._get_total()
return self._current_balance
#property
def current_period(self):
if self._current_period is None:
self._get_total()
return self._current_period
#property
def total_payment(self):
if self._total_payment is None:
self._get_total()
return self._total_payment
#property
def total_discount(self):
if self._total_discount is None:
self._get_total()
return self._total_discount
def clean(self):
model = self.__class__
if self.lease_id and (self.is_terminated == False) and (self.is_active == True) and model.objects.filter(lease=self.lease, is_active=True ).exclude(id=self.id).count() == 1:
raise ValidationError('!Lease has a active condition already, Terminate prior to creation of new one'.format(self.lease))
def save(self, *args, **kwargs):
self.full_clean()
return super(LeaseTerm, self).save(*args, **kwargs)
def __unicode__(self):
return u'%s %i %s %s ' % ("term:",self.id, self.start_period, self.end_period)
That's rather lengthy calculation that you have in your get_total method. I count five queries inside that the following bit of code will result in those five queries being executed for each row on your table.
[obj for obj in leaseterms if obj.current_tobe_payed > 0]
So that means you are doing 5000 queries if you have just a 1000 rows in your table. With 10,000 rows, this list comprehension would take a very long time to run.
Solution. Convert your property to a model field.
to_be_payed = models.DecimalField(max_digits=7, decimal_places=2)
I am often telling deves not to save the results of simple calculations into a db column. but yours isn't a simple calculation but a complex one so it deserves a field. YOu can update this field in the save method
def save(self, *args, **kwargs):
self.to_be_payed = self.get_total()
super(LeaseTerm, self).save(*args, **kwargs)
If as you say, the amount to be paid depends on changes to a Payment instance, what you can do is to have a post_save signal on the Payment model to trigger the related LeaseTerm object(s) to be updated. Doing such an update would still be cheaper than doing this calculation 5000 times
You are using a serializer which returns a python dictionary object. It is not an instance of a model. I suggest the following:
EDITED SOLUTION
def export_leaseterm_csv(request):
response = HttpResponse(content_type='text/csv')
response['Content-Disposition'] = 'attachment; filename="leaseterm.csv"'
writer = csv.writer(response)
# get all the LeaseTerm instances
leaseterms = LeaseTerm.objects.all()
# filter based on current_tobe_payed
tobe_payed_terms = [obj for obj in leaseterms if obj.current_tobe_payed > 0]
tobe_payed_dict = serializers.serialize( "python", tobe_payed_terms)
# serialize these objects and write to values to the csv
for term in tobe_payed_dict:
writer.writerow(term['fields'].values())
return response
at the end I did it without signal and without sterilizer
Amount of records in this table will never grow more then 100th .This report is executed only by one person in company once a week. During the testing if the performance will be insufficient I will denormalize other then that I prefer to have it normalized as long as I can.
def export_leaseterm_csv(request):
response = HttpResponse(content_type='text/csv')
response['Content-Disposition'] = 'attachment; filename="leaseterm.csv"'
writer = csv.writer(response)
writer.writerow([
"lease",
"tenant",
"amount",
"current balance",
])
leaseterms = LeaseTerm.objects.filter(is_terminated = False, is_active = True )
tobe_payed_terms = [obj for obj in leaseterms if obj.current_balance < 0]
for term in tobe_payed_terms:
writer.writerow([
term.lease,
term.tenant,
term.amount,
term.current_balance,
])
return response

Django Rest Framework: invalidating a single part of a multi serialized POST

Working with django-rest-framework I'm using a serializer with many=True, checking for items which already exist and invalidating them.
The problem is:
When part of a request is invalid, the whole request is rejected without creating the valid objects.
Sample Payload:
[{'record_timestamp': '2016-03-04T09:46:04', 'reader_serial': u'00000000f9b320ac', 'card_serial': u'048EC71A0F3382', 'gps_latitude': None, 'gps_longitude': None, 'salt': 34, 'reader_record_id': 1063},
{'record_timestamp': '2016-03-04T09:46:06', 'reader_serial': u'00000000f9b320ac', 'card_serial': u'04614B1A0F3382', 'gps_latitude': None, 'gps_longitude': None, 'salt': 34, 'reader_record_id': 1064}]
Sample response:
[{"last_record_id":[2384],"error":["This record already exists"]},{}]
Ideal response:
[{"last_record_id":[2384],"error":["This record already exists"]},{'reader': 10, 'card': 12, 'gps_latitude': None, 'gps_longitude': None, 'reader_record_id': 1064}}]
I'd like the first record to provide the error, but the second record to be correctly created, with the response being the object created.
class CardRecordInputSerializer(serializers.ModelSerializer):
class Meta:
model = CardRecord
fields = ('card', 'reader', 'bus', 'park', 'company', 'client',
'record_timestamp', 'reader_record_id')
read_only_fields = ('card', 'reader', 'bus', 'park', 'company'
'client')
def validate(self, data):
"""
Check that the record is unique
"""
#import ipdb; ipdb.set_trace()
hash_value = data.get("hash_value", None)
if CardRecord.objects.filter(hash_value=hash_value):
raise ValidationError(
detail={"error":"This record already exists",
"last_record_id":data.get("reader_record_id", None)})
else:
return data
def to_internal_value(self, data):
internal_value = super(CardRecordInputSerializer, self)\
.to_internal_value(data)
card_serial = data.get("card_serial", None).upper()
reader_serial = data.get('reader_serial', None).upper()
record_timestamp = data.get('record_timestamp', None)
date_altered = False
record_date = dateutil.parser.parse(record_timestamp)
#check if clock has reset to 1970
if record_date < datetime.datetime(2014, 4, 24):
record_date = datetime.datetime.now().isoformat()
date_altered = True
#create a hash to check that this record is unique
salt = data.get('salt', None)
hash_generator = hashlib.sha1()
hash_generator.update(card_serial)
hash_generator.update(reader_serial)
hash_generator.update(str(record_timestamp))
hash_generator.update(str(salt))
hash_value = str(hash_generator.hexdigest())
internal_value.update({
"card_serial": card_serial,
"reader_serial": reader_serial,
"salt": salt,
"hash_value": hash_value,
"record_timestamp": record_date,
"date_altered": date_altered
})
return internal_value
def create(self, validated_data):
#import ipdb; ipdb.set_trace()
'''
Create a new card transaction record
'''
try:
card_serial = validated_data.get('card_serial', None)
card = Card.objects.filter(uid=card_serial).last()
reader_serial = validated_data.get('reader_serial', None)
reader = Reader.objects.filter(serial=reader_serial).last()
#if we havent seen this reader before, add it to the list
if not reader:
reader = Reader.objects.create(serial=reader_serial)
company = card.company
client = reader.client
park = reader.park
record_timestamp = validated_data.get('record_timestamp', None)
reader_record_id = validated_data.get('reader_record_id', None)
#if datetime is naive, set it to utc
if record_timestamp.tzinfo is None \
or record_timestamp.tzinfo.utcoffset(d) is None:
record_timestamp = pytz.utc.localize(record_timestamp)
hash_value = validated_data.get('hash_value', None)
date_altered = validated_data.get('date_altered', None)
return CardRecord.objects.create(card = card,
reader = reader,
company = company,
client = client,
park = park,
record_timestamp = record_timestamp,
reader_record_id = reader_record_id,
hash_value = hash_value,
date_altered = date_altered)
#Usually a card that doesn't have company
except AttributeError:
return {
'status': 'Bad Request',
'message': 'One of the values was malformed or does not exist.'
}
How can I create valid objects and provide errors for the invalid ones?
I ended up skipping the validation.
Then in my create method if the object already exists I just return it, if it doesn't exist I create it and return it.
The client no longer knows it the server had that record, but thats fine for my use case.
I also swapped to using PUT to reflect the fact that the method is idempotent.
I feel like the validator is the place to do the check but this works.
class CardRecordInputSerializer(serializers.ModelSerializer):
class Meta:
model = CardRecord
fields = ('card', 'reader', 'bus', 'park', 'company', 'client',
'record_timestamp', 'reader_record_id')
read_only_fields = ('card', 'reader', 'bus', 'park', 'company'
'client')
def validate(self, data):
"""
Check that the record is unique
"""
#import ipdb; ipdb.set_trace()
#<--------Removed the validation
return data
def to_internal_value(self, data):
internal_value = super(CardRecordInputSerializer, self)\
.to_internal_value(data)
card_serial = data.get("card_serial", None).upper()
reader_serial = data.get('reader_serial', None).upper()
record_timestamp = data.get('record_timestamp', None)
date_altered = False
record_date = dateutil.parser.parse(record_timestamp)
#check if clock has reset to 1970
if record_date < datetime.datetime(2014, 4, 24):
record_date = datetime.datetime.now().isoformat()
date_altered = True
#create a hash to check that this record is unique
salt = data.get('salt', None)
hash_generator = hashlib.sha1()
hash_generator.update(card_serial)
hash_generator.update(reader_serial)
hash_generator.update(str(record_timestamp))
hash_generator.update(str(salt))
hash_value = str(hash_generator.hexdigest())
internal_value.update({
"card_serial": card_serial,
"reader_serial": reader_serial,
"salt": salt,
"hash_value": hash_value,
"record_timestamp": record_date,
"date_altered": date_altered
})
return internal_value
def create(self, validated_data):
#import ipdb; ipdb.set_trace()
'''
Create a new card transaction record
'''
try:
card_serial = validated_data.get('card_serial', None)
card = Card.objects.filter(uid=card_serial).last()
reader_serial = validated_data.get('reader_serial', None)
reader = Reader.objects.filter(serial=reader_serial).last()
#if we havent seen this reader before, add it to the list
if not reader:
reader = Reader.objects.create(serial=reader_serial)
company = card.company
client = reader.client
park = reader.park
record_timestamp = validated_data.get('record_timestamp', None)
reader_record_id = validated_data.get('reader_record_id', None)
#if datetime is naive, set it to utc
if record_timestamp.tzinfo is None \
or record_timestamp.tzinfo.utcoffset(d) is None:
record_timestamp = pytz.utc.localize(record_timestamp)
hash_value = validated_data.get('hash_value', None)
date_altered = validated_data.get('date_altered', None)
record = CardRecord.objects.filter(hash_value=hash_value).last()
if record: #<--------Check if that object already exists
return record #<-------- if it does just return it
else: #<-------- otherwise make it
return CardRecord.objects.create(
card = card,
reader = reader,
company = company,
client = client,
park = park,
record_timestamp = record_timestamp,
reader_record_id = reader_record_id,
hash_value = hash_value,
date_altered = date_altered)
#Usually a card that doesn't have company
except AttributeError:
return {
'status': 'Bad Request',
'message': 'One of the values was malformed or does not exist.'
}

Python: issues understanding magicmock with unittests

Here is my class:
class WorkflowsCloudant(cloudant.Account):
def __init__(self, account_id):
super(WorkflowsCloudant, self).__init__(settings.COUCH_DB_ACCOUNT_NAME,
auth=(settings.COUCH_PUBLIC_KEY, settings.COUCH_PRIVATE_KEY))
self.db = self.database(settings.COUCH_DB_NAME)
self.account_id = account_id
def get_by_id(self, key, design='by_workflow_id', view='by_workflow_id', limit=None):
params = dict(key=key, include_docs=True, limit=limit)
docs = self.db.design(design).view(view, params=params)
if limit is 1:
doc = [doc['doc'] for doc in docs]
if doc:
workflow = doc[0]
if workflow.get("account_id") != self.account_id:
raise InvalidAccount("Invalid Account")
return workflow
else:
raise NotFound("Autoresponder Cannot Be Found")
return docs
Here is my test:
def test_get_by_id_single_invalid_account(self):
self.klass.account_id = 200
self.klass.db = mock.MagicMock()
self.klass.db.design.return_value.view.return_value = [{
'doc': test_workflow()
}]
# wc.get_by_id = mock.MagicMock(side_effect=InvalidAccount("Invalid Account"))
with self.assertRaises(InvalidAccount("Invalid Account")) as context:
self.klass.get_by_id('workflow_id', limit=1)
self.assertEqual('Invalid Account', str(context.exception))
I'm trying to get the above test to simple raise the exception of InvalidAccount but I'm unsure how to mock out the self.db.design.view response. That's what's causing my test to fail because it's trying to make a real call out
I think this is what you want.
def test_get_by_id_single_invalid_account(self):
self.klass.account_id = 200
self.klass.db = mock.MagicMock()
self.klass.db.design = mock.MagicMock()
view_mock = mock.MagicMock()
view_mock.return_value =[{
'doc': test_workflow()
}]
self.klass.db.design.return_value.view = view_mock
# wc.get_by_id = mock.MagicMock(side_effect=InvalidAccount("Invalid Account"))
with self.assertRaises(InvalidAccount("Invalid Account")) as context:
self.klass.get_by_id('workflow_id', limit=1)
self.assertEqual('Invalid Account', str(context.exception))

Categories