assertion error omitting two identical items - python

As far as I can tell the assert calls for id and goal_id AND my code provides them both...
enter image description here
goal_routes.py
from datetime import datetime from typing import OrderedDict from
urllib.request import OpenerDirector from flask import Blueprint,
jsonify, request, make_response, abort from app import db from
app.models.goal import Goal from app.models.task import Task from
app.task_routes import validate_task
Create a Goal: goal_bp = Blueprint("goal_bp", name, url_prefix="/goals")
#goal_bp.route("", methods = ["POST"]) def create_goals():
request_body = request.get_json()
if "title" in request_body:
new_goal = Goal(
title = request_body["title"]
)
else:
return jsonify({"details":"Invalid data"}), 400
db.session.add(new_goal)
db.session.commit()
goal_response = {"goal": new_goal.to_dictionary()}
return (jsonify(goal_response), 201)
Get Goals #goal_bp.route("", methods = ["GET"]) def get_goals():
sort = request.args.get("sort")
#Sort by assending (is default?)
if sort == "asc":
goals =Goal.query.order_by(Goal.title)
#Sort by decending
elif sort == "desc":
goals =Goal.query.order_by(Goal.title.desc())
#No Sort
else:
goals = Goal.query.all()
goals_response = []
for goal in goals:
goals_response.append(goal.to_dictionary())
# If No Saved Goals wil stil return 200
return (jsonify(goals_response), 200)
Get One Goal: One Saved Goal #goal_bp.route("/<goal_id>", methods=["GET"]) def get_one_goal(goal_id):
goal = validate_goal(goal_id)
goal_response = {"goal": goal.to_dictionary()}
return (jsonify(goal_response), 200)
Update Goal #goal_bp.route("/<goal_id>", methods=["PUT"]) def update_goal(goal_id):
goal = validate_goal(goal_id)
request_body = request.get_json()
goal.title = request_body["title"]
db.session.commit()
goal_response = {"goal": goal.to_dictionary()}
return (jsonify(goal_response), 200)
Goal Complete #goal_bp.route("/<goal_id>/mark_complete", methods=["PATCH"]) def goal_complete(goal_id):
goal = validate_goal(goal_id)
goal.completed_at = datetime.utcnow()
db.session.commit()
goal_response = {"goal": goal.to_dictionary()}
return (jsonify(goal_response), 200)
Goal Incomplete #goal_bp.route("/<goal_id>/mark_incomplete", methods=["PATCH"]) def goal_incomplete(goal_id):
goal = validate_goal(goal_id)
goal.completed_at = None
db.session.commit()
goal_response = {"goal": goal.to_dictionary()}
return (jsonify(goal_response), 200)
Delete Goal: Deleting a Goal #goal_bp.route("/<goal_id>", methods=["DELETE"]) def delete_goal(goal_id):
goal = validate_goal(goal_id)
db.session.delete(goal)
db.session.commit()
response = {"details": f"Goal {goal.goal_id} \"{goal.title}\" successfully deleted"}
return (jsonify(response), 200)
Validate there are no matching Goal: Get, Update, and Delete
def validate_goal(goal_id):
try:
goal_id = int(goal_id)
except:
abort(make_response({"message": f"Goal {goal_id} is invalid"}, 400))
goal = Goal.query.get(goal_id)
if not goal:
abort(make_response({"message": f"Goal {goal_id} not found"}, 404))
return goal
#goal_bp.route("/<goal_id>/tasks", methods=["POST"]) def
post_task_ids_to_goal(goal_id):
goal = validate_goal(goal_id)
request_body = request.get_json()
for task_id in request_body["task_ids"]:
task = Task.query.get(task_id)
task.goal_id = goal_id
task.goal = goal
db.session.commit()
return jsonify({"id":goal.goal_id, "task_ids": request_body["task_ids"]}), 200
#goal_bp.route("/<goal_id>/tasks", methods=["GET"]) def
get_tasks_for_goal(goal_id):
goal = validate_goal(goal_id)
task_list = [task.to_dictionary() for task in goal.tasks]
goal_dict = goal.to_dictionary()
goal_dict["tasks"] = task_list
return jsonify(goal_dict)
goal.py
from app import db
class Goal(db.Model):
goal_id = db.Column(db.Integer, primary_key=True)
title = db.Column(db.String, nullable=False)
tasks = db.relationship("Task", back_populates="goals", lazy = True)
def to_dictionary(self):
goal_dict = {
"id": self.goal_id,
"title": self.title
}
if self.tasks:
goal_dict["tasks"] = [task.task_id for task in self.tasks]
return goal_dict

Related

How to go to the last page of paginated Flask-Admin view by default.. without sorting descending

In Python's Flask-Admin for database table viewing/administrating, I need the view to open automatically to the last page of the paginated data.
Important: I cannot simply sort the records descending so the last record shows first.
Here's what my simple example below looks like. I'd like it to start on the last page, as pictured.
Here's some example code to reproduce my model:
import os
import os.path as op
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
import flask_admin as admin
from flask_admin.contrib.sqla import ModelView
# Create application
app = Flask(__name__)
# Create dummy secrey key so we can use sessions
app.config['SECRET_KEY'] = '123456790'
# Create in-memory database
app.config['DATABASE_FILE'] = 'sample_db2.sqlite'
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' + app.config['DATABASE_FILE']
app.config['SQLALCHEMY_ECHO'] = True
db = SQLAlchemy(app)
# Models
class User(db.Model):
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.Unicode(64))
email = db.Column(db.Unicode(64))
def __unicode__(self):
return self.name
class UserAdmin(ModelView):
column_searchable_list = ('name', 'email')
column_editable_list = ('name', 'email')
column_display_pk = True
can_set_page_size = True
# the number of entries to display in the list view
page_size = 5
# Flask views
#app.route('/')
def index():
return 'Click me to get to Admin!'
admin = admin.Admin(app, 'Example', template_mode='bootstrap3')
# Add views
admin.add_view(UserAdmin(User, db.session))
def build_sample_db():
"""
Populate a small db with some example entries.
"""
db.drop_all()
db.create_all()
first_names = [
'Harry', 'Amelia', 'Oliver', 'Jack', 'Isabella', 'Charlie','Sophie', 'Mia',
'Jacob', 'Thomas', 'Emily', 'Lily', 'Ava', 'Isla', 'Alfie', 'Olivia', 'Jessica',
'Riley', 'William', 'James', 'Geoffrey', 'Lisa', 'Benjamin', 'Stacey', 'Lucy'
]
last_names = [
'Brown', 'Smith', 'Patel', 'Jones', 'Williams', 'Johnson', 'Taylor', 'Thomas',
'Roberts', 'Khan', 'Lewis', 'Jackson', 'Clarke', 'James', 'Phillips', 'Wilson',
'Ali', 'Mason', 'Mitchell', 'Rose', 'Davis', 'Davies', 'Rodriguez', 'Cox', 'Alexander'
]
for i in range(len(first_names)):
user = User()
user.name = first_names[i] + " " + last_names[i]
user.email = first_names[i].lower() + "#example.com"
db.session.add(user)
db.session.commit()
return
if __name__ == '__main__':
# Build a sample db on the fly, if one does not exist yet.
app_dir = op.realpath(os.path.dirname(__file__))
database_path = op.join(app_dir, app.config['DATABASE_FILE'])
if not os.path.exists(database_path):
build_sample_db()
# Start app
app.run(debug=True, host='0.0.0.0')
what about the following idea:
# Flask views
#app.route('/')
def index():
# from sqlalchemy import func
# n=(db.session.query(func.count(User.id)).all())[0][0]
# or
n=len(User.query.all())
last_page=int((n-1)/UserAdmin.page_size)
return f'Click me to get to Admin!'
I've found an ideal solution to my problem, overriding the index_view() template-rendering view method in the flask_admin.contrib.sqla.ModelView class.
class UserAdmin(ModelView):
column_searchable_list = ('name', 'email')
column_editable_list = ('name', 'email')
column_display_pk = True
can_set_page_size = True
# the number of entries to display in the list view
page_size = 5
# Now to override the index_view method
#expose('/')
def index_view(self):
"""List view overridden to DEFAULT to the last page,
if no other request args have been submitted"""
if self.can_delete:
delete_form = self.delete_form()
else:
delete_form = None
# Grab parameters from URL
view_args = self._get_list_extra_args()
# Map column index to column name
sort_column = self._get_column_by_idx(view_args.sort)
if sort_column is not None:
sort_column = sort_column[0]
# Get page size
page_size = view_args.page_size or self.page_size
#####################################################################
# Custom functionality to start on the last page instead of the first
if len(request.args) == 0:
# Standard request for the first page (no additional args)
count_query = self.get_count_query() if not self.simple_list_pager else None
# Calculate number of rows if necessary
count = count_query.scalar() if count_query else None
# Calculate number of pages
if count is not None and page_size:
num_pages = int(ceil(count / float(page_size)))
# Change the page to the last page (minus 1 for zero-based counting)
setattr(view_args, 'page', num_pages - 1)
# End of custom code. The rest below is from the Flask-Admin package
############################################################################
# Get database data for the view_args.page requested
count, data = self.get_list(view_args.page, sort_column, view_args.sort_desc,
view_args.search, view_args.filters, page_size=page_size)
list_forms = {}
if self.column_editable_list:
for row in data:
list_forms[self.get_pk_value(row)] = self.list_form(obj=row)
# Calculate number of pages
if count is not None and page_size:
num_pages = int(ceil(count / float(page_size)))
elif not page_size:
num_pages = 0 # hide pager for unlimited page_size
else:
num_pages = None # use simple pager
# Various URL generation helpers
def pager_url(p):
# Do not add page number if it is first page
if p == 0:
p = None
return self._get_list_url(view_args.clone(page=p))
def sort_url(column, invert=False, desc=None):
if not desc and invert and not view_args.sort_desc:
desc = 1
return self._get_list_url(view_args.clone(sort=column, sort_desc=desc))
def page_size_url(s):
if not s:
s = self.page_size
return self._get_list_url(view_args.clone(page_size=s))
# Actions
actions, actions_confirmation = self.get_actions_list()
if actions:
action_form = self.action_form()
else:
action_form = None
clear_search_url = self._get_list_url(view_args.clone(page=0,
sort=view_args.sort,
sort_desc=view_args.sort_desc,
search=None,
filters=None))
return self.render(
self.list_template,
data=data,
list_forms=list_forms,
delete_form=delete_form,
action_form=action_form,
# List
list_columns=self._list_columns,
sortable_columns=self._sortable_columns,
editable_columns=self.column_editable_list,
list_row_actions=self.get_list_row_actions(),
# Pagination
count=count,
pager_url=pager_url,
num_pages=num_pages,
can_set_page_size=self.can_set_page_size,
page_size_url=page_size_url,
page=view_args.page,
page_size=page_size,
default_page_size=self.page_size,
# Sorting
sort_column=view_args.sort,
sort_desc=view_args.sort_desc,
sort_url=sort_url,
# Search
search_supported=self._search_supported,
clear_search_url=clear_search_url,
search=view_args.search,
search_placeholder=self.search_placeholder(),
# Filters
filters=self._filters,
filter_groups=self._get_filter_groups(),
active_filters=view_args.filters,
filter_args=self._get_filters(view_args.filters),
# Actions
actions=actions,
actions_confirmation=actions_confirmation,
# Misc
enumerate=enumerate,
get_pk_value=self.get_pk_value,
get_value=self.get_list_value,
return_url=self._get_list_url(view_args),
# Extras
extra_args=view_args.extra_args,
)

Create a relationship between two db records in a form in Python / Flask / SQLAlchemy

I am desperately trying to do the following:
Capture details of 2 individuals in a Flaskform (i.e. a father and a mother). Each individual is an instance of an Individual class / model and stored in an SQLAlchemy table
When the second individual is added (could be the father or the mother), then a relationship is created with the ids of the two Individuals, in a Parents table. This creates a relationship record between the father and the mother. I'm happy that's all set up OK.
Main app code is below. I want to make sure I can grab both id's at the same time, to create the relationship (hence the print statements).
Yet the first time I add an individual, two print statements output:
New father ID is 1
New mother ID is None # this is probably expected as I haven't added the second Individual yet
Second time I add an individual, I get:
New father ID is None
New mother ID is 2 # What has happened to the first Individual's ID?
I have tried declaring global variables to write the ids back to but get errors about using a variable before they are declared.
Can anyone help?
from genealogy import app
from flask import render_template, session, redirect, url_for, request
from genealogy import db
from genealogy.models import Individual, Parents
from genealogy.individual.forms import AddIndividual
#app.route("/", methods=["GET", "POST"])
def index():
form = AddIndividual()
def fullname(first, last):
return first + " " + last
if request.method == "POST":
new_father = Individual("",None,None)
new_mother = Individual("",None,None)
new_child = Individual("",None,None)
new_partners = Parents(None,None)
if request.form.get("addfather") == "Add":
father_forenames = form.father_forenames.data
father_surname = form.father_surname.data
father_fullname = fullname(father_forenames, father_surname)
new_father = Individual(father_surname, father_fullname, father_forenames)
db.session.add(new_father)
session["new_father.id"] = new_father.id
db.session.commit()
# db.session.flush()
# if Parents.query.filter_by(father_id=new_father.id, mother_id=new_mother.id):
# pass
# else:
# new_partners = Parents(new_father.id, new_mother.id)
# db.session.add(new_partners)
# db.session.commit()
if request.form.get("addmother") == "Add":
mother_forenames = form.mother_forenames.data
mother_surname = form.mother_surname.data
mother_fullname = fullname(mother_forenames, mother_surname)
new_mother = Individual(mother_surname, mother_fullname, mother_forenames)
db.session.add(new_mother)
session["new_mother.id"] = new_mother.id
db.session.commit()
# db.session.flush()
# if Parents.query.filter_by(father_id=focus_father.id, mother_id=focus_mother.id):
# pass
# else:
# new_partners = Parents(focus_father.id, focus_mother.id)
# db.session.add(new_partners)
# db.session.commit()
if request.form.get("addchild") == "Add":
child_forenames = form.child_forenames.data
child_surname = form.child_surname.data
child_fullname = fullname(child_forenames, child_surname)
new_child = Individual(child_surname, child_fullname, child_forenames)
db.session.add(new_child)
focus_person = new_child
db.session.commit()
print("New father ID is " + str(session["new_father.id"]))
print("New mother ID is " + str(session["new_mother.id"]))
return render_template("home.html", form=form)
# return render_template("home.html", form=form, focus_father=focus_father, focus_mother=focus_mother,
# focus_person=focus_person, focus_partners=focus_partners)
return render_template("home.html", form=form)
if __name__ == "__main__":
app.run(debug=True)
Thanks to #van, here's the working code:
from genealogy import app
from flask import render_template, session, redirect, url_for, request
from genealogy import db
from genealogy.models import Individual, Parents
from genealogy.individual.forms import AddIndividual
#app.route("/", methods=["GET", "POST"])
def index():
form = AddIndividual()
def fullname(first, last):
return first + " " + last
if request.method == "POST":
new_father = Individual("",None,None)
new_mother = Individual("",None,None)
new_child = Individual("",None,None)
new_partners = Parents(None,None)
if request.form.get("addfather") == "Add":
father_forenames = form.father_forenames.data
father_surname = form.father_surname.data
father_fullname = fullname(father_forenames, father_surname)
new_father = Individual(father_surname, father_fullname, father_forenames)
db.session.add(new_father)
db.session.commit()
db.session.flush()
session["new_father.id"] = new_father.id
if request.form.get("addmother") == "Add":
mother_forenames = form.mother_forenames.data
mother_surname = form.mother_surname.data
mother_fullname = fullname(mother_forenames, mother_surname)
new_mother = Individual(mother_surname, mother_fullname, mother_forenames)
db.session.add(new_mother)
db.session.commit()
db.session.flush()
session["new_mother.id"] = new_mother.id
if request.form.get("addchild") == "Add":
child_forenames = form.child_forenames.data
child_surname = form.child_surname.data
child_fullname = fullname(child_forenames, child_surname)
new_child = Individual(child_surname, child_fullname, child_forenames)
db.session.add(new_child)
focus_person = new_child
db.session.commit()
print("New father ID is " + str(session["new_father.id"]))
print("New mother ID is " + str(session["new_mother.id"]))
return render_template("home.html", form=form)
# return render_template("home.html", form=form, focus_father=focus_father, focus_mother=focus_mother,
# focus_person=focus_person, focus_partners=focus_partners)
return render_template("home.html", form=form)
if __name__ == "__main__":
app.run(debug=True)

Flask-SQLALchemy update record automatically after specific time

I have a db models like this:
class Payment(db.Model):
id = db.Column(db.Integer(), primary_key=True)
user_id = db.Column(db.Integer(), db.ForeignKey('user.id'))
ticket_status = db.Column(db.Enum(TicketStatus, name='ticket_status', default=TicketStatus.UNUSED))
departure_time = db.Column(db.Date)
I want to change the value from all ticket_status after datetime.utcnow() passed the date value from departure_time.
I tried to code like this:
class TicketStatus(enum.Enum):
UNUSED = 'UNUSED'
USED = 'USED'
EXPIRED = 'EXPIRED'
def __repr__(self):
return str(self.value)
class Payment(db.Model):
id = db.Column(db.Integer(), primary_key=True)
user_id = db.Column(db.Integer(), db.ForeignKey('user.id'))
ticket_status = db.Column(db.Enum(TicketStatus, name='ticket_status', default=TicketStatus.UNUSED))
departure_time = db.Column(db.Date)
# TODO | set ticket expirations time
def __init__(self):
if datetime.utcnow() > self.departure_time:
self.ticket_status = TicketStatus.EXPIRED.value
try:
db.session.add(self)
db.session.commit()
except Exception as e:
db.session.rollback()
I also tried like this:
def ticket_expiration(self, payment_id):
now = datetime.utcnow().strftime('%Y-%m-%d')
payment = Payment.query.filter_by(id=payment_id).first()
if payment.ticket_status.value == TicketStatus.USED.value:
pass
elif payment and str(payment.departure_time) < now:
payment.ticket_status = TicketStatus.EXPIRED.value
elif payment and str(payment.departure_time) >= now:
payment.ticket_status = TicketStatus.UNUSED.value
try:
db.session.commit()
except Exception as e:
db.session.rollback()
return str('ok')
But it seems no effect when the datetime.utcnow() passed the date value from departure_time.
So the point of my questions is, how to change the value from a row automatically after a set of times..?
Finally I figure out this by using flask_apscheduler, and here is the snippet of my code that solved this questions:
Install flask_apscheduler:
pip3 install flask_apscheduler
create new module tasks.py
from datetime import datetime
from flask_apscheduler import APScheduler
from app import db
from app.models import Payment, TicketStatus
scheduler = APScheduler()
def ticket_expiration():
utc_now = datetime.utcnow().strftime('%Y-%m-%d')
app = scheduler.app
with app.app_context():
payment = Payment.query.all()
for data in payment:
try:
if data.ticket_status.value == TicketStatus.USED.value:
pass
elif str(data.departure_time) < utc_now:
data.ticket_status = TicketStatus.EXPIRED.value
elif str(data.departure_time) >= utc_now:
data.ticket_status = TicketStatus.UNUSED.value
except Exception as e:
print(str(e))
try:
db.session.commit()
except Exception as e:
db.session.rollback()
return str('ok')
and then register the package with the flask app in the __init__.py
def create_app(config_class=Config):
app = Flask(__name__)
app.config.from_object(Config)
# The other packages...
# The other packages...
scheduler.init_app(app)
scheduler.start()
return app
# import from other_module...
# To avoid SQLAlchemy circular import, do the import at the bottom.
from app.tasks import scheduler
And here is for the config.py:
class Config(object):
# The others config...
# The others config...
# Flask-apscheduler
JOBS = [
{
'id': 'ticket_expiration',
'func': 'app.tasks:ticket_expiration',
'trigger': 'interval',
'hours': 1, # call the task function every 1 hours
'replace_existing': True
}
]
SCHEDULER_JOBSTORES = {
'default': SQLAlchemyJobStore(url='sqlite:///flask_context.db')
}
SCHEDULER_API_ENABLED = True
In the config above, we can call the function to update db every 1 hours, seconds or others time according to our case, for more informations to set the interval time we can see it here.
I hope this answer helps someone who facing this in the future.
You may replace your status column with just "used" column which will contain Boolean value and make a hybrid attribute for state. https://docs.sqlalchemy.org/en/13/orm/extensions/hybrid.html
class Payment(db.Model):
id = db.Column(db.Integer(), primary_key=True)
user_id = db.Column(db.Integer(), db.ForeignKey('user.id'))
used = db.Column(db.Boolean(), default=False)
departure_time = db.Column(db.Date)
#hybrid_property
def status(self):
if datetime.utcnow() > self.departure_time:
return "EXPIRED"
elif self.used:
return "USED"
return "UNUSED"

TypeError: Model constructor takes no positional arguments

While creating an answer for a question in the API (see code below), I got the following error:
TypeError: Model constructor takes no positional arguments
Can someone tell me how to solve this? I am using ndb model
import webapp2
import json
from google.appengine.ext import ndb
class AnswerExchange(ndb.Model):
answer=ndb.StringProperty(indexed=False,default="No Message")
class AnswerHandler(webapp2.RequestHandler):
def create_answer(self,question_id):
try:
query = StackExchange.query(StackExchange.questionId == question_id)
questions = query.fetch(1)
ans_json = json.loads(self.request.body)
answerObj = AnswerExchange(answer=ans_json["answer"])
answerObj.put()
questions.answerName=answerObj
questions.put()
except:
raise webapp2.exc.HTTPBadRequest()
class StackExchange(ndb.Model):
questionId=ndb.StringProperty(indexed=True)
questionName=ndb.StringProperty(indexed=False)
#answerID=ndb.StringProperty(indexed=True)
answerName=ndb.StructuredProperty(AnswerExchange,repeated=True)
class StackHandler(webapp2.RequestHandler):
def get_questions(self):
#p = self.request.get('p') #get QueryString Parameter
#self.response.write(p)
query = StackExchange.query()
questions = query.fetch(6)
self.response.content_type = 'application/json'
self.response.write(json.dumps([d.to_dict() for d in questions]))
def get_question(self, question_id):
query = StackExchange.query(StackExchange.questionId == question_id)
questions = query.fetch(1)
self.response.content_type = 'application/json'
self.response.write(json.dumps([d.to_dict() for d in questions]))
def create_question(self):
try:
question_json = json.loads(self.request.body)
# if id in dept_json:
questionNo = question_json["questionId"]
questionToUpdate = StackExchange.query(StackExchange.questionId == questionNo);
#print deptToUpdate.count()
if questionToUpdate.count() == 0:
question = StackExchange(questionId=question_json["questionId"], questionName=question_json["questionName"])
else:
question = questionToUpdate.get()
question.questionName = question_json["questionName"]
#key = dept.put()
question.put()
except:
raise webapp2.exc.HTTPBadRequest()
self.response.headers['Location'] = webapp2.uri_for('get_question', question_id=questionNo)
self.response.status_int = 201
self.response.content_type = 'application/json'
#self.response.write(json.dumps())
def create_answer(self,question_id):
#try:
question_json = json.loads(self.request.body)
questionNo = question_json["questionId"]
query = StackExchange.query(StackExchange.questionId == questionNo)
questions = query.fetch(1)
ans_json = json.loads(self.request.body)
answerObj = AnswerExchange(ans_json["answer"])
#answerObj.put()
#self.response.write(ans_json["answerName"])
questions.answerName = answerObj
questions.put();
#except:
# raise webapp2.exc.HTTPBadRequest()
def get_answer(self, question_id):
query = StackExchange.query(StackExchange.questionId == question_id)
questions = query.fetch(1)
self.response.content_type = 'application/json'
self.response.write(json.dumps([d.to_dict() for d in questions]))
app = webapp2.WSGIApplication([
webapp2.Route(r'/api/v1/questions/<question_id>', methods=['GET'], handler='api.StackHandler:get_question', name='get_question'),
webapp2.Route(r'/api/v1/questions', methods=['GET'], handler='api.StackHandler:get_questions', name='get_questions'),
webapp2.Route(r'/api/v1/questions', methods=['POST'], handler='api.StackHandler:create_question', name='create_question'),
webapp2.Route(r'/api/v1/questions/<question_id>/answers', methods=['POST'], handler='api.StackHandler:create_answer', name='create_answer'),
webapp2.Route(r'/api/v1/questions/<question_id>/answers', methods=['GET'], handler='api.StackHandler:get_answer', name='get_answer')
], debug=True)
Change,
answerObj = AnswerExchange(ans_json["answer"])
in create_answer method StackHandler class
to
answerObj = AnswerExchange(answer=ans_json["answer"])

Python: issues understanding magicmock with unittests

Here is my class:
class WorkflowsCloudant(cloudant.Account):
def __init__(self, account_id):
super(WorkflowsCloudant, self).__init__(settings.COUCH_DB_ACCOUNT_NAME,
auth=(settings.COUCH_PUBLIC_KEY, settings.COUCH_PRIVATE_KEY))
self.db = self.database(settings.COUCH_DB_NAME)
self.account_id = account_id
def get_by_id(self, key, design='by_workflow_id', view='by_workflow_id', limit=None):
params = dict(key=key, include_docs=True, limit=limit)
docs = self.db.design(design).view(view, params=params)
if limit is 1:
doc = [doc['doc'] for doc in docs]
if doc:
workflow = doc[0]
if workflow.get("account_id") != self.account_id:
raise InvalidAccount("Invalid Account")
return workflow
else:
raise NotFound("Autoresponder Cannot Be Found")
return docs
Here is my test:
def test_get_by_id_single_invalid_account(self):
self.klass.account_id = 200
self.klass.db = mock.MagicMock()
self.klass.db.design.return_value.view.return_value = [{
'doc': test_workflow()
}]
# wc.get_by_id = mock.MagicMock(side_effect=InvalidAccount("Invalid Account"))
with self.assertRaises(InvalidAccount("Invalid Account")) as context:
self.klass.get_by_id('workflow_id', limit=1)
self.assertEqual('Invalid Account', str(context.exception))
I'm trying to get the above test to simple raise the exception of InvalidAccount but I'm unsure how to mock out the self.db.design.view response. That's what's causing my test to fail because it's trying to make a real call out
I think this is what you want.
def test_get_by_id_single_invalid_account(self):
self.klass.account_id = 200
self.klass.db = mock.MagicMock()
self.klass.db.design = mock.MagicMock()
view_mock = mock.MagicMock()
view_mock.return_value =[{
'doc': test_workflow()
}]
self.klass.db.design.return_value.view = view_mock
# wc.get_by_id = mock.MagicMock(side_effect=InvalidAccount("Invalid Account"))
with self.assertRaises(InvalidAccount("Invalid Account")) as context:
self.klass.get_by_id('workflow_id', limit=1)
self.assertEqual('Invalid Account', str(context.exception))

Categories