So, I have created a Custom Middleware for my big FastAPI Application, which alters responses from all of my endpoints this way:
Response model is different for all APIs. However, my MDW adds meta data to all of these responses, in an uniform manner. This is what the final response object looks like:
{
"data": <ANY RESPONSE MODEL THAT ALL THOSE ENDPOINTS ARE SENDING>,
"meta_data":
{
"meta_data_1": "meta_value_1",
"meta_data_2": "meta_value_2",
"meta_data_3": "meta_value_3",
}
}
So essentially, all original responses, are wrapped inside a data field, a new field of meta_data is added with all meta_data. This meta_data model is uniform, it will always be of this type:
"meta_data":
{
"meta_data_1": "meta_value_1",
"meta_data_2": "meta_value_2",
"meta_data_3": "meta_value_3",
}
Now the problem is, when the swagger loads up, it shows the original response model in schema and not the final response model which has been prepared. How to alter swagger to reflect this correctly?
I have tried this:
# This model is common to all endpoints!
# Since we are going to add this for all responses
class MetaDataModel(BaseModel):
meta_data_1: str
meta_data_2: str
meta_data_3: str
class FinalResponseForEndPoint1(BaseModel):
data: OriginalResponseForEndpoint1
meta_data: MetaDataModel
class FinalResponseForEndPoint2(BaseModel):
data: OriginalResponseForEndpoint2
meta_data: MetaDataModel
and so on ...
This approach does render the Swagger perfectly, but there are 2 major problems associated with it:
All my FastAPI endpoints break and give me an error when they are returning response. For example: my endpoint1 is still returning the original response but the endpoint1 expects it to send response adhering to FinalResponseForEndPoint1 model
Doing this approach for all models for all my endpoints, does not seem like the right way
Here is a minimal reproducible example with my custom middleware:
from starlette.types import ASGIApp, Receive, Scope, Send, Message
from starlette.requests import Request
import json
from starlette.datastructures import MutableHeaders
from fastapi import FastAPI
class MetaDataAdderMiddleware:
application_generic_urls = ['/openapi.json', '/docs', '/docs/oauth2-redirect', '/redoc']
def __init__(
self,
app: ASGIApp
) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http" and not any([scope["path"].startswith(endpoint) for endpoint in MetaDataAdderMiddleware.application_generic_urls]):
responder = MetaDataAdderMiddlewareResponder(self.app, self.standard_meta_data, self.additional_custom_information)
await responder(scope, receive, send)
return
await self.app(scope, receive, send)
class MetaDataAdderMiddlewareResponder:
def __init__(
self,
app: ASGIApp,
) -> None:
"""
"""
self.app = app
self.initial_message: Message = {}
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.send = send
await self.app(scope, receive, self.send_with_meta_response)
async def send_with_meta_response(self, message: Message):
message_type = message["type"]
if message_type == "http.response.start":
# Don't send the initial message until we've determined how to
# modify the outgoing headers correctly.
self.initial_message = message
elif message_type == "http.response.body":
response_body = json.loads(message["body"].decode())
data = {}
data["data"] = response_body
data['metadata'] = {
'field_1': 'value_1',
'field_2': 'value_2'
}
data_to_be_sent_to_user = json.dumps(data, default=str).encode("utf-8")
headers = MutableHeaders(raw=self.initial_message["headers"])
headers["Content-Length"] = str(len(data_to_be_sent_to_user))
message["body"] = data_to_be_sent_to_user
await self.send(self.initial_message)
await self.send(message)
app = FastAPI(
title="MY DUMMY APP",
)
app.add_middleware(MetaDataAdderMiddleware)
#app.get("/")
async def root():
return {"message": "Hello World"}
If you add default values to the additional fields you can have the middleware update those fields as opposed to creating them.
SO:
from ast import Str
from starlette.types import ASGIApp, Receive, Scope, Send, Message
from starlette.requests import Request
import json
from starlette.datastructures import MutableHeaders
from fastapi import FastAPI
from pydantic import BaseModel, Field
# This model is common to all endpoints!
# Since we are going to add this for all responses
class MetaDataModel(BaseModel):
meta_data_1: str
meta_data_2: str
meta_data_3: str
class ResponseForEndPoint1(BaseModel):
data: str
meta_data: MetaDataModel | None = Field(None, nullable=True)
class MetaDataAdderMiddleware:
application_generic_urls = ['/openapi.json',
'/docs', '/docs/oauth2-redirect', '/redoc']
def __init__(
self,
app: ASGIApp
) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http" and not any([scope["path"].startswith(endpoint) for endpoint in MetaDataAdderMiddleware.application_generic_urls]):
responder = MetaDataAdderMiddlewareResponder(
self.app)
await responder(scope, receive, send)
return
await self.app(scope, receive, send)
class MetaDataAdderMiddlewareResponder:
def __init__(
self,
app: ASGIApp,
) -> None:
"""
"""
self.app = app
self.initial_message: Message = {}
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.send = send
await self.app(scope, receive, self.send_with_meta_response)
async def send_with_meta_response(self, message: Message):
message_type = message["type"]
if message_type == "http.response.start":
# Don't send the initial message until we've determined how to
# modify the outgoing headers correctly.
self.initial_message = message
elif message_type == "http.response.body":
response_body = json.loads(message["body"].decode())
response_body['meta_data'] = {
'field_1': 'value_1',
'field_2': 'value_2'
}
data_to_be_sent_to_user = json.dumps(
response_body, default=str).encode("utf-8")
headers = MutableHeaders(raw=self.initial_message["headers"])
headers["Content-Length"] = str(len(data_to_be_sent_to_user))
message["body"] = data_to_be_sent_to_user
await self.send(self.initial_message)
await self.send(message)
app = FastAPI(
title="MY DUMMY APP",
)
app.add_middleware(MetaDataAdderMiddleware)
#app.get("/", response_model=ResponseForEndPoint1)
async def root():
return ResponseForEndPoint1(data='hello world')
I don't think this is a good solution - but it doesn't throw errors and it does show the correct output in swagger.
In general I'm struggling to find a good way to document the changes/ additional responses that middleware can introduce in openAI/swagger. If you've found anything else I'd be keen to hear it!
I have a huge list of urls that I need to send request and retrieve a json data.But the problem is Since the list with the urls is too big to load it at once, I would like to read the urls one by one, and each time the url is loaded, it should start a request. My code work for small list(~20k) with no problem but I got stuck with a huge list.
It would be great if you could tell me how to change my code, to get it to send asynchronous requests for each url of the urls list. Thank you in advance.
Here is my code:
import json
import urllib
from urllib.parse import quote
import time
import asyncio
import aiohttp
import json
from json.decoder import JSONDecodeError
urls = ["url_1", "url_2". "url_3"........"url_3,000,000"]
START = time.monotonic()
class RateLimiter:
RATE = 20
MAX_TOKENS = 10
def __init__(self, client):
self.client = client
self.tokens = self.MAX_TOKENS
self.updated_at = time.monotonic()
async def get(self, *args, **kwargs):
await self.wait_for_token()
now = time.monotonic() - START
print(f'{now:.0f}s: ask {args[0]}')
return self.client.get(*args, **kwargs)
async def wait_for_token(self):
while self.tokens < 1:
self.add_new_tokens()
await asyncio.sleep(0.1)
self.tokens -= 1
def add_new_tokens(self):
now = time.monotonic()
time_since_update = now - self.updated_at
new_tokens = time_since_update * self.RATE
if self.tokens + new_tokens >= 1:
self.tokens = min(self.tokens + new_tokens, self.MAX_TOKENS)
self.updated_at = now
async def fetch_one(client, url):
# Watch out for the extra 'await' here!
async with await client.get(url) as resp:
for response in resp:
try:
results = await response.json()
try:
answer = results['results'][0]['locations']
output = {
"Provided location" : results['results'][0]['providedLocation'].get('location'),
"City": answer[0].get('adminArea5'),
"State" : answer[0].get('adminArea3'),
"Country": answer[0].get('adminArea1')
}
json_results.append(output)
except (IndexError,JSONDecodeError):
output = {
"Provided location": 'null',
"City": 'null',
"State" : 'null',
"Country":'null'
}
json_results.append(output)
except:
output = {
"Provided location": None,
"City": 'null',
"State" : 'null',
"Country":'null'
}
json_results.append(output)
now = time.monotonic() - START
async def main():
async with aiohttp.ClientSession() as client:
client = RateLimiter(client)
tasks = [asyncio.ensure_future(fetch_one(client, url)) for url in urls]
await asyncio.gather(*tasks)
if __name__ == '__main__':
asyncio.run(main())
need you assistance to convert asynchronous websocket callbacks to async using asyncio in python 3.6.
Some brokers in India provide a websocket for tick data to their users and users can further use this tick data, one such brokers sample code(from their git repo) is pasted here
`from smartapi import SmartWebSocket
# feed_token=092017047
FEED_TOKEN="YOUR_FEED_TOKEN"
CLIENT_CODE="YOUR_CLIENT_CODE"
# token="mcx_fo|224395"
token="EXCHANGE|TOKEN_SYMBOL" #SAMPLE: nse_cm|2885&nse_cm|1594&nse_cm|11536&nse_cm|3045
# token="mcx_fo|226745&mcx_fo|220822&mcx_fo|227182&mcx_fo|221599"
task="mw" # mw|sfi|dp
ss = SmartWebSocket(FEED_TOKEN, CLIENT_CODE)
def on_message(ws, message):
print("Ticks: {}".format(message))
def on_open(ws):
print("on open")
ss.subscribe(task,token)
def on_error(ws, error):
print(error)
def on_close(ws):
print("Close")
# Assign the callbacks.
ss._on_open = on_open
ss._on_message = on_message
ss._on_error = on_error
ss._on_close = on_close
ss.connect()
so what I think the broker is doing is that it provides messages/ticks aka events that probably trigger a callback function on every received messages/tick in a synchronous way.
I however want to alter the above code so that it can work in an async manner, I have tried to write in a manner how binance provides such as async with websocketName('currencyPair')as ts: but this doesn't seem to work with my broker.
Appreciate if you can share some ideas/code/insights around this.
Thank you for your time.
Edit 1:
Thank you for your response Dirn,
Yes, I've tried this
from smartapi import SmartConnect
from smartapi import SmartWebSocket
import logging
import asyncio
CLIENT_CODE = Code_provided_by_broker
PASSWORD = My_Password
obj=SmartConnect(api_key=My_api_key)
data = obj.generateSession(CLIENT_CODE, PASSWORD)
refreshToken = data['data']['refreshToken']
print("refreshToken", refreshToken)
feedToken=obj.getfeedToken()
print("feedToken", feedToken)
FEED_TOKEN= feedToken
token="nse_cm|3499"
# token="nse_fo|53595"
task = 'mw'
async def on_message(ws, message):
print("Ticks: {}".format(message))
async def chk_msg():
# ss = SmartWebSocket(FEED_TOKEN, CLIENT_CODE)
# ss.subscribe(task,token)
# ss._on_message = on_message
# ss.connect()
# print("Ticks: {}".format(message))
async with SmartWebSocket(FEED_TOKEN, CLIENT_CODE) as ss: #throws an error: AttributeError: __aexit__
ss.subscribe(task,token)
ss.connect()
while True:
ss._on_message = on_message # not sure how to await a non-async method _on_message, also there is no use of async module in the modules(imported at start of the script) SmartWebSocket and SmartConnect
if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(chk_msg())
here is the smartConnect.py
from six.moves.urllib.parse import urljoin
import sys
import csv
import json
import dateutil.parser
import hashlib
import logging
import datetime
import smartapi.smartExceptions as ex
import requests
from requests import get
import re, uuid
import socket
import platform
from smartapi.version import __version__, __title__
log = logging.getLogger(__name__)
#user_sys=platform.system()
#print("the system",user_sys)
class SmartConnect(object):
#_rootUrl = "https://openapisuat.angelbroking.com"
_rootUrl="https://apiconnect.angelbroking.com" #prod endpoint
#_login_url ="https://smartapi.angelbroking.com/login"
_login_url="https://smartapi.angelbroking.com/publisher-login" #prod endpoint
_default_timeout = 7 # In seconds
_routes = {
"api.login":"/rest/auth/angelbroking/user/v1/loginByPassword",
"api.logout":"/rest/secure/angelbroking/user/v1/logout",
"api.token": "/rest/auth/angelbroking/jwt/v1/generateTokens",
"api.refresh": "/rest/auth/angelbroking/jwt/v1/generateTokens",
"api.user.profile": "/rest/secure/angelbroking/user/v1/getProfile",
"api.order.place": "/rest/secure/angelbroking/order/v1/placeOrder",
"api.order.modify": "/rest/secure/angelbroking/order/v1/modifyOrder",
"api.order.cancel": "/rest/secure/angelbroking/order/v1/cancelOrder",
"api.order.book":"/rest/secure/angelbroking/order/v1/getOrderBook",
"api.ltp.data": "/rest/secure/angelbroking/order/v1/getLtpData",
"api.trade.book": "/rest/secure/angelbroking/order/v1/getTradeBook",
"api.rms.limit": "/rest/secure/angelbroking/user/v1/getRMS",
"api.holding": "/rest/secure/angelbroking/portfolio/v1/getHolding",
"api.position": "/rest/secure/angelbroking/order/v1/getPosition",
"api.convert.position": "/rest/secure/angelbroking/order/v1/convertPosition",
"api.gtt.create":"/gtt-service/rest/secure/angelbroking/gtt/v1/createRule",
"api.gtt.modify":"/gtt-service/rest/secure/angelbroking/gtt/v1/modifyRule",
"api.gtt.cancel":"/gtt-service/rest/secure/angelbroking/gtt/v1/cancelRule",
"api.gtt.details":"/rest/secure/angelbroking/gtt/v1/ruleDetails",
"api.gtt.list":"/rest/secure/angelbroking/gtt/v1/ruleList",
"api.candle.data":"/rest/secure/angelbroking/historical/v1/getCandleData"
}
try:
clientPublicIp= " " + get('https://api.ipify.org').text
if " " in clientPublicIp:
clientPublicIp=clientPublicIp.replace(" ","")
hostname = socket.gethostname()
clientLocalIp=socket.gethostbyname(hostname)
except Exception as e:
print("Exception while retriving IP Address,using local host IP address",e)
finally:
clientPublicIp="106.193.147.98"
clientLocalIp="127.0.0.1"
clientMacAddress=':'.join(re.findall('..', '%012x' % uuid.getnode()))
accept = "application/json"
userType = "USER"
sourceID = "WEB"
def __init__(self, api_key=None, access_token=None, refresh_token=None,feed_token=None, userId=None, root=None, debug=False, timeout=None, proxies=None, pool=None, disable_ssl=False,accept=None,userType=None,sourceID=None,Authorization=None,clientPublicIP=None,clientMacAddress=None,clientLocalIP=None,privateKey=None):
self.debug = debug
self.api_key = api_key
self.session_expiry_hook = None
self.disable_ssl = disable_ssl
self.access_token = access_token
self.refresh_token = refresh_token
self.feed_token = feed_token
self.userId = userId
self.proxies = proxies if proxies else {}
self.root = root or self._rootUrl
self.timeout = timeout or self._default_timeout
self.Authorization= None
self.clientLocalIP=self.clientLocalIp
self.clientPublicIP=self.clientPublicIp
self.clientMacAddress=self.clientMacAddress
self.privateKey=api_key
self.accept=self.accept
self.userType=self.userType
self.sourceID=self.sourceID
if pool:
self.reqsession = requests.Session()
reqadapter = requests.adapters.HTTPAdapter(**pool)
self.reqsession.mount("https://", reqadapter)
print("in pool")
else:
self.reqsession = requests
# disable requests SSL warning
requests.packages.urllib3.disable_warnings()
def requestHeaders(self):
return{
"Content-type":self.accept,
"X-ClientLocalIP": self.clientLocalIp,
"X-ClientPublicIP": self.clientPublicIp,
"X-MACAddress": self.clientMacAddress,
"Accept": self.accept,
"X-PrivateKey": self.privateKey,
"X-UserType": self.userType,
"X-SourceID": self.sourceID
}
def setSessionExpiryHook(self, method):
if not callable(method):
raise TypeError("Invalid input type. Only functions are accepted.")
self.session_expiry_hook = method
def getUserId():
return userId
def setUserId(self,id):
self.userId=id
def setAccessToken(self, access_token):
self.access_token = access_token
def setRefreshToken(self, refresh_token):
self.refresh_token = refresh_token
def setFeedToken(self,feedToken):
self.feed_token=feedToken
def getfeedToken(self):
return self.feed_token
def login_url(self):
"""Get the remote login url to which a user should be redirected to initiate the login flow."""
return "%s?api_key=%s" % (self._login_url, self.api_key)
def _request(self, route, method, parameters=None):
"""Make an HTTP request."""
params = parameters.copy() if parameters else {}
uri =self._routes[route].format(**params)
url = urljoin(self.root, uri)
# Custom headers
headers = self.requestHeaders()
if self.access_token:
# set authorization header
auth_header = self.access_token
headers["Authorization"] = "Bearer {}".format(auth_header)
if self.debug:
log.debug("Request: {method} {url} {params} {headers}".format(method=method, url=url, params=params, headers=headers))
try:
r = requests.request(method,
url,
data=json.dumps(params) if method in ["POST", "PUT"] else None,
params=json.dumps(params) if method in ["GET", "DELETE"] else None,
headers=headers,
verify=not self.disable_ssl,
allow_redirects=True,
timeout=self.timeout,
proxies=self.proxies)
except Exception as e:
raise e
if self.debug:
log.debug("Response: {code} {content}".format(code=r.status_code, content=r.content))
# Validate the content type.
if "json" in headers["Content-type"]:
try:
data = json.loads(r.content.decode("utf8"))
except ValueError:
raise ex.DataException("Couldn't parse the JSON response received from the server: {content}".format(
content=r.content))
# api error
if data.get("error_type"):
# Call session hook if its registered and TokenException is raised
if self.session_expiry_hook and r.status_code == 403 and data["error_type"] == "TokenException":
self.session_expiry_hook()
# native errors
exp = getattr(ex, data["error_type"], ex.GeneralException)
raise exp(data["message"], code=r.status_code)
return data
elif "csv" in headers["Content-type"]:
return r.content
else:
raise ex.DataException("Unknown Content-type ({content_type}) with response: ({content})".format(
content_type=headers["Content-type"],
content=r.content))
def _deleteRequest(self, route, params=None):
"""Alias for sending a DELETE request."""
return self._request(route, "DELETE", params)
def _putRequest(self, route, params=None):
"""Alias for sending a PUT request."""
return self._request(route, "PUT", params)
def _postRequest(self, route, params=None):
"""Alias for sending a POST request."""
return self._request(route, "POST", params)
def _getRequest(self, route, params=None):
"""Alias for sending a GET request."""
return self._request(route, "GET", params)
def generateSession(self,clientCode,password):
params={"clientcode":clientCode,"password":password}
loginResultObject=self._postRequest("api.login",params)
if loginResultObject['status']==True:
jwtToken=loginResultObject['data']['jwtToken']
self.setAccessToken(jwtToken)
refreshToken=loginResultObject['data']['refreshToken']
feedToken=loginResultObject['data']['feedToken']
self.setRefreshToken(refreshToken)
self.setFeedToken(feedToken)
user=self.getProfile(refreshToken)
id=user['data']['clientcode']
#id='D88311'
self.setUserId(id)
user['data']['jwtToken']="Bearer "+jwtToken
user['data']['refreshToken']=refreshToken
return user
else:
return loginResultObject
def terminateSession(self,clientCode):
logoutResponseObject=self._postRequest("api.logout",{"clientcode":clientCode})
return logoutResponseObject
def generateToken(self,refresh_token):
response=self._postRequest('api.token',{"refreshToken":refresh_token})
jwtToken=response['data']['jwtToken']
feedToken=response['data']['feedToken']
self.setFeedToken(feedToken)
self.setAccessToken(jwtToken)
return response
def renewAccessToken(self):
response =self._postRequest('api.refresh', {
"jwtToken": self.access_token,
"refreshToken": self.refresh_token,
})
tokenSet={}
if "jwtToken" in response:
tokenSet['jwtToken']=response['data']['jwtToken']
tokenSet['clientcode']=self. userId
tokenSet['refreshToken']=response['data']["refreshToken"]
return tokenSet
def getProfile(self,refreshToken):
user=self._getRequest("api.user.profile",{"refreshToken":refreshToken})
return user
def placeOrder(self,orderparams):
params=orderparams
for k in list(params.keys()):
if params[k] is None :
del(params[k])
orderResponse= self._postRequest("api.order.place", params)['data']['orderid']
return orderResponse
def modifyOrder(self,orderparams):
params = orderparams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
orderResponse= self._postRequest("api.order.modify", params)
return orderResponse
def cancelOrder(self, order_id,variety):
orderResponse= self._postRequest("api.order.cancel", {"variety": variety,"orderid": order_id})
return orderResponse
def ltpData(self,exchange,tradingsymbol,symboltoken):
params={
"exchange":exchange,
"tradingsymbol":tradingsymbol,
"symboltoken":symboltoken
}
ltpDataResponse= self._postRequest("api.ltp.data",params)
return ltpDataResponse
def orderBook(self):
orderBookResponse=self._getRequest("api.order.book")
return orderBookResponse
def tradeBook(self):
tradeBookResponse=self._getRequest("api.trade.book")
return tradeBookResponse
def rmsLimit(self):
rmsLimitResponse= self._getRequest("api.rms.limit")
return rmsLimitResponse
def position(self):
positionResponse= self._getRequest("api.position")
return positionResponse
def holding(self):
holdingResponse= self._getRequest("api.holding")
return holdingResponse
def convertPosition(self,positionParams):
params=positionParams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
convertPositionResponse= self._postRequest("api.convert.position",params)
return convertPositionResponse
def gttCreateRule(self,createRuleParams):
params=createRuleParams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
createGttRuleResponse=self._postRequest("api.gtt.create",params)
#print(createGttRuleResponse)
return createGttRuleResponse['data']['id']
def gttModifyRule(self,modifyRuleParams):
params=modifyRuleParams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
modifyGttRuleResponse=self._postRequest("api.gtt.modify",params)
#print(modifyGttRuleResponse)
return modifyGttRuleResponse['data']['id']
def gttCancelRule(self,gttCancelParams):
params=gttCancelParams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
#print(params)
cancelGttRuleResponse=self._postRequest("api.gtt.cancel",params)
#print(cancelGttRuleResponse)
return cancelGttRuleResponse
def gttDetails(self,id):
params={
"id":id
}
gttDetailsResponse=self._postRequest("api.gtt.details",params)
return gttDetailsResponse
def gttLists(self,status,page,count):
if type(status)== list:
params={
"status":status,
"page":page,
"count":count
}
gttListResponse=self._postRequest("api.gtt.list",params)
#print(gttListResponse)
return gttListResponse
else:
message="The status param is entered as" +str(type(status))+". Please enter status param as a list i.e., status=['CANCELLED']"
return message
def getCandleData(self,historicDataParams):
params=historicDataParams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
getCandleDataResponse=self._postRequest("api.candle.data",historicDataParams)
return getCandleDataResponse
def _user_agent(self):
return (__title__ + "-python/").capitalize() + __version__
the smartApiWebsocket.py
import websocket
import six
import base64
import zlib
import datetime
import time
import json
import threading
import ssl
class SmartWebSocket(object):
ROOT_URI='wss://wsfeeds.angelbroking.com/NestHtml5Mobile/socket/stream'
HB_INTERVAL=30
HB_THREAD_FLAG=False
WS_RECONNECT_FLAG=False
feed_token=None
client_code=None
ws=None
task_dict = {}
def __init__(self, FEED_TOKEN, CLIENT_CODE):
self.root = self.ROOT_URI
self.feed_token = FEED_TOKEN
self.client_code = CLIENT_CODE
if self.client_code == None or self.feed_token == None:
return "client_code or feed_token or task is missing"
def _subscribe_on_open(self):
request = {"task": "cn", "channel": "NONLM", "token": self.feed_token, "user": self.client_code,
"acctid": self.client_code}
print(request)
self.ws.send(
six.b(json.dumps(request))
)
thread = threading.Thread(target=self.run, args=())
thread.daemon = True
thread.start()
def run(self):
while True:
# More statements comes here
if self.HB_THREAD_FLAG:
break
print(datetime.datetime.now().__str__() + ' : Start task in the background')
self.heartBeat()
time.sleep(self.HB_INTERVAL)
def subscribe(self, task, token):
# print(self.task_dict)
self.task_dict.update([(task,token),])
# print(self.task_dict)
if task in ("mw", "sfi", "dp"):
strwatchlistscrips = token # dynamic call
try:
request = {"task": task, "channel": strwatchlistscrips, "token": self.feed_token,
"user": self.client_code, "acctid": self.client_code}
self.ws.send(
six.b(json.dumps(request))
)
return True
except Exception as e:
self._close(reason="Error while request sending: {}".format(str(e)))
raise
else:
print("The task entered is invalid, Please enter correct task(mw,sfi,dp) ")
def resubscribe(self):
for task, marketwatch in self.task_dict.items():
print(task, '->', marketwatch)
try:
request = {"task": task, "channel": marketwatch, "token": self.feed_token,
"user": self.client_code, "acctid": self.client_code}
self.ws.send(
six.b(json.dumps(request))
)
return True
except Exception as e:
self._close(reason="Error while request sending: {}".format(str(e)))
raise
def heartBeat(self):
try:
request = {"task": "hb", "channel": "", "token": self.feed_token, "user": self.client_code,
"acctid": self.client_code}
print(request)
self.ws.send(
six.b(json.dumps(request))
)
except:
print("HeartBeat Sending Failed")
# time.sleep(60)
def _parse_text_message(self, message):
"""Parse text message."""
data = base64.b64decode(message)
try:
data = bytes((zlib.decompress(data)).decode("utf-8"), 'utf-8')
data = json.loads(data.decode('utf8').replace("'", '"'))
data = json.loads(json.dumps(data, indent=4, sort_keys=True))
except ValueError:
return
# return data
if data:
self._on_message(self.ws,data)
def connect(self):
# websocket.enableTrace(True)
self.ws = websocket.WebSocketApp(self.ROOT_URI,
on_message=self.__on_message,
on_close=self.__on_close,
on_open=self.__on_open,
on_error=self.__on_error)
self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def __on_message(self, ws, message):
self._parse_text_message(message)
# print(msg)
def __on_open(self, ws):
print("__on_open################")
self.HB_THREAD_FLAG = False
self._subscribe_on_open()
if self.WS_RECONNECT_FLAG:
self.WS_RECONNECT_FLAG = False
self.resubscribe()
else:
self._on_open(ws)
def __on_close(self, ws):
self.HB_THREAD_FLAG = True
print("__on_close################")
self._on_close(ws)
def __on_error(self, ws, error):
if ( "timed" in str(error) ) or ( "Connection is already closed" in str(error) ) or ( "Connection to remote host was lost" in str(error) ):
self.WS_RECONNECT_FLAG = True
self.HB_THREAD_FLAG = True
if (ws is not None):
ws.close()
ws.on_message = None
ws.on_open = None
ws.close = None
# print (' deleting ws')
del ws
self.connect()
else:
print ('Error info: %s' %(error))
self._on_error(ws, error)
def _on_message(self, ws, message):
pass
def _on_open(self, ws):
pass
def _on_close(self, ws):
pass
def _on_error(self, ws, error):
pass
the provider has also imported the websocket module in the above smartApiWebsocket file and has also provided a custom webSocket.py
here is the link webSocket.py
enter link description here
I'm trying to write some async tests in FastAPI using Tortoise ORM under Python 3.8 but I keep getting the same errors (seen at the end). I've been trying to figure this out for the past few days but somehow all my recent efforts in creating tests have been unsuccessful.
I'm following the fastapi docs and tortoise docs on this one.
main.py
# UserPy is a pydantic model
#app.post('/testpost')
async def world(user: UserPy) -> UserPy:
await User.create(**user.dict())
# Just returns the user model
return user
simple_test.py
from fastapi.testclient import TestClient
from httpx import AsyncClient
#pytest.fixture
def client1():
with TestClient(app) as tc:
yield tc
#pytest.fixture
def client2():
initializer(DATABASE_MODELS, DATABASE_URL)
with TestClient(app) as tc:
yield tc
finalizer()
#pytest.fixture
def event_loop(client2): # Been using client1 and client2 on this
yield client2.task.get_loop()
# The test
#pytest.mark.asyncio
def test_testpost(client2, event_loop):
name, age = ['sam', 99]
data = json.dumps(dict(username=name, age=age))
res = client2.post('/testpost', data=data)
assert res.status_code == 200
# Sample query
async def getx(id):
return await User.get(pk=id)
x = event_loop.run_until_complete(getx(123))
assert x.id == 123
# end of code
My errors vary on whether I'm usinng client1 or client2
Using client1 error
RuntimeError: Task <Task pending name='Task-9' coro=<TestClient.wait_shutdown() running at <my virtualenv path>/site-packages/starlette/testclient.py:487> cb=[_run_until_complete_cb() at /usr/lib/python3.8/asyncio/base_events.py:184]> got Future <Future pending> attached to a different loop
Using client2 error
asyncpg.exceptions.ObjectInUseError: cannot drop the currently open database
Oh, I've also tried using httpx.AsyncClient but still no success (and more errors). Any ideas because I'm out of my own.
It cost me about one hour to make the async test worked. Here is the example:
(Python3.8+ is required)
conftest.py
import pytest
from httpx import AsyncClient
from tortoise import Tortoise
from main import app
DB_URL = "sqlite://:memory:"
async def init_db(db_url, create_db: bool = False, schemas: bool = False) -> None:
"""Initial database connection"""
await Tortoise.init(
db_url=db_url, modules={"models": ["models"]}, _create_db=create_db
)
if create_db:
print(f"Database created! {db_url = }")
if schemas:
await Tortoise.generate_schemas()
print("Success to generate schemas")
async def init(db_url: str = DB_URL):
await init_db(db_url, True, True)
#pytest.fixture(scope="session")
def anyio_backend():
return "asyncio"
#pytest.fixture(scope="session")
async def client():
async with AsyncClient(app=app, base_url="http://test") as client:
print("Client is ready")
yield client
#pytest.fixture(scope="session", autouse=True)
async def initialize_tests():
await init()
yield
await Tortoise._drop_databases()
settings.py
import os
from dotenv import load_dotenv
load_dotenv()
DB_NAME = "async_test"
DB_URL = os.getenv(
"APP_DB_URL", f"postgres://postgres:postgres#127.0.0.1:5432/{DB_NAME}"
)
ALLOW_ORIGINS = [
"http://localhost",
"http://localhost:8080",
"http://localhost:8000",
"https://example.com",
]
main.py
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from models.users import User, User_Pydantic, User_Pydantic_List, UserIn_Pydantic
from settings import ALLOW_ORIGINS, DB_URL
from tortoise.contrib.fastapi import register_tortoise
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOW_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
#app.post("/testpost", response_model=User_Pydantic)
async def world(user: UserIn_Pydantic):
return await User.create(**user.dict())
#app.get("/users", response_model=User_Pydantic_List)
async def user_list():
return await User.all()
register_tortoise(
app,
config={
"connections": {"default": DB_URL},
"apps": {"models": {"models": ["models"]}},
"use_tz": True,
"timezone": "Asia/Shanghai",
"generate_schemas": True,
},
)
models/base.py
from typing import List, Set, Tuple, Union
from tortoise import fields, models
from tortoise.queryset import Q, QuerySet
def reduce_query_filters(args: Tuple[Q, ...]) -> Set:
fields = set()
for q in args:
fields |= set(q.filters)
c: Union[List[Q], Tuple[Q, ...]] = q.children
while c:
_c: List[Q] = []
for i in c:
fields |= set(i.filters)
_c += list(i.children)
c = _c
return fields
class AbsModel(models.Model):
id = fields.IntField(pk=True)
created_at = fields.DatetimeField(auto_now_add=True, description="Created At")
updated_at = fields.DatetimeField(auto_now=True, description="Updated At")
is_deleted = fields.BooleanField(default=False, description="Mark as Deleted")
class Meta:
abstract = True
ordering = ("-id",)
#classmethod
def filter(cls, *args, **kwargs) -> QuerySet:
field = "is_deleted"
if not args or (field not in reduce_query_filters(args)):
kwargs.setdefault(field, False)
return super().filter(*args, **kwargs)
class PydanticMeta:
exclude = ("created_at", "updated_at", "is_deleted")
def __repr__(self):
return f"<{self.__class__.__name__} {self.id}>"
models/users.py
from tortoise.contrib.pydantic import pydantic_model_creator, pydantic_queryset_creator
from .base import AbsModel, fields
class User(AbsModel):
username = fields.CharField(60)
age = fields.IntField()
class Meta:
table = "users"
def __str__(self):
return self.name
User_Pydantic = pydantic_model_creator(User)
UserIn_Pydantic = pydantic_model_creator(User, name="UserIn", exclude_readonly=True)
User_Pydantic_List = pydantic_queryset_creator(User)
models/__init__.py
from .users import User # NOQA: F401
tests/test_users.py
import pytest
from httpx import AsyncClient
from models.users import User
#pytest.mark.anyio
async def test_testpost(client: AsyncClient):
name, age = ["sam", 99]
assert await User.filter(username=name).count() == 0
data = {"username": name, "age": age}
response = await client.post("/testpost", json=data)
assert response.json() == dict(data, id=1)
assert response.status_code == 200
response = await client.get("/users")
assert response.status_code == 200
assert response.json() == [dict(data, id=1)]
assert await User.filter(username=name).count() == 1
Source code of the demo had been post to github:
https://github.com/waketzheng/fastapi-tortoise-pytest-demo.git