Problem:
I am trying to make a backend web framework as a challenge and I'm trying to implement web
sockets but whenever I make the handshake my browser says that it received no response and the status is 'finished'.
Things I have tried:
I have tried making the request to my own server using the requests module and same Sec Websocket Key as on another website and the key was the same as that other websites response.
I have also tried running it on a node.js server with the same code for the client and the whole thing worked as expected so it's a serverside problem.
Changing browsers and updating browser versions.
Changing from localhost to my private IP and my local machine name.
code
# Ignore missing stuff, some stuff is missing to reduce the length and shouldn't be needed to debug
#app.route('/')
def ws(request: dict):
if request.get('Connection') == 'Upgrade' and request.get('Upgrade') == 'websocket':
# Get returns None if key doesn't exist [] raises keyerror
print(f"Sec-Websocket-Key: {request['Sec-WebSocket-Key']}")
# Prints correct key (:
request['Sec-WebSocket-Key'] += '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
SecWebSocketAccept = hashlib.sha1(request['Sec-WebSocket-Key'].encode())
SecWebSocketAccept = SecWebSocketAccept.digest()
SecWebSocketAccept = base64.b64encode(SecWebSocketAccept).decode()
return HttpResponse(status_code=101,
message='Switching Protocols',
headers= {
'Upgrade': 'websocket',
'Connection': 'Upgrade',
'Sec-WebSocket-Accept': SecWebSocketAccept,
'Sec-WebSocket-Protocol': 'chat'
}, AutoGenerateExtraHeaders=False)
else:
return RenderTemplate('example.html', minimize=False)
# warning: remove minimize only for testing
# HttpResponse class
class HttpResponse:
def __init__(
self,
data: str = '',
status_code: int = 200,
message: str = 'OK',
headers: dict = {},
AutoGenerateExtraHeaders=True
)-> None:
""" todo: add docs on class and methods"""
self.protocol = 'HTTP/1.1'
self.status = f'{status_code} {message}'
self.data = data
self.headers = headers
if AutoGenerateExtraHeaders:
contentmd5 = hashlib.md5(data.encode()).digest()
self.headers['Content-MD5'] = base64.b64encode(contentmd5)
self.headers['Content-Length'] = len(data.encode())
# todo: add more headers such as date,
self.headers = str(headers)[2:].replace('{', '').replace('}', '').replace(', \'', '\r\n\'').replace('\'', '')
# String version of headers
def encode(self) -> bytes:
return str(self).encode()
# Handle Connection Code
def handleConnection(self, connection: socket.socket, address) -> None:
"""handles incoming requests
Args:
connection (socket.socket): client socket connection
address: address of incoming connection
"""
request = connection.recv(1024).decode()
# print(request)
if not request: return
request: dict = self.dictFromRequest(request)
route = self.app.getRoute(request['route']) # returns route object which is made up of a function and a URL defined by app.route(route: str) decor
method: str = request['method']
protocol: str = request['protocol']
# 404 Page not found
# Call event
event = self.app.callEvent(RequestEvent(request))
if event.timeOut: return
# Handle
if protocol == 'HTTP':
# todo: cors and csrf
if method == 'GET':
response = route.call(request)
print(f'encoded: {response.encode()}, \n\ndecoded: {response}')
# ouupt:
# encoded: b'HTTP/1.1 101 Switching Protocols\r\nUpgrade:
# websocket\r\nConnection:
# Upgrade\r\nSec-WebSocket-Accept: TwmkkaET4SyBJad/5OzZNHxaZ/o=\r\nSec
# -WebSocket-Protocol:
# chat\r\n\r\n',
#
# decoded: HTTP/1.1 101 Switching Protocols
# Upgrade: websocket
# Connection: Upgrade
# Sec-WebSocket-Accept: TwmkkaET4SyBJad/5OzZNHxaZ/o=
# Sec-WebSocket-Protocol: chat
#
#
connection.send(response.encode())
return connection.close()
def dictFromRequest(self, headers: str) -> dict:
# todo: just lol
dict = {}
dict['method'] = headers.split('\r\n\r\n')[0].split('\r\n')[0].split(' ')[0]
dict['protocol'] = headers.split('\r\n')[0].split(' ')[-1].split('/')[0]
dict['route'] = headers.split('\r\n\r\n')[0].split('\r\n')[0].split(' ')[1].split('?')[0]
dict['query'] = ''.join(headers.split('\n\n')[0].split('\r\n')[0].split(' ')[1].split('?')[1:])
dict['data'] = headers.split('\r\n\r\n')[1]
for header in headers.split('\r\n\r\n')[0].split('\r\n')[1:]:
dict[header.split(': ')[0]] = header.split(': ')[1]
return dict
P.S. pls don't judge I'm only 13 lol
I'm building a REST API for a simple Todo application using flask and SQLAlchemy as my ORM. I am testing my API using Postman. I'm on a windows 10 64-bit machine.
A GET request works and returns the data that I've entered into my database using python.
I'd like to try to add a task now. But when I POST my request, I receive an error.
My route in flask looks like this.
#add task
#app.route('/todo/api/v1.0/tasks', methods=['POST'])
def create_task():
if not request.json or not 'title' in request.json:
raise InvalidUsage('Not a valid task!', status_code=400)
task = {
'title': request.json['title'],
'description': request.json['description'],
'done': False
}
Todo.add_todo(task)
return jsonify({'task': task}), 201
And the method it's calling on the Todo object looks like this.
def add_todo(_title, _description):
new_todo = Todo(title=_title, description=_description , completed = 0)
db.session.add(new_todo)
db.session.commit()
What I've tried
I thought that maybe the ' in my Postman Params was causing an issue so I removed them. But I still get the same error.
Then I thought that maybe the way that Postman was sending the POST was incorrect so I checked to make sure that the Content-Type headers was correct. It is set to application/json
Finally, to confirm that the issue was that flask didn't like the request, I removed the check in the add task route to make sure the request had a title. So it looks like this.
if not request.json:
And I get the same error. So I think that the problem must be with how I'm actually sending the POST rather than some kind of formatting issue.
My entire code looks like this.
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
import json
from flask import jsonify
from flask import request
app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///todo.db'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
db = SQLAlchemy(app)
class Todo(db.Model):
id = db.Column(db.Integer, primary_key=True)
title = db.Column(db.String(300), unique=False, nullable=False)
description = db.Column(db.String(), unique=False, nullable=False)
completed = db.Column(db.Boolean, nullable=False)
def json(self):
return {'id': self.id,'title': self.title, 'description': self.description, 'completed': self.completed}
def add_todo(_title, _description):
new_todo = Todo(title=_title, description=_description , completed = 0)
db.session.add(new_todo)
db.session.commit()
def get_all_tasks():
return [Todo.json(todo) for todo in Todo.query.all()]
def get_task(_id):
task = Todo.query.filter_by(id=_id).first()
if task is not None:
return Todo.json(task)
else:
raise InvalidUsage('No task found', status_code=400)
def __repr__(self):
return f"Todo('{self.title}')"
class InvalidUsage(Exception):
status_code = 400
def __init__(self, message, status_code=None, payload=None):
Exception.__init__(self)
self.message = message
if status_code is not None:
self.status_code = status_code
self.payload = payload
def to_dict(self):
rv = dict(self.payload or ())
rv['message'] = self.message
return rv
#app.route('/')
def hello_world():
return 'Hello to the World of Flask!'
#get all tasks
#app.route('/todo/api/v1.0/tasks', methods=['GET'])
def get_tasks():
return_value = Todo.get_all_tasks()
return jsonify({'tasks': return_value})
#get specific task
#app.route('/todo/api/v1.0/tasks/<int:task_id>', methods=['GET'])
def get_task(task_id):
task = Todo.get_task(task_id)
#if len(task) == 0:
#raise InvalidUsage('No such task', status_code=404)
return jsonify({'task': task})
#add task
#app.route('/todo/api/v1.0/tasks', methods=['POST'])
def create_task():
if not request.json or not 'title' in request.json:
raise InvalidUsage('Not a valid task!', status_code=400)
task = {
'title': request.json['title'],
'description': request.json['description'],
'done': False
}
Todo.add_todo(task)
return jsonify({'task': task}), 201
#update task
#app.route('/todo/api/v1.0/tasks/<int:task_id>', methods=['PUT'])
def update_task(task_id):
task = [task for task in tasks if task['id'] == task_id]
if len(task) == 0:
raise InvalidUsage('No provided updated', status_code=400)
if not request.json:
raise InvalidUsage('request not valid json', status_code=400)
if 'title' in request.json and type(request.json['title']) != unicode:
raise InvalidUsage('title not unicode', status_code=400)
if 'description' in request.json and type(request.json['description']) != unicode:
raise InvalidUsage('description not unicode', status_code=400)
if 'done' in request.json and type(request.json['done']) is not bool:
raise InvalidUsage('done not boolean', status_code=400)
task[0]['title'] = request.json.get('title', task[0]['title'])
task[0]['description'] = request.json.get('description', task[0]['description'])
task[0]['done'] = request.json.get('done', task[0]['done'])
return jsonify({'task': task[0]})
#delete task
#app.route('/todo/api/v1.0/tasks/<int:task_id>', methods=['DELETE'])
def delete_task(task_id):
task = [task for task in tasks if task['id'] == task_id]
if len(task) == 0:
raise InvalidUsage('No task to delete', status_code=400)
tasks.remove(task[0])
return jsonify({'result': True})
#app.errorhandler(InvalidUsage)
def handle_invalid_usage(error):
response = jsonify(error.to_dict())
response.status_code = error.status_code
return response
if __name__ == '__main__':
app.run(debug=True)
EDIT:
Turns out I wasn't setting the request type in POSTMAN correctly. I've updated it to 'application/json' in the header. Now I'm receiving a different error.
Bad Request Failed to decode JSON object: Expecting value: line 1
column 1 (char 0)
I've tried all the previous steps as before but I continue to get this error.
EDIT 2:
Per a response below, I tried putting the values into the body of the POST. But I still get back a 400 response.
From the image [second postman screenshot] it looks like you pass data in query string but create_task() expects them in request body.
Either replace all occurrences of request.json with request.args in create_task() (to make it work with query params) or leave it as it is and send data in request body.
curl -X POST http://localhost:5000/todo/api/v1.0/tasks \
-H "Content-Type: application/json" \
-d '{"title":"Learn more flask","description":"its supper fun"}'
Also, take a look at Get the data received in a Flask request.
EDITED
Update your add_todo to something like
#classmethod
def add_todo(cls, task):
new_todo = cls(title=task["title"], description=task["description"], completed=0)
db.session.add(new_todo)
db.session.commit()
Related: generalised insert into sqlalchemy using dictionary.
I am trying to pass my session object from one class to another. But I am not sure whats happening.
class CreateSession:
def __init__(self, user, pwd, url="http://url_to_hit"):
self.url = url
self.user = user
self.pwd = pwd
def get_session(self):
sess = requests.Session()
r = sess.get(self.url + "/", auth=(self.user, self.pwd))
print(r.content)
return sess
class TestGet(CreateSession):
def get_response(self):
s = self.get_session()
print(s)
data = s.get(self.url + '/some-get')
print(data.status_code)
print(data)
if __name__ == "__main__":
TestGet(user='user', pwd='pwd').get_response()
I am getting 401 for get_response(). Not able to understand this.
What's a 401?
The response you're getting means that you're unauthorised to access the resource.
A session is used in order to persist headers and other prerequisites throughout requests, why are you creating the session every time rather than storing it in a variable?
As is, the session should work the only issue is that you're trying to call a resource that you don't have access to. - You're not passing the url parameter either in the initialisation.
Example of how you can effectively use Session:
from requests import Session
from requests.exceptions import HTTPError
class TestGet:
__session = None
__username = None
__password = None
def __init__(self, username, password):
self.__username = username
self.__password = password
#property
def session(self):
if self.__session is None:
self.__session = Session()
self.__session.auth = (self.__user, self.__pwd)
return self.__session
#session.setter
def session(self, value):
raise AttributeError('Setting \'session\' attribute is prohibited.')
def get_response(self, url):
try:
response = self.session.get(url)
# raises if the status code is an error - 4xx, 5xx
response.raise_for_status()
return response
except HTTPError as e:
# you received an http error .. handle it here (e contains the request and response)
pass
test_get = TestGet('my_user', 'my_pass')
first_response = test_get.get_response('http://your-website-with-basic-auth.com')
second_response = test_get.get_response('http://another-url.com')
my_session = test_get.session
my_session.get('http://url.com')
Howdie do,
So I have this API that does the following:
1) It receives a XML request. It first parses that request to ensure the request is in the correct format.
2) If the XML is in the correct format, it checks to ensure that the correct user is authenicating.
3) If authenication is successful, it then retreives a URI from a database that the API will send a get request to.
4) If the response is successful, meaning it's a XML reply, it will use XSLT to transform the request into a format
5) It then adds the request to the database and returns the transformed XML to the user that queried the API
I have error handling involved at each step, but the issue is, I've had to nest 5 if else statements to accomplish this.
I know there has to be a better way to rewrite this error handling logic without so many nested if statements, but I'm not sure how. The subsequent steps rely on the previous to ensure that if any error occurs, it's returned properly to the user.
Below is my main Flask-API that I've created. The second file is a module that I've created which does a lot of the error processing. Those functions return a state(True/False) and the response to the main Flask-API.
Can someone give me some ideas on how to rewrite this API without the nesting? The API works 100% and does what it should for catching errors, but I just know there's a better way
Main API:
#dbConnect.app.route('/services/tracking/getShipmentStatus', methods=['POST'])
def parsexml2():
parseStatus, returnValues = func.parseNWRequest(request.data, 'SS')
if parseStatus:
authStatus, authResponse = func.checkAuthorization(returnValues['bu'], request.headers['Authorization'], 'SS')
if authStatus:
getURIStatus, uriResponse = func.getURI('SS')
if getURIStatus:
search = {'bu': returnValues['bu'], 'starttime': returnValues['start'], 'endtime': returnValues['end'],
'requestid': returnValues['requestid'], 'pagesize': returnValues['page']}
responseStatus, depascoResponse = func.sendDepascoRequest(uriResponse, search, 'SS')
if responseStatus:
nakedResponse = func.transformXML(depascoResponse.content, 'transformTracking.xsl')
# Write Request to db
file_name = 'WS_SS_' + returnValues['bu']
request_file_size = request.headers['Content-Length']
if func.addToDb(file_name, 'text/xml', 'SS.Request', returnValues['bu'], 'Y', request.data,
request_file_size) or func.addToDb(file_name, 'text/xml', 'SS.Result',
returnValues['bu'], 'Y', nakedResponse):
pass
return Response(nakedResponse, mimetype='text/xml')
else:
return Response(depascoResponse, mimetype='text/xml')
else:
return Response(uriResponse, mimetype='text/xml')
else:
return Response(authResponse, mimetype='text/xml')
else:
return Response(returnValues, mimetype='text/xml')
Imported module functions:
def transformXML(response, xsl):
xml = ET.tostring(ET.fromstring(response))
xslt = ET.XSLT(ET.parse(xsl))
transformedXML = xslt(ET.XML(xml))
return ET.tostring(transformedXML, pretty_print=True)
def addToDb(filename, mime, docType, customer_code, activeState, document_blob, filesize=None):
try:
response = dbConnect.Documents(file_name=filename, mime_type=mime, file_size=filesize, doc_type=docType,
customer_code=customer_code, is_active=activeState, document_blob=document_blob)
dbConnect.db.session.add(response)
dbConnect.db.session.commit()
dbConnect.db.session.close()
except exc.SQLAlchemyError:
return False
else:
return True
def generateXMLErrorResponse(errorMessage, api):
E = ElementMaker()
if api == 'SS':
GETSHIPMENTSTATUSRESPONSE = E.getShipmentStatusResponse
GETSHIPMENTSTATUSRESULT = E.getShipmentStatusResult
OUTCOME = E.outcome
RESULT = E.result
ERROR = E.error
xml_error = GETSHIPMENTSTATUSRESPONSE(
GETSHIPMENTSTATUSRESULT(
OUTCOME(
RESULT('Failure'),
ERROR(errorMessage)
)
)
)
return ET.tostring(xml_error, pretty_print=True)
elif api == 'IS':
GETINVENTORYSTATUSRESPONSE = E.getInventoryStatusResponse
GETINVENTORYSTATUSRESULT = E.getInventoryStatusResult
OUTCOME = E.outcome
RESULT = E.result
ERROR = E.error
xml_error = GETINVENTORYSTATUSRESPONSE(
GETINVENTORYSTATUSRESULT(
OUTCOME(
RESULT('Failure'),
ERROR(errorMessage)
)
)
)
return ET.tostring(xml_error, pretty_print=True)
def checkAuthorization(bu, headers, status):
error = 'Invalid clientCode for account type'
auth_search = re.search('username="(.*?)"', headers)
auth_user = auth_search.group(1)
if auth_user.upper() != "ADMIN":
if (str(bu).upper() != auth_user.upper()) and status == 'SS':
return False, Response(generateXMLErrorResponse(error, status), mimetype='text/xml')
elif (str(bu).upper() != auth_user.upper()) and status == 'IS':
return False, Response(generateXMLErrorResponse(error, status), mimetype='text/xml')
else:
return True, 'User authenticated'
else:
return True, 'User authenticated'
def getURI(api):
try:
if api == 'SS':
tracking = dbConnect.db.session.query(dbConnect.AppParam).\
filter(dbConnect.AppParam.name == 'TRACKING_WEB_SERVICE_URI').first()
return True, tracking.value
elif api == 'IS':
inventory = dbConnect.db.session.query(dbConnect.AppParam).\
filter(dbConnect.AppParam.name == 'INVENTORY_WEB_SERVICE_URI').first()
return True, inventory.value
except exc.OperationalError:
sendEmail()
return False, generateXMLErrorResponse('Service Unavailable', api)
def sendEmail():
msg = MIMEText('Unable to connect to DB')
msg['Subject'] = "Database server down!"
msg['From'] = ''
msg['To'] = ''
s = smtplib.SMTP('localhost')
s.sendmail(msg['From'], msg['To'], msg.as_string())
s.quit()
return True
def parseNWRequest(nwRequest, api):
returnValues = {}
if api == 'SS':
try:
xml = xmltodict.parse(nwRequest)
returnValues['start'] = xml['getShipmentStatus']['getShipmentStatusRequest']['startTime']
returnValues['end'] = xml['getShipmentStatus']['getShipmentStatusRequest']['endTime']
if validateDate(returnValues['start'], returnValues['end']):
returnValues['bu'] = xml['getShipmentStatus']['getShipmentStatusRequest']['clientCode']
returnValues['page'] = xml['getShipmentStatus']['getShipmentStatusRequest']['pageSize']
returnValues['requestid'] = xml['getShipmentStatus']['getShipmentStatusRequest']['requestId']
return True, returnValues
else:
return False, generateXMLErrorResponse('Invalid startDate/endDate', api)
except xmltodict.expat.ExpatError:
return False, generateXMLErrorResponse('Invalid Formed XML', api)
elif api == 'IS':
try:
xml = xmltodict.parse(request.data)
returnValues['bu'] = xml['getInventoryStatus']['getInventoryStatusRequest']['clientCode']
returnValues['facility'] = xml['getInventoryStatus']['getInventoryStatusRequest']['facility']
return True, returnValues
except xmltodict.expat.ExpatError:
return False, generateXMLErrorResponse('Invalid Formed XML', api)
def validateDate(startDate, endDate):
try:
datetime.strptime(startDate, '%Y-%m-%dT%H:%M:%S')
datetime.strptime(endDate, '%Y-%m-%dT%H:%M:%S')
except ValueError:
return False
else:
return True
def sendDepascoRequest(uri, search, api):
auth=('', '')
depascoResponse = requests.get(uri, auth=auth, params=search)
try:
depascoResponse.raise_for_status()
except requests.exceptions.HTTPError:
return False, generateXMLErrorResponse(depascoResponse.content, api)
else:
return True, requests.get(uri, auth=auth, params=search)
******* UPDATE **********
Thanks to the accepted answer, I removed all layers of the nested if statements. I have my functions just raise a custom exception which is handled in the main program.
__author__ = 'jw1050'
from functions import parseNWRequest, checkAuthorization, getURI, sendDepascoRequest, transformXML, addToDb, APIError
from flask import request
from flask import Response
import dbConnect
import lxml.etree as ET
from sqlalchemy import text
#dbConnect.app.route('/services/inventory/getInventoryStatus', methods=['POST'])
def parsexml():
try:
returnValues = parseNWRequest(request.data, 'IS')
checkAuthorization(returnValues['bu'], request.headers['Authorization'], 'IS')
uri = getURI('IS')
search = {'bu': returnValues['bu'], 'facility': returnValues['facility']}
depascoResponse = sendDepascoRequest(uri, search, 'IS')
s = text("Select sku, allocated from fgw_allocated_sku_count where client_code = :c and fulfillment_location = :t")
result = dbConnect.db.engine.execute(s, c=returnValues['bu'], t=returnValues['facility']).fetchall()
root = ET.fromstring(depascoResponse.content)
for row in result:
for element in root.iter('Item'):
xmlSKU = element.find('SKU').text
if xmlSKU == row[0]:
newQonOrder = int(element.find('QuantityOnOrder').text) + row[1]
element.find('QuantityOnOrder').text = str(newQonOrder)
newQAvailable = int(element.find('QuantityAvailable').text) - newQonOrder
element.find('QuantityAvailable').text = str(newQAvailable)
nakedResponse = transformXML(ET.tostring(root), 'transformInventory.xsl')
# Write Request to db
file_name = 'WS_IS_' + returnValues['bu']
request_file_size = request.headers['Content-Length']
addToDb('SS', file_name, 'text/xml', 'IS.Req', returnValues['bu'], 'Y', request.data, request_file_size)
addToDb('SS', file_name, 'text/xml', 'IS.Result', returnValues['bu'], 'Y', nakedResponse)
return Response(nakedResponse, mimetype='text/xml')
except APIError as error:
return Response(error.errorResponse, mimetype='text/xml')
#dbConnect.app.route('/services/tracking/getShipmentStatus', methods=['POST'])
def parsexml2():
try:
returnValues = parseNWRequest(request.data, 'SS')
checkAuthorization(returnValues['bu'], request.headers['Authorization'], 'SS')
uri = getURI('SS')
search = {'bu': returnValues['bu'], 'starttime': returnValues['start'], 'endtime': returnValues['end'],
'requestid': returnValues['requestid'], 'pagesize': returnValues['page']}
depascoResponse = sendDepascoRequest(uri, search, 'SS')
nakedResponse = transformXML(depascoResponse.content, 'transformTracking.xsl')
# Write Request to db
file_name = 'WS_SS_' + returnValues['bu']
request_file_size = request.headers['Content-Length']
addToDb('SS', file_name, 'text/xml', 'SS.Request', returnValues['bu'], 'Y', request.data, request_file_size)
addToDb('SS', file_name, 'text/xml', 'SS.Result', returnValues['bu'], 'Y', nakedResponse)
return Response(nakedResponse, mimetype='text/xml')
except APIError as error:
return Response(error.errorResponse, mimetype='text/xml')
if __name__ == '__main__':
dbConnect.app.run(host='localhost', port=int("5010"), debug=True)
My custom exception class:
class APIError(Exception):
def __init__(self, errorResponse):
self.errorResponse = errorResponse
An example of a function that raises my custom exception:
def validateDate(startDate, endDate):
try:
datetime.strptime(startDate, '%Y-%m-%dT%H:%M:%S')
datetime.strptime(endDate, '%Y-%m-%dT%H:%M:%S')
except ValueError:
raise APIError(generateXMLErrorResponse('Invalid format for startDate/endDate', 'SS'))
else:
return True
Another way is to raise exceptions and handle them.
import myexceptions
try:
# a complicated hunk of stuff that raises exceptions when
# things don't work out. The exceptions can be raised down
# inside functions, if breaking this lump into functions
# makes it easier to follow.
except MyExceptionA:
return Response( ...) # appropriate response for error A
except MyExceptionB
return Response( ...) # appropriate response for error B
except ...
I like to leave my methods ASAP, so in your case, it would look like:
if !X:
return...
if !Y:
return...
if !Z:
return...
though generally, I prefer
if X:
return...
if Y:
return...
I am trying to write unit testing code for a child of tornado.web.RequestHandler that runs an aggregate query to the database. I have already wasted several days trying to get the tests to work.
The tests are using pytest and factoryboy. A lot of the important tornado class have factories for the tests.
This is the class that is being tested:
class AggregateRequestHandler(StreamlyneRequestHandler):
'''
'''
SUPPORTED_METHODS = (
"GET", "POST", "OPTIONS")
def get(self):
self.aggregate()
#auth.hmac_auth
##tornado.web.asynchronous
#tornado.web.removeslash
#tornado.gen.coroutine
def aggregate(self):
'''
'''
self.logger.info('api aggregate')
data = self.data
print("Data: {0}".format(data))
pipeline = data['pipeline']
self.logger.debug('pipeline : {0}'.format(pipeline))
self.logger.debug('utc tz : {0}'.format(tz_util.utc))
# execute pipeline query
print(self.collection)
try:
cursor_future = self.collection.aggregate(pipeline, cursor={})
print(cursor_future)
cursor = yield cursor_future
print("Cursor: {0}".format(cursor))
except Exception as e:
print(e)
documents = yield cursor.to_list(length=None)
self.logger.debug('results : {0}'.format(documents))
# process MongoDB JSON extended
results = json.loads(json_util.dumps(documents))
pipeline = json.loads(json_util.dumps(pipeline))
response_data = {
'pipeline': pipeline,
'results': results
}
self.respond(response_data)
The method used to test it is here:
##tornado.testing.gen_test
def test_time_inside(self):
current_time = gen_time()
past_time = gen_time() - datetime.timedelta(minutes=20)
test_query = copy.deepcopy(QUERY)
oid = ObjectId("53a72de12fb05c0788545ed6")
test_query[0]['$match']['attribute'] = oid
test_query[0]['$match']['date_created']['$gte'] = past_time
test_query[0]['$match']['date_created']['$lte'] = current_time
request = produce.HTTPRequest(
method="GET",
headers=produce.HTTPHeaders(
kwargs = {
"Content-Type": "application/json",
"Accept": "application/json",
"X-Sl-Organization": "test",
"Hmac": "83275edec557e2a339e0ec624201db604645e1e1",
"X-Sl-Username": "test#test.co",
"X-Sl-Expires": 1602011725
}
),
uri="/api/v1/attribute-data/aggregate?{0}".format(json_util.dumps({
"pipeline": test_query
}))
)
self.ARH = produce.AggregateRequestHandler(request=request)
#io_loop = tornado.ioloop.IOLoop.instance()
self.io_loop.run_sync(self.ARH.get)
#def stop_test():
#self.stop()
#self.ARH.test_get(stop_test)
#self.wait()
output = self.ARH.get_written_output()
assert output == ""
This is the way I set up the factory for the Request Handler:
class OutputTestAggregateRequestHandler(slapi.rest.AggregateRequestHandler, tornado.testing.AsyncTestCase):
'''
'''
_written_output = []
def write(self, chunk):
print("Previously written: {0}".format(self._written_output))
print("Len: {0}".format(len(self._written_output)))
if self._finished:
raise RuntimeError("Cannot write() after finish(). May be caused "
"by using async operations without the "
"#asynchronous decorator.")
if isinstance(chunk, dict):
print("Going to encode a chunk")
chunk = escape.json_encode(chunk)
self.set_header("Content-Type", "application/json; charset=UTF-8")
chunk = escape.utf8(chunk)
print("Writing")
self._written_output = []
self._written_output.append(chunk)
print(chunk)
def flush(self, include_footers=False, callback=None):
pass
def get_written_output(self):
for_return = self._written_output
self._written_output = []
return for_return
class AggregateRequestHandler(StreamlyneRequestHandler):
'''
'''
class Meta:
model = OutputTestAggregateRequestHandler
model = slapi.model.AttributeDatum
When running the tests, the test simply stops in def aggregate(self): somewhere between print(cursor_future) and print("Cursor: {0}".format(cursor)).
The in the stdout you see
MotorCollection(Collection(Database(MongoClient([]), u'test'), u'attribute_datum'))
<tornado.concurrent.Future object at 0x7fbc737993d0>
and nothing else comes out of the test with it failing on
> assert output == ""
E AssertionError: assert [] == ''
After a lot of time looking at documentation and examples and stack overflow I managed to get a functioning test by adding the following code to OutputTestAggregateRequestHandler:
def set_io_loop(self):
self.io_loop = tornado.ioloop.IOLoop.instance()
def ioloop(f):
#functools.wraps(f)
def wrapper(self, *args, **kwargs):
print(args)
self.set_io_loop()
return f(self, *args, **kwargs)
return wrapper
def runTest(self):
pass
Then copying all of the code from AggregateRequestHandler.aggregate into OutputTestAggregateRequestHandler but with different decorators:
#ioloop
#tornado.testing.gen_test
def _aggregate(self):
......
I then received the output:
assert output == ""
E AssertionError: assert ['{\n "pipeline": [\n {\n "$match": {\n "attribute": {\n "$oid"... "$oid": "53cec0e72dc9832c4c4185f2"\n }, \n "quality": 9001\n }\n ]\n}'] == ''
which is actually a success, but I was just triggering an assertion error on purpose to see the output.
The big problem that I have, is how do I achieve the desired outcome, which is the output received by adding the extra code, and copying the aggregate method.
Obviously when copying the code out of the aggregate method the test is no longer useful after I make changes to the actual method. How can I get the actual aggregate method to function properly in the tests instead of stopping seemingly when it encounters asynchronous code?
Thanks for any help,
Cheers!
-Liam
In general, the intended way to test RequestHandlers is with AsyncHTTPTestCase, not AsyncTestCase. This will set up the HTTP client and server for you and everything will go through the HTTP plumbing. Using RequestHandlers outside of an Application and HTTP server is not fully supported, although in Tornado 4.0 it might be feasible to use a dummy HTTPConnection to avoid the full server stack. This might be faster, although it's kind of uncharted territory at this point.