convert synchronous websocket callbacks to async in python using asyncio - python

need you assistance to convert asynchronous websocket callbacks to async using asyncio in python 3.6.
Some brokers in India provide a websocket for tick data to their users and users can further use this tick data, one such brokers sample code(from their git repo) is pasted here
`from smartapi import SmartWebSocket
# feed_token=092017047
FEED_TOKEN="YOUR_FEED_TOKEN"
CLIENT_CODE="YOUR_CLIENT_CODE"
# token="mcx_fo|224395"
token="EXCHANGE|TOKEN_SYMBOL" #SAMPLE: nse_cm|2885&nse_cm|1594&nse_cm|11536&nse_cm|3045
# token="mcx_fo|226745&mcx_fo|220822&mcx_fo|227182&mcx_fo|221599"
task="mw" # mw|sfi|dp
ss = SmartWebSocket(FEED_TOKEN, CLIENT_CODE)
def on_message(ws, message):
print("Ticks: {}".format(message))
def on_open(ws):
print("on open")
ss.subscribe(task,token)
def on_error(ws, error):
print(error)
def on_close(ws):
print("Close")
# Assign the callbacks.
ss._on_open = on_open
ss._on_message = on_message
ss._on_error = on_error
ss._on_close = on_close
ss.connect()
so what I think the broker is doing is that it provides messages/ticks aka events that probably trigger a callback function on every received messages/tick in a synchronous way.
I however want to alter the above code so that it can work in an async manner, I have tried to write in a manner how binance provides such as async with websocketName('currencyPair')as ts: but this doesn't seem to work with my broker.
Appreciate if you can share some ideas/code/insights around this.
Thank you for your time.
Edit 1:
Thank you for your response Dirn,
Yes, I've tried this
from smartapi import SmartConnect
from smartapi import SmartWebSocket
import logging
import asyncio
CLIENT_CODE = Code_provided_by_broker
PASSWORD = My_Password
obj=SmartConnect(api_key=My_api_key)
data = obj.generateSession(CLIENT_CODE, PASSWORD)
refreshToken = data['data']['refreshToken']
print("refreshToken", refreshToken)
feedToken=obj.getfeedToken()
print("feedToken", feedToken)
FEED_TOKEN= feedToken
token="nse_cm|3499"
# token="nse_fo|53595"
task = 'mw'
async def on_message(ws, message):
print("Ticks: {}".format(message))
async def chk_msg():
# ss = SmartWebSocket(FEED_TOKEN, CLIENT_CODE)
# ss.subscribe(task,token)
# ss._on_message = on_message
# ss.connect()
# print("Ticks: {}".format(message))
async with SmartWebSocket(FEED_TOKEN, CLIENT_CODE) as ss: #throws an error: AttributeError: __aexit__
ss.subscribe(task,token)
ss.connect()
while True:
ss._on_message = on_message # not sure how to await a non-async method _on_message, also there is no use of async module in the modules(imported at start of the script) SmartWebSocket and SmartConnect
if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(chk_msg())
here is the smartConnect.py
from six.moves.urllib.parse import urljoin
import sys
import csv
import json
import dateutil.parser
import hashlib
import logging
import datetime
import smartapi.smartExceptions as ex
import requests
from requests import get
import re, uuid
import socket
import platform
from smartapi.version import __version__, __title__
log = logging.getLogger(__name__)
#user_sys=platform.system()
#print("the system",user_sys)
class SmartConnect(object):
#_rootUrl = "https://openapisuat.angelbroking.com"
_rootUrl="https://apiconnect.angelbroking.com" #prod endpoint
#_login_url ="https://smartapi.angelbroking.com/login"
_login_url="https://smartapi.angelbroking.com/publisher-login" #prod endpoint
_default_timeout = 7 # In seconds
_routes = {
"api.login":"/rest/auth/angelbroking/user/v1/loginByPassword",
"api.logout":"/rest/secure/angelbroking/user/v1/logout",
"api.token": "/rest/auth/angelbroking/jwt/v1/generateTokens",
"api.refresh": "/rest/auth/angelbroking/jwt/v1/generateTokens",
"api.user.profile": "/rest/secure/angelbroking/user/v1/getProfile",
"api.order.place": "/rest/secure/angelbroking/order/v1/placeOrder",
"api.order.modify": "/rest/secure/angelbroking/order/v1/modifyOrder",
"api.order.cancel": "/rest/secure/angelbroking/order/v1/cancelOrder",
"api.order.book":"/rest/secure/angelbroking/order/v1/getOrderBook",
"api.ltp.data": "/rest/secure/angelbroking/order/v1/getLtpData",
"api.trade.book": "/rest/secure/angelbroking/order/v1/getTradeBook",
"api.rms.limit": "/rest/secure/angelbroking/user/v1/getRMS",
"api.holding": "/rest/secure/angelbroking/portfolio/v1/getHolding",
"api.position": "/rest/secure/angelbroking/order/v1/getPosition",
"api.convert.position": "/rest/secure/angelbroking/order/v1/convertPosition",
"api.gtt.create":"/gtt-service/rest/secure/angelbroking/gtt/v1/createRule",
"api.gtt.modify":"/gtt-service/rest/secure/angelbroking/gtt/v1/modifyRule",
"api.gtt.cancel":"/gtt-service/rest/secure/angelbroking/gtt/v1/cancelRule",
"api.gtt.details":"/rest/secure/angelbroking/gtt/v1/ruleDetails",
"api.gtt.list":"/rest/secure/angelbroking/gtt/v1/ruleList",
"api.candle.data":"/rest/secure/angelbroking/historical/v1/getCandleData"
}
try:
clientPublicIp= " " + get('https://api.ipify.org').text
if " " in clientPublicIp:
clientPublicIp=clientPublicIp.replace(" ","")
hostname = socket.gethostname()
clientLocalIp=socket.gethostbyname(hostname)
except Exception as e:
print("Exception while retriving IP Address,using local host IP address",e)
finally:
clientPublicIp="106.193.147.98"
clientLocalIp="127.0.0.1"
clientMacAddress=':'.join(re.findall('..', '%012x' % uuid.getnode()))
accept = "application/json"
userType = "USER"
sourceID = "WEB"
def __init__(self, api_key=None, access_token=None, refresh_token=None,feed_token=None, userId=None, root=None, debug=False, timeout=None, proxies=None, pool=None, disable_ssl=False,accept=None,userType=None,sourceID=None,Authorization=None,clientPublicIP=None,clientMacAddress=None,clientLocalIP=None,privateKey=None):
self.debug = debug
self.api_key = api_key
self.session_expiry_hook = None
self.disable_ssl = disable_ssl
self.access_token = access_token
self.refresh_token = refresh_token
self.feed_token = feed_token
self.userId = userId
self.proxies = proxies if proxies else {}
self.root = root or self._rootUrl
self.timeout = timeout or self._default_timeout
self.Authorization= None
self.clientLocalIP=self.clientLocalIp
self.clientPublicIP=self.clientPublicIp
self.clientMacAddress=self.clientMacAddress
self.privateKey=api_key
self.accept=self.accept
self.userType=self.userType
self.sourceID=self.sourceID
if pool:
self.reqsession = requests.Session()
reqadapter = requests.adapters.HTTPAdapter(**pool)
self.reqsession.mount("https://", reqadapter)
print("in pool")
else:
self.reqsession = requests
# disable requests SSL warning
requests.packages.urllib3.disable_warnings()
def requestHeaders(self):
return{
"Content-type":self.accept,
"X-ClientLocalIP": self.clientLocalIp,
"X-ClientPublicIP": self.clientPublicIp,
"X-MACAddress": self.clientMacAddress,
"Accept": self.accept,
"X-PrivateKey": self.privateKey,
"X-UserType": self.userType,
"X-SourceID": self.sourceID
}
def setSessionExpiryHook(self, method):
if not callable(method):
raise TypeError("Invalid input type. Only functions are accepted.")
self.session_expiry_hook = method
def getUserId():
return userId
def setUserId(self,id):
self.userId=id
def setAccessToken(self, access_token):
self.access_token = access_token
def setRefreshToken(self, refresh_token):
self.refresh_token = refresh_token
def setFeedToken(self,feedToken):
self.feed_token=feedToken
def getfeedToken(self):
return self.feed_token
def login_url(self):
"""Get the remote login url to which a user should be redirected to initiate the login flow."""
return "%s?api_key=%s" % (self._login_url, self.api_key)
def _request(self, route, method, parameters=None):
"""Make an HTTP request."""
params = parameters.copy() if parameters else {}
uri =self._routes[route].format(**params)
url = urljoin(self.root, uri)
# Custom headers
headers = self.requestHeaders()
if self.access_token:
# set authorization header
auth_header = self.access_token
headers["Authorization"] = "Bearer {}".format(auth_header)
if self.debug:
log.debug("Request: {method} {url} {params} {headers}".format(method=method, url=url, params=params, headers=headers))
try:
r = requests.request(method,
url,
data=json.dumps(params) if method in ["POST", "PUT"] else None,
params=json.dumps(params) if method in ["GET", "DELETE"] else None,
headers=headers,
verify=not self.disable_ssl,
allow_redirects=True,
timeout=self.timeout,
proxies=self.proxies)
except Exception as e:
raise e
if self.debug:
log.debug("Response: {code} {content}".format(code=r.status_code, content=r.content))
# Validate the content type.
if "json" in headers["Content-type"]:
try:
data = json.loads(r.content.decode("utf8"))
except ValueError:
raise ex.DataException("Couldn't parse the JSON response received from the server: {content}".format(
content=r.content))
# api error
if data.get("error_type"):
# Call session hook if its registered and TokenException is raised
if self.session_expiry_hook and r.status_code == 403 and data["error_type"] == "TokenException":
self.session_expiry_hook()
# native errors
exp = getattr(ex, data["error_type"], ex.GeneralException)
raise exp(data["message"], code=r.status_code)
return data
elif "csv" in headers["Content-type"]:
return r.content
else:
raise ex.DataException("Unknown Content-type ({content_type}) with response: ({content})".format(
content_type=headers["Content-type"],
content=r.content))
def _deleteRequest(self, route, params=None):
"""Alias for sending a DELETE request."""
return self._request(route, "DELETE", params)
def _putRequest(self, route, params=None):
"""Alias for sending a PUT request."""
return self._request(route, "PUT", params)
def _postRequest(self, route, params=None):
"""Alias for sending a POST request."""
return self._request(route, "POST", params)
def _getRequest(self, route, params=None):
"""Alias for sending a GET request."""
return self._request(route, "GET", params)
def generateSession(self,clientCode,password):
params={"clientcode":clientCode,"password":password}
loginResultObject=self._postRequest("api.login",params)
if loginResultObject['status']==True:
jwtToken=loginResultObject['data']['jwtToken']
self.setAccessToken(jwtToken)
refreshToken=loginResultObject['data']['refreshToken']
feedToken=loginResultObject['data']['feedToken']
self.setRefreshToken(refreshToken)
self.setFeedToken(feedToken)
user=self.getProfile(refreshToken)
id=user['data']['clientcode']
#id='D88311'
self.setUserId(id)
user['data']['jwtToken']="Bearer "+jwtToken
user['data']['refreshToken']=refreshToken
return user
else:
return loginResultObject
def terminateSession(self,clientCode):
logoutResponseObject=self._postRequest("api.logout",{"clientcode":clientCode})
return logoutResponseObject
def generateToken(self,refresh_token):
response=self._postRequest('api.token',{"refreshToken":refresh_token})
jwtToken=response['data']['jwtToken']
feedToken=response['data']['feedToken']
self.setFeedToken(feedToken)
self.setAccessToken(jwtToken)
return response
def renewAccessToken(self):
response =self._postRequest('api.refresh', {
"jwtToken": self.access_token,
"refreshToken": self.refresh_token,
})
tokenSet={}
if "jwtToken" in response:
tokenSet['jwtToken']=response['data']['jwtToken']
tokenSet['clientcode']=self. userId
tokenSet['refreshToken']=response['data']["refreshToken"]
return tokenSet
def getProfile(self,refreshToken):
user=self._getRequest("api.user.profile",{"refreshToken":refreshToken})
return user
def placeOrder(self,orderparams):
params=orderparams
for k in list(params.keys()):
if params[k] is None :
del(params[k])
orderResponse= self._postRequest("api.order.place", params)['data']['orderid']
return orderResponse
def modifyOrder(self,orderparams):
params = orderparams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
orderResponse= self._postRequest("api.order.modify", params)
return orderResponse
def cancelOrder(self, order_id,variety):
orderResponse= self._postRequest("api.order.cancel", {"variety": variety,"orderid": order_id})
return orderResponse
def ltpData(self,exchange,tradingsymbol,symboltoken):
params={
"exchange":exchange,
"tradingsymbol":tradingsymbol,
"symboltoken":symboltoken
}
ltpDataResponse= self._postRequest("api.ltp.data",params)
return ltpDataResponse
def orderBook(self):
orderBookResponse=self._getRequest("api.order.book")
return orderBookResponse
def tradeBook(self):
tradeBookResponse=self._getRequest("api.trade.book")
return tradeBookResponse
def rmsLimit(self):
rmsLimitResponse= self._getRequest("api.rms.limit")
return rmsLimitResponse
def position(self):
positionResponse= self._getRequest("api.position")
return positionResponse
def holding(self):
holdingResponse= self._getRequest("api.holding")
return holdingResponse
def convertPosition(self,positionParams):
params=positionParams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
convertPositionResponse= self._postRequest("api.convert.position",params)
return convertPositionResponse
def gttCreateRule(self,createRuleParams):
params=createRuleParams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
createGttRuleResponse=self._postRequest("api.gtt.create",params)
#print(createGttRuleResponse)
return createGttRuleResponse['data']['id']
def gttModifyRule(self,modifyRuleParams):
params=modifyRuleParams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
modifyGttRuleResponse=self._postRequest("api.gtt.modify",params)
#print(modifyGttRuleResponse)
return modifyGttRuleResponse['data']['id']
def gttCancelRule(self,gttCancelParams):
params=gttCancelParams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
#print(params)
cancelGttRuleResponse=self._postRequest("api.gtt.cancel",params)
#print(cancelGttRuleResponse)
return cancelGttRuleResponse
def gttDetails(self,id):
params={
"id":id
}
gttDetailsResponse=self._postRequest("api.gtt.details",params)
return gttDetailsResponse
def gttLists(self,status,page,count):
if type(status)== list:
params={
"status":status,
"page":page,
"count":count
}
gttListResponse=self._postRequest("api.gtt.list",params)
#print(gttListResponse)
return gttListResponse
else:
message="The status param is entered as" +str(type(status))+". Please enter status param as a list i.e., status=['CANCELLED']"
return message
def getCandleData(self,historicDataParams):
params=historicDataParams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
getCandleDataResponse=self._postRequest("api.candle.data",historicDataParams)
return getCandleDataResponse
def _user_agent(self):
return (__title__ + "-python/").capitalize() + __version__
the smartApiWebsocket.py
import websocket
import six
import base64
import zlib
import datetime
import time
import json
import threading
import ssl
class SmartWebSocket(object):
ROOT_URI='wss://wsfeeds.angelbroking.com/NestHtml5Mobile/socket/stream'
HB_INTERVAL=30
HB_THREAD_FLAG=False
WS_RECONNECT_FLAG=False
feed_token=None
client_code=None
ws=None
task_dict = {}
def __init__(self, FEED_TOKEN, CLIENT_CODE):
self.root = self.ROOT_URI
self.feed_token = FEED_TOKEN
self.client_code = CLIENT_CODE
if self.client_code == None or self.feed_token == None:
return "client_code or feed_token or task is missing"
def _subscribe_on_open(self):
request = {"task": "cn", "channel": "NONLM", "token": self.feed_token, "user": self.client_code,
"acctid": self.client_code}
print(request)
self.ws.send(
six.b(json.dumps(request))
)
thread = threading.Thread(target=self.run, args=())
thread.daemon = True
thread.start()
def run(self):
while True:
# More statements comes here
if self.HB_THREAD_FLAG:
break
print(datetime.datetime.now().__str__() + ' : Start task in the background')
self.heartBeat()
time.sleep(self.HB_INTERVAL)
def subscribe(self, task, token):
# print(self.task_dict)
self.task_dict.update([(task,token),])
# print(self.task_dict)
if task in ("mw", "sfi", "dp"):
strwatchlistscrips = token # dynamic call
try:
request = {"task": task, "channel": strwatchlistscrips, "token": self.feed_token,
"user": self.client_code, "acctid": self.client_code}
self.ws.send(
six.b(json.dumps(request))
)
return True
except Exception as e:
self._close(reason="Error while request sending: {}".format(str(e)))
raise
else:
print("The task entered is invalid, Please enter correct task(mw,sfi,dp) ")
def resubscribe(self):
for task, marketwatch in self.task_dict.items():
print(task, '->', marketwatch)
try:
request = {"task": task, "channel": marketwatch, "token": self.feed_token,
"user": self.client_code, "acctid": self.client_code}
self.ws.send(
six.b(json.dumps(request))
)
return True
except Exception as e:
self._close(reason="Error while request sending: {}".format(str(e)))
raise
def heartBeat(self):
try:
request = {"task": "hb", "channel": "", "token": self.feed_token, "user": self.client_code,
"acctid": self.client_code}
print(request)
self.ws.send(
six.b(json.dumps(request))
)
except:
print("HeartBeat Sending Failed")
# time.sleep(60)
def _parse_text_message(self, message):
"""Parse text message."""
data = base64.b64decode(message)
try:
data = bytes((zlib.decompress(data)).decode("utf-8"), 'utf-8')
data = json.loads(data.decode('utf8').replace("'", '"'))
data = json.loads(json.dumps(data, indent=4, sort_keys=True))
except ValueError:
return
# return data
if data:
self._on_message(self.ws,data)
def connect(self):
# websocket.enableTrace(True)
self.ws = websocket.WebSocketApp(self.ROOT_URI,
on_message=self.__on_message,
on_close=self.__on_close,
on_open=self.__on_open,
on_error=self.__on_error)
self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def __on_message(self, ws, message):
self._parse_text_message(message)
# print(msg)
def __on_open(self, ws):
print("__on_open################")
self.HB_THREAD_FLAG = False
self._subscribe_on_open()
if self.WS_RECONNECT_FLAG:
self.WS_RECONNECT_FLAG = False
self.resubscribe()
else:
self._on_open(ws)
def __on_close(self, ws):
self.HB_THREAD_FLAG = True
print("__on_close################")
self._on_close(ws)
def __on_error(self, ws, error):
if ( "timed" in str(error) ) or ( "Connection is already closed" in str(error) ) or ( "Connection to remote host was lost" in str(error) ):
self.WS_RECONNECT_FLAG = True
self.HB_THREAD_FLAG = True
if (ws is not None):
ws.close()
ws.on_message = None
ws.on_open = None
ws.close = None
# print (' deleting ws')
del ws
self.connect()
else:
print ('Error info: %s' %(error))
self._on_error(ws, error)
def _on_message(self, ws, message):
pass
def _on_open(self, ws):
pass
def _on_close(self, ws):
pass
def _on_error(self, ws, error):
pass
the provider has also imported the websocket module in the above smartApiWebsocket file and has also provided a custom webSocket.py
here is the link webSocket.py
enter link description here

Related

self.scope['user'] always returns AnonymousUser in websocket

I have researched similar questions but can't find an answer that works for me.
I would like to get the username or user_id from a session when a user connects to a websocket.
This is what I have in consumers.py:
class PracticeConsumer(AsyncConsumer):
async def websocket_connect(self, event):
print('session data', self.scope['user'])
await self.send({"type": "websocket.accept", })
...
#database_sync_to_async
def get_user(self, user_id):
try:
return User.objects.get(username=user_id).pk
except User.DoesNotExist:
return AnonymousUser()
This is my asgi.py:
"""
ASGI config for restapi project.
It exposes the ASGI callable as a module-level variable named application.
For more information on this file, see
https://docs.djangoproject.com/en/4.0/howto/deployment/asgi/
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'signup.settings')
application = ProtocolTypeRouter({
"http": get_asgi_application(),
"websocket": AllowedHostsOriginValidator(
AuthMiddlewareStack(
URLRouter(websocket_urlpatterns)
)
)
})
and the user login function + token sent when user logs in:
class CustomTokenObtainPairSerializer(TokenObtainPairSerializer):
def validate(self, attrs):
authenticate_kwargs = {
self.username_field: attrs[self.username_field],
"password": attrs["password"],
}
try:
authenticate_kwargs["request"] = self.context["request"]
except KeyError:
pass
user = authenticate(**authenticate_kwargs)
if not user:
return {
'user': 'Username or password is incorrect',
}
token = RefreshToken.for_user(user)
# customizing token payload
token['username'] = user.username
token['first_name'] = user.first_name
token['last_name'] = user.last_name
token['country'] = user.profile.country
token['city'] = user.profile.city
token['bio'] = user.profile.bio
token['photo'] = json.dumps(str(user.profile.profile_pic))
user_logged_in.send(sender=user.__class__, request=self.context['request'], user=user)
if not api_settings.USER_AUTHENTICATION_RULE(user):
raise exceptions.AuthenticationFailed(
self.error_messages["no_active_account"],
"no_active_account",
)
return {
'refresh': str(token),
'access': str(token.access_token),
}
Whenever I print out self.scope['user'] upon connecting, I get AnonymousUser
UPDATE
I tried writing some custom middleware to handle simple JWT authentication:
#database_sync_to_async
def get_user(validated_token):
try:
user = get_user_model().objects.get(id=validated_token["user_id"])
print(f"{user}")
return user
except User.DoesNotExist:
return AnonymousUser()
class JwtAuthMiddleware(BaseMiddleware):
def __init__(self, inner):
self.inner = inner
async def __call__(self, scope, receive, send):
close_old_connections()
token = parse_qs(scope["query_string"].decode("utf8"))["token"][0]
try:
UntypedToken(token)
except (InvalidToken, TokenError) as e:
print(e)
return None
else:
decoded_data = jwt_decode(token, settings.SECRET_KEY, algorithms=["HS256"])
print(decoded_data)
scope["user"] = await get_user(validated_token=decoded_data)
return await super().__call__(scope, receive, send)
def JwtAuthMiddlewareStack(inner):
return JwtAuthMiddleware(AuthMiddlewareStack(inner))
However, this gives me this error:
File "C:\Users\15512\Desktop\django-project\peerplatform\signup\middleware.py", line 37, in __call__
token = parse_qs(scope["query_string"].decode("utf8"))["token"][0]
KeyError: 'token'

Writing unittest for package

I have created a package which uploads a data to some storage using Azure AD access token, now I want to write test cases for the code, as I'm not aware of writing test cases have tried few. Can anyone help here, below is the code for my package.
__init__.py file
import json
import requests
import sys
from data import Data
import datetime
from singleton import singleton
#singleton
class CertifiedLogProvider:
def __init__(self, client_id, client_secret):
self.client_id=client_id
self.client_secret= client_secret
self.grant_type="client_credentials"
self.resource="*****"
self.url="https://login.microsoftonline.com/azureford.onmicrosoft.com/oauth2/token"
self.api_url="http://example.com"
self.get_access_token()
def get_access_token(self)-> None:
data={'client_id':self.client_id,'client_secret':self.client_secret,
'grant_type':self.grant_type,'resource':self.resource}
response = requests.post(self.url, data=data)
if response.status_code == 200:
self.token_dict=response.content
self.access_token = response.json()["access_token"]
else:
print(f"An Error occurred: \n {response.text}")
def is_expired(self) -> bool:
token_dict=json.loads(self.token_dict.decode('utf-8'))
if int(datetime.datetime.now().timestamp()) > int(token_dict['expires_on']):
return True
else:
return False
def send_clp_data(self,payload:dict):
obj=Data()
data=obj.data
data['event_metric_body']=payload
if self.is_expired() is True:
self.get_access_token()
headers={"Authorization": "Bearer {}".format(self.access_token),
"Content-Type": "application/json",}
response = requests.post(self.api_url,data=json.dumps(data), headers=headers)
if response.status_code == 201:
print('Data uploaded successfully')
else:
print(f"An Error occurred: \n {response.text}")
singleton.py
def singleton(class_):
instances = {}
def getinstance(*args, **kwargs):
if class_ not in instances:
instances[class_] = class_(*args, **kwargs)
return instances[class_]
return getinstance
data.py
Contains data which is static
test.py
import json
import unittest
from unittest import TestCase
from unittest.mock import patch
import requests
from unittest.mock import MagicMock
from __init__ import CertifiedLogProvider
import pytest
class MyTestCase(TestCase):
def test_can_construct_clp_instance(self):
object= CertifiedLogProvider(1,2)
#patch('requests.post')
def test_send_clp_data(self, mock_post):
info={"test1":"value1", "test2": "value2"}
response = requests.post("www.clp_api.com", data=json.dumps(info), headers={'Content-Type': 'application/json'})
mock_post.assert_called_with("www.clp_api.com", data=json.dumps(info), headers={'Content-Type': 'application/json'})
if __name__ == '__main__':
unittest.main()
How can we test boolean method and method containing requests?
Testing is_expired() becomes easier if you create a Token class:
class Token:
def __init__(self, token_dict: dict) -> None:
self.token_dict = token_dict
def __str__(self) -> str:
return self.token_dict['access_token']
def is_expired(self, now:DateTime=None) -> bool:
if now is None:
now = datetime.datetime.now()
return int(now.timestamp()) > int(self.token_dict['expires_on'])
#singleton
class CertifiedLogProvider:
def __init__(self, client_id, client_secret):
self.client_id=client_id
self.client_secret= client_secret
self.grant_type="client_credentials"
self.resource="*****"
self.url="https://login.microsoftonline.com/azureford.onmicrosoft.com/oauth2/token"
self.api_url="http://example.com"
self.access_token = None
def get_access_token(self) -> Token:
if self.access_token is None or self.access_token.is_expired():
self.access_token = self.fetch_access_token()
return self.access_token
def fetch_access_token(self) -> Token:
data={
'client_id': self.client_id,
'client_secret': self.client_secret,
'grant_type': self.grant_type,
'resource': self.resource,
}
response = requests.post(self.url, data=data)
if response.status_code != 200:
raise Exception(f"An Error occurred: \n {response.text}")
return Token(response.json())
def send_clp_data(self, payload: dict):
obj=Data()
data=obj.data
data['event_metric_body'] = payload
headers={
"Authorization": f"Bearer {self.get_access_token()}",
"Content-Type": "application/json",
}
response = requests.post(self.api_url, data=json.dumps(data), headers=headers)
if response.status_code != 201:
raise Exception(f"An Error occurred: \n {response.text}")
print('Data uploaded successfully')
test_token.py
from datetime import datetime
import unittest
from certified_log_provider import Token
class TestToken(unittest.TestCase):
def setUp(self) -> None:
self.token = Token({ 'expires_on': datetime.strptime("2022-02-18", "%Y-%m-%d").timestamp() })
def test_token_is_not_expired_day_before(self):
self.assertFalse(self.token.is_expired(datetime.strptime("2022-02-17", "%Y-%m-%d")))
def test_token_is_expired_day_after(self):
self.assertTrue(self.token.is_expired(datetime.strptime("2022-02-19", "%Y-%m-%d")))

requests.session() object not recognized in another class

I am trying to pass my session object from one class to another. But I am not sure whats happening.
class CreateSession:
def __init__(self, user, pwd, url="http://url_to_hit"):
self.url = url
self.user = user
self.pwd = pwd
def get_session(self):
sess = requests.Session()
r = sess.get(self.url + "/", auth=(self.user, self.pwd))
print(r.content)
return sess
class TestGet(CreateSession):
def get_response(self):
s = self.get_session()
print(s)
data = s.get(self.url + '/some-get')
print(data.status_code)
print(data)
if __name__ == "__main__":
TestGet(user='user', pwd='pwd').get_response()
I am getting 401 for get_response(). Not able to understand this.
What's a 401?
The response you're getting means that you're unauthorised to access the resource.
A session is used in order to persist headers and other prerequisites throughout requests, why are you creating the session every time rather than storing it in a variable?
As is, the session should work the only issue is that you're trying to call a resource that you don't have access to. - You're not passing the url parameter either in the initialisation.
Example of how you can effectively use Session:
from requests import Session
from requests.exceptions import HTTPError
class TestGet:
__session = None
__username = None
__password = None
def __init__(self, username, password):
self.__username = username
self.__password = password
#property
def session(self):
if self.__session is None:
self.__session = Session()
self.__session.auth = (self.__user, self.__pwd)
return self.__session
#session.setter
def session(self, value):
raise AttributeError('Setting \'session\' attribute is prohibited.')
def get_response(self, url):
try:
response = self.session.get(url)
# raises if the status code is an error - 4xx, 5xx
response.raise_for_status()
return response
except HTTPError as e:
# you received an http error .. handle it here (e contains the request and response)
pass
test_get = TestGet('my_user', 'my_pass')
first_response = test_get.get_response('http://your-website-with-basic-auth.com')
second_response = test_get.get_response('http://another-url.com')
my_session = test_get.session
my_session.get('http://url.com')

Error with asynchronous request in DRF

I need to fulfill a request to two services.
The code looks like this:
async def post1(data):
response = await aiohttp.request('post', 'http://', json=data)
json_response = await response.json()
response.close()
return json_response
async def get2():
response = await aiohttp.request('get', 'http://')
json_response = await response.json()
response.close()
return json_response
async def asynchronous(parameters):
task1 = post1(parameters['data'])
task2 = get2()
result_list = []
for body in await asyncio.gather(task1, task2):
result_list.append(body)
return result_list
If I run the code locally, it's OK. The code looks like this:
if __name__ == "__main__":
ioloop = asyncio.get_event_loop()
parameters = {'data': data}
result = ioloop.run_until_complete(asynchronous(parameters))
ioloop.close()
print(result)
I get the right result. But if I try to execute code from the DRF method, an error occurs:
TypeError: object _SessionRequestContextManager can't be used in
'await' expression
example code that I run:
.....
class MyViewSet(GenericAPIView):
def post(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
......
ioloop = asyncio.get_event_loop()
result = ioloop.run_until_complete(asynchronous(serializer.data)) # <<<<< error here
ioloop.close()
......
return Response(serializer.data, status=status.HTTP_201_CREATED)
Please, tell me what the problem may be?
The object returned by aiohttp.request cannot be awaited, it must be used as an async context manager. This code:
response = await aiohttp.request('post', 'http://', json=data)
json_response = await response.json()
response.close()
needs to changed to something like:
async with aiohttp.request('post', 'http://', json=data) as response:
json_response = await response.json()
See the documentation for more usage examples.
Perhaps you have a different aiohttp version on the server where you run DRF, which is why it works locally and fails under DRF.
Try it https://github.com/Skorpyon/aiorestframework
Create some middleware for authentication:
import re
from xml.etree import ElementTree as etree
from json.decoder import JSONDecodeError
from multidict import MultiDict, MultiDictProxy
from aiohttp import web
from aiohttp.hdrs import (
METH_POST, METH_PUT, METH_PATCH, METH_DELETE
)
from aiorestframework import exceptions
from aiorestframework omport serializers
from aiorestframework import Response
from aiorestframework.views import BaseViewSet
from aiorestframework.permissions import set_permissions, AllowAny
from aiorestframework.app import APIApplication
from my_project.settings import Settings # Generated by aiohttp-devtools
TOKEN_RE = re.compile(r'^\s*BEARER\s{,3}(\S{64})\s*$')
async def token_authentication(app, handler):
"""
Authorization middleware
Catching Authorization: BEARER <token> from request headers
Found user in Tarantool by token and bind User or AnonymousUser to request
"""
async def middleware_handler(request):
# Check that `Authorization` header exists
if 'authorization' in request.headers:
authorization = request.headers['authorization']
# Check matches in header value
match = TOKEN_RE.match(authorization)
if not match:
setattr(request, 'user', AnonymousUser())
return await handler(request)
else:
token = match[1]
elif 'authorization_token' in request.query:
token = request.query['authorization_token']
else:
setattr(request, 'user', AnonymousUser())
return await handler(request)
# Try select user auth record from Tarantool by token index
res = await app['tnt'].select('auth', [token, ])
cached = res.body
if not cached:
raise exceptions.AuthenticationFailed()
# Build basic user data and bind it to User instance
record = cached[0]
user = User()
user.bind_cached_tarantool(record)
# Add User to request
setattr(request, 'user', user)
return await handler(request)
return middleware_handler
And for data extraction from request:
DATA_METHODS = [METH_POST, METH_PUT, METH_PATCH, METH_DELETE]
JSON_CONTENT = ['application/json', ]
XML_CONTENT = ['application/xml', ]
FORM_CONTENT = ['application/x-www-form-urlencoded', 'multipart/form-data']
async def request_data_handler(app, handler):
"""
Request .data middleware
Try extract POST data or application/json from request body
"""
async def middleware_handler(request):
data = None
if request.method in DATA_METHODS:
if request.has_body:
if request.content_type in JSON_CONTENT:
# If request has body - try to decode it to JSON
try:
data = await request.json()
except JSONDecodeError:
raise exceptions.ParseError()
elif request.content_type in XML_CONTENT:
if request.charset is not None:
encoding = request.charset
else:
encoding = api_settings.DEFAULT_CHARSET
parser = etree.XMLParser(encoding=encoding)
try:
text = await request.text()
tree = etree.XML(text, parser=parser)
except (etree.ParseError, ValueError) as exc:
raise exceptions.ParseError(
detail='XML parse error - %s' % str(exc))
data = tree
elif request.content_type in FORM_CONTENT:
data = await request.post()
if data is None:
# If not catch any data create empty MultiDictProxy
data = MultiDictProxy(MultiDict())
# Attach extracted data to request
setattr(request, 'data', data)
return await handler(request)
return middleware_handler
Create few serializers:
class UserRegisterSerializer(s.Serializer):
"""Register new user"""
email = s.EmailField(max_length=256)
password = s.CharField(min_length=8, max_length=64)
first_name = s.CharField(min_length=2, max_length=64)
middle_name = s.CharField(default='', min_length=2, max_length=64,
required=False, allow_blank=True)
last_name = s.CharField(min_length=2, max_length=64)
phone = s.CharField(max_length=32, required=False,
allow_blank=True, default='')
async def register_user(self, app):
user = User()
data = self.validated_data
try:
await user.register_user(data, app)
except Error as e:
resolve_db_exception(e, self)
return user
And few ViewSets. It may be nested in bindings['custom']
class UserViewSet(BaseViewSet):
name = 'user'
lookup_url_kwarg = '{user_id:[0-9a-f]{32}}'
permission_classes = [AllowAny, ]
bindings = {
'list': {
'retrieve': 'get',
'update': 'put'
},
'custom': {
'list': {
'set_status': 'post',
'create_new_sip_password': 'post',
'get_registration_domain': 'get',
'report': UserReportViewSet
}
}
}
#staticmethod
async def resolve_sip_host(data, user, app):
sip_host = await resolve_registration_switch(user, app)
data.update({'sip_host': sip_host})
async def retrieve(self, request):
user = User()
await user.load_from_db(request.match_info['user_id'], request.app)
serializer = user_ser.UserProfileSerializer(instance=user)
data = serializer.data
await self.resolve_sip_host(data, user, request.app)
return Response(data=data)
#atomic
#set_permissions([AuthenticatedOnly, IsCompanyMember, CompanyIsEnabled])
async def update(self, request):
serializer = user_ser.UserProfileSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
user = await serializer.update_user(user_id=request.user.id,
app=request.app)
serializer = user_ser.UserProfileSerializer(instance=user)
data = serializer.data
await self.resolve_sip_host(data, user, request.app)
return Response(data=data)
Register Viewsets and run Application:
def setup_routes(app: APIApplication):
"""Add app routes here"""
# Auth API
app.router.register_viewset('/auth', auth_vws.AuthViewSet())
# User API
app.router.register_viewset('/user', user_vws.UserViewSet())
# Root redirection to Swagger
redirect = app.router.add_resource('/', name='home_redirect')
redirect.add_route('*', swagger_redirect)
def create_api_app():
sentry = get_sentry_middleware(settings.SENTRY_CONNECT_STRING, settings.SENTRY_ENVIRONMENT)
middlewares = [sentry, token_authentication, request_data_handler]
api_app = APIApplication(name='api', middlewares=middlewares,
client_max_size=10*(1024**2))
api_app.on_startup.append(startup.startup_api)
api_app.on_shutdown.append(startup.shutdown_api)
api_app.on_cleanup.append(startup.cleanup_api)
setup_routes(api_app)
if __name__ == '__main__':
app = create_api_app()
web.run_app(app)

Tornado testing async requests

I need an advice regards testing tornado app. For now I just playing with demo chat application, but it looks like real-life problem.
In the handler I have:
class MessageUpdatesHandler(BaseHandler):
#tornado.web.authenticated
#tornado.web.asynchronous
def post(self):
cursor = self.get_argument("cursor", None)
global_message_buffer.wait_for_messages(self.on_new_messages,
cursor=cursor)
def on_new_messages(self, messages):
# Closed client connection
if self.request.connection.stream.closed():
return
self.finish(dict(messages=messages))
class MessageBuffer(object):
def __init__(self):
....
def wait_for_messages(self, callback, cursor=None):
if cursor:
new_count = 0
for msg in reversed(self.cache):
if msg["id"] == cursor:
break
new_count += 1
if new_count:
callback(self.cache[-new_count:])
return
self.waiters.add(callback)
def cancel_wait(self, callback):
.....
def new_messages(self, messages):
logging.info("Sending new message to %r listeners", len(self.waiters))
for callback in self.waiters:
try:
callback(messages)
except:
logging.error("Error in waiter callback", exc_info=True)
self.waiters = set()
self.cache.extend(messages)
if len(self.cache) > self.cache_size:
self.cache = self.cache[-self.cache_size:]
As I metioned full source code is in torndado demos
In my test I have:
#wsgi_safe
class MessageUpdatesHandlerTest(LoginedUserHanldersTest):
Handler = MessageUpdatesHandler
def test_add_message(self):
from chatdemo import global_message_buffer
kwargs = dict(
method="POST",
body='',
)
future = self.http_client.fetch(self.get_url('/'), callback=self.stop, **kwargs)
message = {
"id": '123',
"from": "first_name",
"body": "hello",
"html": "html"
}
global_message_buffer.new_messages([message])
response = self.wait()
self.assertEqual(response.code, 200)
self.mox.VerifyAll()
What happens:
It creates a future object
It sends a hello message, in this moment no waiter is registered
in MessageBuffer so callback is not called
In wait starts IoLoop and makes, a post fetch and waiter becomes
registered in MessageBuffer
Callback is never called and my response remains empty, so
everything fails with
AssertionError: Async operation timed out
after 5 seconds
What I want it to do:
On post register itself as a waiter
Receive some messages
Return to me a 200 response
Thank you for your help

Categories