convert synchronous websocket callbacks to async in python using asyncio - python
need you assistance to convert asynchronous websocket callbacks to async using asyncio in python 3.6.
Some brokers in India provide a websocket for tick data to their users and users can further use this tick data, one such brokers sample code(from their git repo) is pasted here
`from smartapi import SmartWebSocket
# feed_token=092017047
FEED_TOKEN="YOUR_FEED_TOKEN"
CLIENT_CODE="YOUR_CLIENT_CODE"
# token="mcx_fo|224395"
token="EXCHANGE|TOKEN_SYMBOL" #SAMPLE: nse_cm|2885&nse_cm|1594&nse_cm|11536&nse_cm|3045
# token="mcx_fo|226745&mcx_fo|220822&mcx_fo|227182&mcx_fo|221599"
task="mw" # mw|sfi|dp
ss = SmartWebSocket(FEED_TOKEN, CLIENT_CODE)
def on_message(ws, message):
print("Ticks: {}".format(message))
def on_open(ws):
print("on open")
ss.subscribe(task,token)
def on_error(ws, error):
print(error)
def on_close(ws):
print("Close")
# Assign the callbacks.
ss._on_open = on_open
ss._on_message = on_message
ss._on_error = on_error
ss._on_close = on_close
ss.connect()
so what I think the broker is doing is that it provides messages/ticks aka events that probably trigger a callback function on every received messages/tick in a synchronous way.
I however want to alter the above code so that it can work in an async manner, I have tried to write in a manner how binance provides such as async with websocketName('currencyPair')as ts: but this doesn't seem to work with my broker.
Appreciate if you can share some ideas/code/insights around this.
Thank you for your time.
Edit 1:
Thank you for your response Dirn,
Yes, I've tried this
from smartapi import SmartConnect
from smartapi import SmartWebSocket
import logging
import asyncio
CLIENT_CODE = Code_provided_by_broker
PASSWORD = My_Password
obj=SmartConnect(api_key=My_api_key)
data = obj.generateSession(CLIENT_CODE, PASSWORD)
refreshToken = data['data']['refreshToken']
print("refreshToken", refreshToken)
feedToken=obj.getfeedToken()
print("feedToken", feedToken)
FEED_TOKEN= feedToken
token="nse_cm|3499"
# token="nse_fo|53595"
task = 'mw'
async def on_message(ws, message):
print("Ticks: {}".format(message))
async def chk_msg():
# ss = SmartWebSocket(FEED_TOKEN, CLIENT_CODE)
# ss.subscribe(task,token)
# ss._on_message = on_message
# ss.connect()
# print("Ticks: {}".format(message))
async with SmartWebSocket(FEED_TOKEN, CLIENT_CODE) as ss: #throws an error: AttributeError: __aexit__
ss.subscribe(task,token)
ss.connect()
while True:
ss._on_message = on_message # not sure how to await a non-async method _on_message, also there is no use of async module in the modules(imported at start of the script) SmartWebSocket and SmartConnect
if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(chk_msg())
here is the smartConnect.py
from six.moves.urllib.parse import urljoin
import sys
import csv
import json
import dateutil.parser
import hashlib
import logging
import datetime
import smartapi.smartExceptions as ex
import requests
from requests import get
import re, uuid
import socket
import platform
from smartapi.version import __version__, __title__
log = logging.getLogger(__name__)
#user_sys=platform.system()
#print("the system",user_sys)
class SmartConnect(object):
#_rootUrl = "https://openapisuat.angelbroking.com"
_rootUrl="https://apiconnect.angelbroking.com" #prod endpoint
#_login_url ="https://smartapi.angelbroking.com/login"
_login_url="https://smartapi.angelbroking.com/publisher-login" #prod endpoint
_default_timeout = 7 # In seconds
_routes = {
"api.login":"/rest/auth/angelbroking/user/v1/loginByPassword",
"api.logout":"/rest/secure/angelbroking/user/v1/logout",
"api.token": "/rest/auth/angelbroking/jwt/v1/generateTokens",
"api.refresh": "/rest/auth/angelbroking/jwt/v1/generateTokens",
"api.user.profile": "/rest/secure/angelbroking/user/v1/getProfile",
"api.order.place": "/rest/secure/angelbroking/order/v1/placeOrder",
"api.order.modify": "/rest/secure/angelbroking/order/v1/modifyOrder",
"api.order.cancel": "/rest/secure/angelbroking/order/v1/cancelOrder",
"api.order.book":"/rest/secure/angelbroking/order/v1/getOrderBook",
"api.ltp.data": "/rest/secure/angelbroking/order/v1/getLtpData",
"api.trade.book": "/rest/secure/angelbroking/order/v1/getTradeBook",
"api.rms.limit": "/rest/secure/angelbroking/user/v1/getRMS",
"api.holding": "/rest/secure/angelbroking/portfolio/v1/getHolding",
"api.position": "/rest/secure/angelbroking/order/v1/getPosition",
"api.convert.position": "/rest/secure/angelbroking/order/v1/convertPosition",
"api.gtt.create":"/gtt-service/rest/secure/angelbroking/gtt/v1/createRule",
"api.gtt.modify":"/gtt-service/rest/secure/angelbroking/gtt/v1/modifyRule",
"api.gtt.cancel":"/gtt-service/rest/secure/angelbroking/gtt/v1/cancelRule",
"api.gtt.details":"/rest/secure/angelbroking/gtt/v1/ruleDetails",
"api.gtt.list":"/rest/secure/angelbroking/gtt/v1/ruleList",
"api.candle.data":"/rest/secure/angelbroking/historical/v1/getCandleData"
}
try:
clientPublicIp= " " + get('https://api.ipify.org').text
if " " in clientPublicIp:
clientPublicIp=clientPublicIp.replace(" ","")
hostname = socket.gethostname()
clientLocalIp=socket.gethostbyname(hostname)
except Exception as e:
print("Exception while retriving IP Address,using local host IP address",e)
finally:
clientPublicIp="106.193.147.98"
clientLocalIp="127.0.0.1"
clientMacAddress=':'.join(re.findall('..', '%012x' % uuid.getnode()))
accept = "application/json"
userType = "USER"
sourceID = "WEB"
def __init__(self, api_key=None, access_token=None, refresh_token=None,feed_token=None, userId=None, root=None, debug=False, timeout=None, proxies=None, pool=None, disable_ssl=False,accept=None,userType=None,sourceID=None,Authorization=None,clientPublicIP=None,clientMacAddress=None,clientLocalIP=None,privateKey=None):
self.debug = debug
self.api_key = api_key
self.session_expiry_hook = None
self.disable_ssl = disable_ssl
self.access_token = access_token
self.refresh_token = refresh_token
self.feed_token = feed_token
self.userId = userId
self.proxies = proxies if proxies else {}
self.root = root or self._rootUrl
self.timeout = timeout or self._default_timeout
self.Authorization= None
self.clientLocalIP=self.clientLocalIp
self.clientPublicIP=self.clientPublicIp
self.clientMacAddress=self.clientMacAddress
self.privateKey=api_key
self.accept=self.accept
self.userType=self.userType
self.sourceID=self.sourceID
if pool:
self.reqsession = requests.Session()
reqadapter = requests.adapters.HTTPAdapter(**pool)
self.reqsession.mount("https://", reqadapter)
print("in pool")
else:
self.reqsession = requests
# disable requests SSL warning
requests.packages.urllib3.disable_warnings()
def requestHeaders(self):
return{
"Content-type":self.accept,
"X-ClientLocalIP": self.clientLocalIp,
"X-ClientPublicIP": self.clientPublicIp,
"X-MACAddress": self.clientMacAddress,
"Accept": self.accept,
"X-PrivateKey": self.privateKey,
"X-UserType": self.userType,
"X-SourceID": self.sourceID
}
def setSessionExpiryHook(self, method):
if not callable(method):
raise TypeError("Invalid input type. Only functions are accepted.")
self.session_expiry_hook = method
def getUserId():
return userId
def setUserId(self,id):
self.userId=id
def setAccessToken(self, access_token):
self.access_token = access_token
def setRefreshToken(self, refresh_token):
self.refresh_token = refresh_token
def setFeedToken(self,feedToken):
self.feed_token=feedToken
def getfeedToken(self):
return self.feed_token
def login_url(self):
"""Get the remote login url to which a user should be redirected to initiate the login flow."""
return "%s?api_key=%s" % (self._login_url, self.api_key)
def _request(self, route, method, parameters=None):
"""Make an HTTP request."""
params = parameters.copy() if parameters else {}
uri =self._routes[route].format(**params)
url = urljoin(self.root, uri)
# Custom headers
headers = self.requestHeaders()
if self.access_token:
# set authorization header
auth_header = self.access_token
headers["Authorization"] = "Bearer {}".format(auth_header)
if self.debug:
log.debug("Request: {method} {url} {params} {headers}".format(method=method, url=url, params=params, headers=headers))
try:
r = requests.request(method,
url,
data=json.dumps(params) if method in ["POST", "PUT"] else None,
params=json.dumps(params) if method in ["GET", "DELETE"] else None,
headers=headers,
verify=not self.disable_ssl,
allow_redirects=True,
timeout=self.timeout,
proxies=self.proxies)
except Exception as e:
raise e
if self.debug:
log.debug("Response: {code} {content}".format(code=r.status_code, content=r.content))
# Validate the content type.
if "json" in headers["Content-type"]:
try:
data = json.loads(r.content.decode("utf8"))
except ValueError:
raise ex.DataException("Couldn't parse the JSON response received from the server: {content}".format(
content=r.content))
# api error
if data.get("error_type"):
# Call session hook if its registered and TokenException is raised
if self.session_expiry_hook and r.status_code == 403 and data["error_type"] == "TokenException":
self.session_expiry_hook()
# native errors
exp = getattr(ex, data["error_type"], ex.GeneralException)
raise exp(data["message"], code=r.status_code)
return data
elif "csv" in headers["Content-type"]:
return r.content
else:
raise ex.DataException("Unknown Content-type ({content_type}) with response: ({content})".format(
content_type=headers["Content-type"],
content=r.content))
def _deleteRequest(self, route, params=None):
"""Alias for sending a DELETE request."""
return self._request(route, "DELETE", params)
def _putRequest(self, route, params=None):
"""Alias for sending a PUT request."""
return self._request(route, "PUT", params)
def _postRequest(self, route, params=None):
"""Alias for sending a POST request."""
return self._request(route, "POST", params)
def _getRequest(self, route, params=None):
"""Alias for sending a GET request."""
return self._request(route, "GET", params)
def generateSession(self,clientCode,password):
params={"clientcode":clientCode,"password":password}
loginResultObject=self._postRequest("api.login",params)
if loginResultObject['status']==True:
jwtToken=loginResultObject['data']['jwtToken']
self.setAccessToken(jwtToken)
refreshToken=loginResultObject['data']['refreshToken']
feedToken=loginResultObject['data']['feedToken']
self.setRefreshToken(refreshToken)
self.setFeedToken(feedToken)
user=self.getProfile(refreshToken)
id=user['data']['clientcode']
#id='D88311'
self.setUserId(id)
user['data']['jwtToken']="Bearer "+jwtToken
user['data']['refreshToken']=refreshToken
return user
else:
return loginResultObject
def terminateSession(self,clientCode):
logoutResponseObject=self._postRequest("api.logout",{"clientcode":clientCode})
return logoutResponseObject
def generateToken(self,refresh_token):
response=self._postRequest('api.token',{"refreshToken":refresh_token})
jwtToken=response['data']['jwtToken']
feedToken=response['data']['feedToken']
self.setFeedToken(feedToken)
self.setAccessToken(jwtToken)
return response
def renewAccessToken(self):
response =self._postRequest('api.refresh', {
"jwtToken": self.access_token,
"refreshToken": self.refresh_token,
})
tokenSet={}
if "jwtToken" in response:
tokenSet['jwtToken']=response['data']['jwtToken']
tokenSet['clientcode']=self. userId
tokenSet['refreshToken']=response['data']["refreshToken"]
return tokenSet
def getProfile(self,refreshToken):
user=self._getRequest("api.user.profile",{"refreshToken":refreshToken})
return user
def placeOrder(self,orderparams):
params=orderparams
for k in list(params.keys()):
if params[k] is None :
del(params[k])
orderResponse= self._postRequest("api.order.place", params)['data']['orderid']
return orderResponse
def modifyOrder(self,orderparams):
params = orderparams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
orderResponse= self._postRequest("api.order.modify", params)
return orderResponse
def cancelOrder(self, order_id,variety):
orderResponse= self._postRequest("api.order.cancel", {"variety": variety,"orderid": order_id})
return orderResponse
def ltpData(self,exchange,tradingsymbol,symboltoken):
params={
"exchange":exchange,
"tradingsymbol":tradingsymbol,
"symboltoken":symboltoken
}
ltpDataResponse= self._postRequest("api.ltp.data",params)
return ltpDataResponse
def orderBook(self):
orderBookResponse=self._getRequest("api.order.book")
return orderBookResponse
def tradeBook(self):
tradeBookResponse=self._getRequest("api.trade.book")
return tradeBookResponse
def rmsLimit(self):
rmsLimitResponse= self._getRequest("api.rms.limit")
return rmsLimitResponse
def position(self):
positionResponse= self._getRequest("api.position")
return positionResponse
def holding(self):
holdingResponse= self._getRequest("api.holding")
return holdingResponse
def convertPosition(self,positionParams):
params=positionParams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
convertPositionResponse= self._postRequest("api.convert.position",params)
return convertPositionResponse
def gttCreateRule(self,createRuleParams):
params=createRuleParams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
createGttRuleResponse=self._postRequest("api.gtt.create",params)
#print(createGttRuleResponse)
return createGttRuleResponse['data']['id']
def gttModifyRule(self,modifyRuleParams):
params=modifyRuleParams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
modifyGttRuleResponse=self._postRequest("api.gtt.modify",params)
#print(modifyGttRuleResponse)
return modifyGttRuleResponse['data']['id']
def gttCancelRule(self,gttCancelParams):
params=gttCancelParams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
#print(params)
cancelGttRuleResponse=self._postRequest("api.gtt.cancel",params)
#print(cancelGttRuleResponse)
return cancelGttRuleResponse
def gttDetails(self,id):
params={
"id":id
}
gttDetailsResponse=self._postRequest("api.gtt.details",params)
return gttDetailsResponse
def gttLists(self,status,page,count):
if type(status)== list:
params={
"status":status,
"page":page,
"count":count
}
gttListResponse=self._postRequest("api.gtt.list",params)
#print(gttListResponse)
return gttListResponse
else:
message="The status param is entered as" +str(type(status))+". Please enter status param as a list i.e., status=['CANCELLED']"
return message
def getCandleData(self,historicDataParams):
params=historicDataParams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
getCandleDataResponse=self._postRequest("api.candle.data",historicDataParams)
return getCandleDataResponse
def _user_agent(self):
return (__title__ + "-python/").capitalize() + __version__
the smartApiWebsocket.py
import websocket
import six
import base64
import zlib
import datetime
import time
import json
import threading
import ssl
class SmartWebSocket(object):
ROOT_URI='wss://wsfeeds.angelbroking.com/NestHtml5Mobile/socket/stream'
HB_INTERVAL=30
HB_THREAD_FLAG=False
WS_RECONNECT_FLAG=False
feed_token=None
client_code=None
ws=None
task_dict = {}
def __init__(self, FEED_TOKEN, CLIENT_CODE):
self.root = self.ROOT_URI
self.feed_token = FEED_TOKEN
self.client_code = CLIENT_CODE
if self.client_code == None or self.feed_token == None:
return "client_code or feed_token or task is missing"
def _subscribe_on_open(self):
request = {"task": "cn", "channel": "NONLM", "token": self.feed_token, "user": self.client_code,
"acctid": self.client_code}
print(request)
self.ws.send(
six.b(json.dumps(request))
)
thread = threading.Thread(target=self.run, args=())
thread.daemon = True
thread.start()
def run(self):
while True:
# More statements comes here
if self.HB_THREAD_FLAG:
break
print(datetime.datetime.now().__str__() + ' : Start task in the background')
self.heartBeat()
time.sleep(self.HB_INTERVAL)
def subscribe(self, task, token):
# print(self.task_dict)
self.task_dict.update([(task,token),])
# print(self.task_dict)
if task in ("mw", "sfi", "dp"):
strwatchlistscrips = token # dynamic call
try:
request = {"task": task, "channel": strwatchlistscrips, "token": self.feed_token,
"user": self.client_code, "acctid": self.client_code}
self.ws.send(
six.b(json.dumps(request))
)
return True
except Exception as e:
self._close(reason="Error while request sending: {}".format(str(e)))
raise
else:
print("The task entered is invalid, Please enter correct task(mw,sfi,dp) ")
def resubscribe(self):
for task, marketwatch in self.task_dict.items():
print(task, '->', marketwatch)
try:
request = {"task": task, "channel": marketwatch, "token": self.feed_token,
"user": self.client_code, "acctid": self.client_code}
self.ws.send(
six.b(json.dumps(request))
)
return True
except Exception as e:
self._close(reason="Error while request sending: {}".format(str(e)))
raise
def heartBeat(self):
try:
request = {"task": "hb", "channel": "", "token": self.feed_token, "user": self.client_code,
"acctid": self.client_code}
print(request)
self.ws.send(
six.b(json.dumps(request))
)
except:
print("HeartBeat Sending Failed")
# time.sleep(60)
def _parse_text_message(self, message):
"""Parse text message."""
data = base64.b64decode(message)
try:
data = bytes((zlib.decompress(data)).decode("utf-8"), 'utf-8')
data = json.loads(data.decode('utf8').replace("'", '"'))
data = json.loads(json.dumps(data, indent=4, sort_keys=True))
except ValueError:
return
# return data
if data:
self._on_message(self.ws,data)
def connect(self):
# websocket.enableTrace(True)
self.ws = websocket.WebSocketApp(self.ROOT_URI,
on_message=self.__on_message,
on_close=self.__on_close,
on_open=self.__on_open,
on_error=self.__on_error)
self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def __on_message(self, ws, message):
self._parse_text_message(message)
# print(msg)
def __on_open(self, ws):
print("__on_open################")
self.HB_THREAD_FLAG = False
self._subscribe_on_open()
if self.WS_RECONNECT_FLAG:
self.WS_RECONNECT_FLAG = False
self.resubscribe()
else:
self._on_open(ws)
def __on_close(self, ws):
self.HB_THREAD_FLAG = True
print("__on_close################")
self._on_close(ws)
def __on_error(self, ws, error):
if ( "timed" in str(error) ) or ( "Connection is already closed" in str(error) ) or ( "Connection to remote host was lost" in str(error) ):
self.WS_RECONNECT_FLAG = True
self.HB_THREAD_FLAG = True
if (ws is not None):
ws.close()
ws.on_message = None
ws.on_open = None
ws.close = None
# print (' deleting ws')
del ws
self.connect()
else:
print ('Error info: %s' %(error))
self._on_error(ws, error)
def _on_message(self, ws, message):
pass
def _on_open(self, ws):
pass
def _on_close(self, ws):
pass
def _on_error(self, ws, error):
pass
the provider has also imported the websocket module in the above smartApiWebsocket file and has also provided a custom webSocket.py
here is the link webSocket.py
enter link description here
Related
self.scope['user'] always returns AnonymousUser in websocket
I have researched similar questions but can't find an answer that works for me. I would like to get the username or user_id from a session when a user connects to a websocket. This is what I have in consumers.py: class PracticeConsumer(AsyncConsumer): async def websocket_connect(self, event): print('session data', self.scope['user']) await self.send({"type": "websocket.accept", }) ... #database_sync_to_async def get_user(self, user_id): try: return User.objects.get(username=user_id).pk except User.DoesNotExist: return AnonymousUser() This is my asgi.py: """ ASGI config for restapi project. It exposes the ASGI callable as a module-level variable named application. For more information on this file, see https://docs.djangoproject.com/en/4.0/howto/deployment/asgi/ os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'signup.settings') application = ProtocolTypeRouter({ "http": get_asgi_application(), "websocket": AllowedHostsOriginValidator( AuthMiddlewareStack( URLRouter(websocket_urlpatterns) ) ) }) and the user login function + token sent when user logs in: class CustomTokenObtainPairSerializer(TokenObtainPairSerializer): def validate(self, attrs): authenticate_kwargs = { self.username_field: attrs[self.username_field], "password": attrs["password"], } try: authenticate_kwargs["request"] = self.context["request"] except KeyError: pass user = authenticate(**authenticate_kwargs) if not user: return { 'user': 'Username or password is incorrect', } token = RefreshToken.for_user(user) # customizing token payload token['username'] = user.username token['first_name'] = user.first_name token['last_name'] = user.last_name token['country'] = user.profile.country token['city'] = user.profile.city token['bio'] = user.profile.bio token['photo'] = json.dumps(str(user.profile.profile_pic)) user_logged_in.send(sender=user.__class__, request=self.context['request'], user=user) if not api_settings.USER_AUTHENTICATION_RULE(user): raise exceptions.AuthenticationFailed( self.error_messages["no_active_account"], "no_active_account", ) return { 'refresh': str(token), 'access': str(token.access_token), } Whenever I print out self.scope['user'] upon connecting, I get AnonymousUser UPDATE I tried writing some custom middleware to handle simple JWT authentication: #database_sync_to_async def get_user(validated_token): try: user = get_user_model().objects.get(id=validated_token["user_id"]) print(f"{user}") return user except User.DoesNotExist: return AnonymousUser() class JwtAuthMiddleware(BaseMiddleware): def __init__(self, inner): self.inner = inner async def __call__(self, scope, receive, send): close_old_connections() token = parse_qs(scope["query_string"].decode("utf8"))["token"][0] try: UntypedToken(token) except (InvalidToken, TokenError) as e: print(e) return None else: decoded_data = jwt_decode(token, settings.SECRET_KEY, algorithms=["HS256"]) print(decoded_data) scope["user"] = await get_user(validated_token=decoded_data) return await super().__call__(scope, receive, send) def JwtAuthMiddlewareStack(inner): return JwtAuthMiddleware(AuthMiddlewareStack(inner)) However, this gives me this error: File "C:\Users\15512\Desktop\django-project\peerplatform\signup\middleware.py", line 37, in __call__ token = parse_qs(scope["query_string"].decode("utf8"))["token"][0] KeyError: 'token'
Writing unittest for package
I have created a package which uploads a data to some storage using Azure AD access token, now I want to write test cases for the code, as I'm not aware of writing test cases have tried few. Can anyone help here, below is the code for my package. __init__.py file import json import requests import sys from data import Data import datetime from singleton import singleton #singleton class CertifiedLogProvider: def __init__(self, client_id, client_secret): self.client_id=client_id self.client_secret= client_secret self.grant_type="client_credentials" self.resource="*****" self.url="https://login.microsoftonline.com/azureford.onmicrosoft.com/oauth2/token" self.api_url="http://example.com" self.get_access_token() def get_access_token(self)-> None: data={'client_id':self.client_id,'client_secret':self.client_secret, 'grant_type':self.grant_type,'resource':self.resource} response = requests.post(self.url, data=data) if response.status_code == 200: self.token_dict=response.content self.access_token = response.json()["access_token"] else: print(f"An Error occurred: \n {response.text}") def is_expired(self) -> bool: token_dict=json.loads(self.token_dict.decode('utf-8')) if int(datetime.datetime.now().timestamp()) > int(token_dict['expires_on']): return True else: return False def send_clp_data(self,payload:dict): obj=Data() data=obj.data data['event_metric_body']=payload if self.is_expired() is True: self.get_access_token() headers={"Authorization": "Bearer {}".format(self.access_token), "Content-Type": "application/json",} response = requests.post(self.api_url,data=json.dumps(data), headers=headers) if response.status_code == 201: print('Data uploaded successfully') else: print(f"An Error occurred: \n {response.text}") singleton.py def singleton(class_): instances = {} def getinstance(*args, **kwargs): if class_ not in instances: instances[class_] = class_(*args, **kwargs) return instances[class_] return getinstance data.py Contains data which is static test.py import json import unittest from unittest import TestCase from unittest.mock import patch import requests from unittest.mock import MagicMock from __init__ import CertifiedLogProvider import pytest class MyTestCase(TestCase): def test_can_construct_clp_instance(self): object= CertifiedLogProvider(1,2) #patch('requests.post') def test_send_clp_data(self, mock_post): info={"test1":"value1", "test2": "value2"} response = requests.post("www.clp_api.com", data=json.dumps(info), headers={'Content-Type': 'application/json'}) mock_post.assert_called_with("www.clp_api.com", data=json.dumps(info), headers={'Content-Type': 'application/json'}) if __name__ == '__main__': unittest.main() How can we test boolean method and method containing requests?
Testing is_expired() becomes easier if you create a Token class: class Token: def __init__(self, token_dict: dict) -> None: self.token_dict = token_dict def __str__(self) -> str: return self.token_dict['access_token'] def is_expired(self, now:DateTime=None) -> bool: if now is None: now = datetime.datetime.now() return int(now.timestamp()) > int(self.token_dict['expires_on']) #singleton class CertifiedLogProvider: def __init__(self, client_id, client_secret): self.client_id=client_id self.client_secret= client_secret self.grant_type="client_credentials" self.resource="*****" self.url="https://login.microsoftonline.com/azureford.onmicrosoft.com/oauth2/token" self.api_url="http://example.com" self.access_token = None def get_access_token(self) -> Token: if self.access_token is None or self.access_token.is_expired(): self.access_token = self.fetch_access_token() return self.access_token def fetch_access_token(self) -> Token: data={ 'client_id': self.client_id, 'client_secret': self.client_secret, 'grant_type': self.grant_type, 'resource': self.resource, } response = requests.post(self.url, data=data) if response.status_code != 200: raise Exception(f"An Error occurred: \n {response.text}") return Token(response.json()) def send_clp_data(self, payload: dict): obj=Data() data=obj.data data['event_metric_body'] = payload headers={ "Authorization": f"Bearer {self.get_access_token()}", "Content-Type": "application/json", } response = requests.post(self.api_url, data=json.dumps(data), headers=headers) if response.status_code != 201: raise Exception(f"An Error occurred: \n {response.text}") print('Data uploaded successfully') test_token.py from datetime import datetime import unittest from certified_log_provider import Token class TestToken(unittest.TestCase): def setUp(self) -> None: self.token = Token({ 'expires_on': datetime.strptime("2022-02-18", "%Y-%m-%d").timestamp() }) def test_token_is_not_expired_day_before(self): self.assertFalse(self.token.is_expired(datetime.strptime("2022-02-17", "%Y-%m-%d"))) def test_token_is_expired_day_after(self): self.assertTrue(self.token.is_expired(datetime.strptime("2022-02-19", "%Y-%m-%d")))
requests.session() object not recognized in another class
I am trying to pass my session object from one class to another. But I am not sure whats happening. class CreateSession: def __init__(self, user, pwd, url="http://url_to_hit"): self.url = url self.user = user self.pwd = pwd def get_session(self): sess = requests.Session() r = sess.get(self.url + "/", auth=(self.user, self.pwd)) print(r.content) return sess class TestGet(CreateSession): def get_response(self): s = self.get_session() print(s) data = s.get(self.url + '/some-get') print(data.status_code) print(data) if __name__ == "__main__": TestGet(user='user', pwd='pwd').get_response() I am getting 401 for get_response(). Not able to understand this.
What's a 401? The response you're getting means that you're unauthorised to access the resource. A session is used in order to persist headers and other prerequisites throughout requests, why are you creating the session every time rather than storing it in a variable? As is, the session should work the only issue is that you're trying to call a resource that you don't have access to. - You're not passing the url parameter either in the initialisation. Example of how you can effectively use Session: from requests import Session from requests.exceptions import HTTPError class TestGet: __session = None __username = None __password = None def __init__(self, username, password): self.__username = username self.__password = password #property def session(self): if self.__session is None: self.__session = Session() self.__session.auth = (self.__user, self.__pwd) return self.__session #session.setter def session(self, value): raise AttributeError('Setting \'session\' attribute is prohibited.') def get_response(self, url): try: response = self.session.get(url) # raises if the status code is an error - 4xx, 5xx response.raise_for_status() return response except HTTPError as e: # you received an http error .. handle it here (e contains the request and response) pass test_get = TestGet('my_user', 'my_pass') first_response = test_get.get_response('http://your-website-with-basic-auth.com') second_response = test_get.get_response('http://another-url.com') my_session = test_get.session my_session.get('http://url.com')
Error with asynchronous request in DRF
I need to fulfill a request to two services. The code looks like this: async def post1(data): response = await aiohttp.request('post', 'http://', json=data) json_response = await response.json() response.close() return json_response async def get2(): response = await aiohttp.request('get', 'http://') json_response = await response.json() response.close() return json_response async def asynchronous(parameters): task1 = post1(parameters['data']) task2 = get2() result_list = [] for body in await asyncio.gather(task1, task2): result_list.append(body) return result_list If I run the code locally, it's OK. The code looks like this: if __name__ == "__main__": ioloop = asyncio.get_event_loop() parameters = {'data': data} result = ioloop.run_until_complete(asynchronous(parameters)) ioloop.close() print(result) I get the right result. But if I try to execute code from the DRF method, an error occurs: TypeError: object _SessionRequestContextManager can't be used in 'await' expression example code that I run: ..... class MyViewSet(GenericAPIView): def post(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) ...... ioloop = asyncio.get_event_loop() result = ioloop.run_until_complete(asynchronous(serializer.data)) # <<<<< error here ioloop.close() ...... return Response(serializer.data, status=status.HTTP_201_CREATED) Please, tell me what the problem may be?
The object returned by aiohttp.request cannot be awaited, it must be used as an async context manager. This code: response = await aiohttp.request('post', 'http://', json=data) json_response = await response.json() response.close() needs to changed to something like: async with aiohttp.request('post', 'http://', json=data) as response: json_response = await response.json() See the documentation for more usage examples. Perhaps you have a different aiohttp version on the server where you run DRF, which is why it works locally and fails under DRF.
Try it https://github.com/Skorpyon/aiorestframework Create some middleware for authentication: import re from xml.etree import ElementTree as etree from json.decoder import JSONDecodeError from multidict import MultiDict, MultiDictProxy from aiohttp import web from aiohttp.hdrs import ( METH_POST, METH_PUT, METH_PATCH, METH_DELETE ) from aiorestframework import exceptions from aiorestframework omport serializers from aiorestframework import Response from aiorestframework.views import BaseViewSet from aiorestframework.permissions import set_permissions, AllowAny from aiorestframework.app import APIApplication from my_project.settings import Settings # Generated by aiohttp-devtools TOKEN_RE = re.compile(r'^\s*BEARER\s{,3}(\S{64})\s*$') async def token_authentication(app, handler): """ Authorization middleware Catching Authorization: BEARER <token> from request headers Found user in Tarantool by token and bind User or AnonymousUser to request """ async def middleware_handler(request): # Check that `Authorization` header exists if 'authorization' in request.headers: authorization = request.headers['authorization'] # Check matches in header value match = TOKEN_RE.match(authorization) if not match: setattr(request, 'user', AnonymousUser()) return await handler(request) else: token = match[1] elif 'authorization_token' in request.query: token = request.query['authorization_token'] else: setattr(request, 'user', AnonymousUser()) return await handler(request) # Try select user auth record from Tarantool by token index res = await app['tnt'].select('auth', [token, ]) cached = res.body if not cached: raise exceptions.AuthenticationFailed() # Build basic user data and bind it to User instance record = cached[0] user = User() user.bind_cached_tarantool(record) # Add User to request setattr(request, 'user', user) return await handler(request) return middleware_handler And for data extraction from request: DATA_METHODS = [METH_POST, METH_PUT, METH_PATCH, METH_DELETE] JSON_CONTENT = ['application/json', ] XML_CONTENT = ['application/xml', ] FORM_CONTENT = ['application/x-www-form-urlencoded', 'multipart/form-data'] async def request_data_handler(app, handler): """ Request .data middleware Try extract POST data or application/json from request body """ async def middleware_handler(request): data = None if request.method in DATA_METHODS: if request.has_body: if request.content_type in JSON_CONTENT: # If request has body - try to decode it to JSON try: data = await request.json() except JSONDecodeError: raise exceptions.ParseError() elif request.content_type in XML_CONTENT: if request.charset is not None: encoding = request.charset else: encoding = api_settings.DEFAULT_CHARSET parser = etree.XMLParser(encoding=encoding) try: text = await request.text() tree = etree.XML(text, parser=parser) except (etree.ParseError, ValueError) as exc: raise exceptions.ParseError( detail='XML parse error - %s' % str(exc)) data = tree elif request.content_type in FORM_CONTENT: data = await request.post() if data is None: # If not catch any data create empty MultiDictProxy data = MultiDictProxy(MultiDict()) # Attach extracted data to request setattr(request, 'data', data) return await handler(request) return middleware_handler Create few serializers: class UserRegisterSerializer(s.Serializer): """Register new user""" email = s.EmailField(max_length=256) password = s.CharField(min_length=8, max_length=64) first_name = s.CharField(min_length=2, max_length=64) middle_name = s.CharField(default='', min_length=2, max_length=64, required=False, allow_blank=True) last_name = s.CharField(min_length=2, max_length=64) phone = s.CharField(max_length=32, required=False, allow_blank=True, default='') async def register_user(self, app): user = User() data = self.validated_data try: await user.register_user(data, app) except Error as e: resolve_db_exception(e, self) return user And few ViewSets. It may be nested in bindings['custom'] class UserViewSet(BaseViewSet): name = 'user' lookup_url_kwarg = '{user_id:[0-9a-f]{32}}' permission_classes = [AllowAny, ] bindings = { 'list': { 'retrieve': 'get', 'update': 'put' }, 'custom': { 'list': { 'set_status': 'post', 'create_new_sip_password': 'post', 'get_registration_domain': 'get', 'report': UserReportViewSet } } } #staticmethod async def resolve_sip_host(data, user, app): sip_host = await resolve_registration_switch(user, app) data.update({'sip_host': sip_host}) async def retrieve(self, request): user = User() await user.load_from_db(request.match_info['user_id'], request.app) serializer = user_ser.UserProfileSerializer(instance=user) data = serializer.data await self.resolve_sip_host(data, user, request.app) return Response(data=data) #atomic #set_permissions([AuthenticatedOnly, IsCompanyMember, CompanyIsEnabled]) async def update(self, request): serializer = user_ser.UserProfileSerializer(data=request.data) serializer.is_valid(raise_exception=True) user = await serializer.update_user(user_id=request.user.id, app=request.app) serializer = user_ser.UserProfileSerializer(instance=user) data = serializer.data await self.resolve_sip_host(data, user, request.app) return Response(data=data) Register Viewsets and run Application: def setup_routes(app: APIApplication): """Add app routes here""" # Auth API app.router.register_viewset('/auth', auth_vws.AuthViewSet()) # User API app.router.register_viewset('/user', user_vws.UserViewSet()) # Root redirection to Swagger redirect = app.router.add_resource('/', name='home_redirect') redirect.add_route('*', swagger_redirect) def create_api_app(): sentry = get_sentry_middleware(settings.SENTRY_CONNECT_STRING, settings.SENTRY_ENVIRONMENT) middlewares = [sentry, token_authentication, request_data_handler] api_app = APIApplication(name='api', middlewares=middlewares, client_max_size=10*(1024**2)) api_app.on_startup.append(startup.startup_api) api_app.on_shutdown.append(startup.shutdown_api) api_app.on_cleanup.append(startup.cleanup_api) setup_routes(api_app) if __name__ == '__main__': app = create_api_app() web.run_app(app)
Tornado testing async requests
I need an advice regards testing tornado app. For now I just playing with demo chat application, but it looks like real-life problem. In the handler I have: class MessageUpdatesHandler(BaseHandler): #tornado.web.authenticated #tornado.web.asynchronous def post(self): cursor = self.get_argument("cursor", None) global_message_buffer.wait_for_messages(self.on_new_messages, cursor=cursor) def on_new_messages(self, messages): # Closed client connection if self.request.connection.stream.closed(): return self.finish(dict(messages=messages)) class MessageBuffer(object): def __init__(self): .... def wait_for_messages(self, callback, cursor=None): if cursor: new_count = 0 for msg in reversed(self.cache): if msg["id"] == cursor: break new_count += 1 if new_count: callback(self.cache[-new_count:]) return self.waiters.add(callback) def cancel_wait(self, callback): ..... def new_messages(self, messages): logging.info("Sending new message to %r listeners", len(self.waiters)) for callback in self.waiters: try: callback(messages) except: logging.error("Error in waiter callback", exc_info=True) self.waiters = set() self.cache.extend(messages) if len(self.cache) > self.cache_size: self.cache = self.cache[-self.cache_size:] As I metioned full source code is in torndado demos In my test I have: #wsgi_safe class MessageUpdatesHandlerTest(LoginedUserHanldersTest): Handler = MessageUpdatesHandler def test_add_message(self): from chatdemo import global_message_buffer kwargs = dict( method="POST", body='', ) future = self.http_client.fetch(self.get_url('/'), callback=self.stop, **kwargs) message = { "id": '123', "from": "first_name", "body": "hello", "html": "html" } global_message_buffer.new_messages([message]) response = self.wait() self.assertEqual(response.code, 200) self.mox.VerifyAll() What happens: It creates a future object It sends a hello message, in this moment no waiter is registered in MessageBuffer so callback is not called In wait starts IoLoop and makes, a post fetch and waiter becomes registered in MessageBuffer Callback is never called and my response remains empty, so everything fails with AssertionError: Async operation timed out after 5 seconds What I want it to do: On post register itself as a waiter Receive some messages Return to me a 200 response Thank you for your help