Writing unittest for package - python

I have created a package which uploads a data to some storage using Azure AD access token, now I want to write test cases for the code, as I'm not aware of writing test cases have tried few. Can anyone help here, below is the code for my package.
__init__.py file
import json
import requests
import sys
from data import Data
import datetime
from singleton import singleton
#singleton
class CertifiedLogProvider:
def __init__(self, client_id, client_secret):
self.client_id=client_id
self.client_secret= client_secret
self.grant_type="client_credentials"
self.resource="*****"
self.url="https://login.microsoftonline.com/azureford.onmicrosoft.com/oauth2/token"
self.api_url="http://example.com"
self.get_access_token()
def get_access_token(self)-> None:
data={'client_id':self.client_id,'client_secret':self.client_secret,
'grant_type':self.grant_type,'resource':self.resource}
response = requests.post(self.url, data=data)
if response.status_code == 200:
self.token_dict=response.content
self.access_token = response.json()["access_token"]
else:
print(f"An Error occurred: \n {response.text}")
def is_expired(self) -> bool:
token_dict=json.loads(self.token_dict.decode('utf-8'))
if int(datetime.datetime.now().timestamp()) > int(token_dict['expires_on']):
return True
else:
return False
def send_clp_data(self,payload:dict):
obj=Data()
data=obj.data
data['event_metric_body']=payload
if self.is_expired() is True:
self.get_access_token()
headers={"Authorization": "Bearer {}".format(self.access_token),
"Content-Type": "application/json",}
response = requests.post(self.api_url,data=json.dumps(data), headers=headers)
if response.status_code == 201:
print('Data uploaded successfully')
else:
print(f"An Error occurred: \n {response.text}")
singleton.py
def singleton(class_):
instances = {}
def getinstance(*args, **kwargs):
if class_ not in instances:
instances[class_] = class_(*args, **kwargs)
return instances[class_]
return getinstance
data.py
Contains data which is static
test.py
import json
import unittest
from unittest import TestCase
from unittest.mock import patch
import requests
from unittest.mock import MagicMock
from __init__ import CertifiedLogProvider
import pytest
class MyTestCase(TestCase):
def test_can_construct_clp_instance(self):
object= CertifiedLogProvider(1,2)
#patch('requests.post')
def test_send_clp_data(self, mock_post):
info={"test1":"value1", "test2": "value2"}
response = requests.post("www.clp_api.com", data=json.dumps(info), headers={'Content-Type': 'application/json'})
mock_post.assert_called_with("www.clp_api.com", data=json.dumps(info), headers={'Content-Type': 'application/json'})
if __name__ == '__main__':
unittest.main()
How can we test boolean method and method containing requests?

Testing is_expired() becomes easier if you create a Token class:
class Token:
def __init__(self, token_dict: dict) -> None:
self.token_dict = token_dict
def __str__(self) -> str:
return self.token_dict['access_token']
def is_expired(self, now:DateTime=None) -> bool:
if now is None:
now = datetime.datetime.now()
return int(now.timestamp()) > int(self.token_dict['expires_on'])
#singleton
class CertifiedLogProvider:
def __init__(self, client_id, client_secret):
self.client_id=client_id
self.client_secret= client_secret
self.grant_type="client_credentials"
self.resource="*****"
self.url="https://login.microsoftonline.com/azureford.onmicrosoft.com/oauth2/token"
self.api_url="http://example.com"
self.access_token = None
def get_access_token(self) -> Token:
if self.access_token is None or self.access_token.is_expired():
self.access_token = self.fetch_access_token()
return self.access_token
def fetch_access_token(self) -> Token:
data={
'client_id': self.client_id,
'client_secret': self.client_secret,
'grant_type': self.grant_type,
'resource': self.resource,
}
response = requests.post(self.url, data=data)
if response.status_code != 200:
raise Exception(f"An Error occurred: \n {response.text}")
return Token(response.json())
def send_clp_data(self, payload: dict):
obj=Data()
data=obj.data
data['event_metric_body'] = payload
headers={
"Authorization": f"Bearer {self.get_access_token()}",
"Content-Type": "application/json",
}
response = requests.post(self.api_url, data=json.dumps(data), headers=headers)
if response.status_code != 201:
raise Exception(f"An Error occurred: \n {response.text}")
print('Data uploaded successfully')
test_token.py
from datetime import datetime
import unittest
from certified_log_provider import Token
class TestToken(unittest.TestCase):
def setUp(self) -> None:
self.token = Token({ 'expires_on': datetime.strptime("2022-02-18", "%Y-%m-%d").timestamp() })
def test_token_is_not_expired_day_before(self):
self.assertFalse(self.token.is_expired(datetime.strptime("2022-02-17", "%Y-%m-%d")))
def test_token_is_expired_day_after(self):
self.assertTrue(self.token.is_expired(datetime.strptime("2022-02-19", "%Y-%m-%d")))

Related

make fastapi middleware returning custom http status instead of AuthenticationError status 400

In the following example when you pass a username in the basic auth field it raise a basic 400 error, but i want to return 401 since it's related to the authentication system.
I did tried Fastapi exceptions classes but they do not raise (i presume since we are in a starlette middleware). Il also tried JSONResponse from starlette but it doesn't work either.
AuthenticationError work and raise a 400 but it's juste an empty class that inherit from Exception so no status code can be given.
Fully working example:
import base64
import binascii
import uvicorn
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer, HTTPBasic
from starlette.authentication import AuthenticationBackend, AuthCredentials, AuthenticationError, BaseUser
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.responses import JSONResponse
class SimpleUserTest(BaseUser):
"""
user object returned to route
"""
def __init__(self, username: str, test1: str, test2: str) -> None:
self.username = username
self.test1 = test1
self.test2 = test2
#property
def is_authenticated(self) -> bool:
return True
async def jwt_auth(auth: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False))):
if auth:
return True
async def key_auth(apikey_header=Depends(HTTPBasic(auto_error=False))):
if apikey_header:
return True
class BasicAuthBackend(AuthenticationBackend):
async def authenticate(self, conn):
if "Authorization" not in conn.headers:
return
auth = conn.headers["Authorization"]
try:
scheme, credentials = auth.split()
if scheme.lower() == 'bearer':
# check bearer content and decode it
user: dict = {"username": "bearer", "test1": "test1", "test2": "test2"}
elif scheme.lower() == 'basic':
decoded = base64.b64decode(credentials).decode("ascii")
username, _, password = decoded.partition(":")
if username:
# check redis here instead of return dict
print("try error raise")
raise AuthenticationError('Invalid basic auth credentials') # <= raise 400, we need 401
# user: dict = {"username": "basic auth", "test1": "test1", "test2": "test2"}
else:
print("error should raise")
return JSONResponse(status_code=401, content={'reason': str("You need to provide a username")})
else:
return JSONResponse(status_code=401, content={'reason': str("Authentication type is not supported")})
except (ValueError, UnicodeDecodeError, binascii.Error) as exc:
raise AuthenticationError('Invalid basic auth credentials')
return AuthCredentials(["authenticated"]), SimpleUserTest(**user)
async def jwt_or_key_auth(jwt_result=Depends(jwt_auth), key_result=Depends(key_auth)):
if not (key_result or jwt_result):
raise HTTPException(status_code=401, detail="Not authenticated")
app = FastAPI(
dependencies=[Depends(jwt_or_key_auth)],
middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuthBackend())]
)
#app.get("/")
async def read_items(request: Request) -> str:
return request.user.__dict__
if __name__ == "__main__":
uvicorn.run("main:app", host="127.0.0.1", port=5000, log_level="info")
if we set username in basic auth:
INFO: 127.0.0.1:22930 - "GET / HTTP/1.1" 400 Bad Request
so i ended up using on_error as suggested by #MatsLindh
old app:
app = FastAPI(
dependencies=[Depends(jwt_or_key_auth)],
middleware=[
Middleware(
AuthenticationMiddleware,
backend=BasicAuthBackend(),
)
],
)
new version:
app = FastAPI(
dependencies=[Depends(jwt_or_key_auth)],
middleware=[
Middleware(
AuthenticationMiddleware,
backend=BasicAuthBackend(),
on_error=lambda conn, exc: JSONResponse({"detail": str(exc)}, status_code=401),
)
],
)
I choose to use JSONResponse and return a "detail" key/value to emulate a classic 401 fastapi httperror

Coinbase Pro API Authentication - Invalid Signature

Hoping to get some help on making calls to the Coinbase Pro API.
I created a key, noted my phasephrase, key and secret, and am running the below Python script. The response I get is "invalid signature".
On the CBPro documentation site, when I try running it with my credentials on this page, I get a "Sorry, you couldn't be authenticated with those credentails" message.
I've seen some sources that encode to base64 and have success, and others that don't, but neither works for me. What am I don't wrong?
Code:
import requests
import time
import base64
import json
url = "https://api.exchange.coinbase.com/accounts/account_id/transfers?limit=100"
key = "key"
secret = "secret"
passphrase = "pass"
timestamp = str(time.time())
headers = {
"Accept": "application/json",
"cb-access-key": key,
"cb-access-passphrase": passphrase,
"cb-access-sign": encodedData,
"cb-access-timestamp": timestamp
}
response = requests.request("GET", url, headers=headers)
print(response.text)
Signing a request is probably the worst part of the coinbase api.
Here is the documentation for it. some things to note:
signature is only good for 30 seconds so you have to hurry and copy / paste the encoded data and timestamp into the docs form.
the timestamp has to be the same value you use in the encodedData.
the non-pro api signs a little different so make sure you're on the right set of docs. This one I don't think works for pro.
If you're still having trouble authenticating your requests to coinbase here's an example of what i'm using. You might have to change a few things but does the job.
What you're looking for in this example are the HMACAuth & CoinbaseSession classes for your particular need.
# -*- coding: UTF-8 -*-
from base64 import b64encode, b64decode
from collections import namedtuple
from datetime import datetime, timezone
from hashlib import sha256
from hmac import HMAC
from json import loads, JSONDecodeError
from types import SimpleNamespace
from typing import Union, Generator
from requests import Session
from requests.adapters import HTTPAdapter
from requests.auth import AuthBase
from requests.exceptions import HTTPError
from requests.models import PreparedRequest
from requests.utils import to_native_string
from urllib3.util.retry import Retry
EXCHANGE: str = r"https://api.exchange.coinbase.com"
ACCOUNT = namedtuple(
"ACCOUNT",
(
"id",
"currency",
"balance",
"hold",
"available",
"profile_id",
"trading_enabled",
)
)
PRODUCT = namedtuple(
"PRODUCT",
(
"id",
"base_currency",
"quote_currency",
"base_min_size",
"base_max_size",
"quote_increment",
"base_increment",
"display_name",
"min_market_funds",
"max_market_funds",
"margin_enabled",
"fx_stablecoin",
"max_slippage_percentage",
"post_only",
"limit_only",
"cancel_only",
"trading_disabled",
"status",
"status_message",
"auction_mode",
)
)
def encode(value: Union[str, bytes]) -> bytes:
"""Encode the string `value` with UTF-8."""
if isinstance(value, str) is True:
value = value.encode("UTF-8")
return value
def decode(value: Union[bytes, str]) -> str:
"""Decode the bytes-like object `value` with UTF-8."""
if isinstance(value, bytes) is True:
value = value.decode("UTF-8")
return value
def req_time():
"""POSIX timestamp as float. Number of seconds since Unix Epoch in UTC."""
utc = get_utc()
return utc.timestamp()
def get_utc() -> datetime:
"""Construct a datetime object with UTC time zone info."""
return datetime.now(timezone.utc)
class TimeoutHTTPAdapter(HTTPAdapter):
"""Custom HTTP adapter with timeout capability."""
def __init__(self, *args, **kwargs):
self._timeout = kwargs.pop("timeout")
super(TimeoutHTTPAdapter, self).__init__(*args, **kwargs)
def send(self, request, **kwargs):
kwargs.update({"timeout": self._timeout})
return super(TimeoutHTTPAdapter, self).send(request, **kwargs)
class CoinbaseSession(Session):
"""Coinbase Session handle."""
_headers: dict = {
"Accept": "application/json",
"Content-Type": "application/json",
"Accept-Charset": "utf-8",
}
#staticmethod
def http_adapter(retries: int = 3, backoff: int = 1, timeout: int = 30):
return TimeoutHTTPAdapter(
max_retries=Retry(total=retries, backoff_factor=backoff),
timeout=timeout
)
def __init__(self):
super(CoinbaseSession, self).__init__()
self.headers.update(self._headers)
self.auth = HMACAuth()
self.mount("https://", self.http_adapter())
self.mount("http://", self.http_adapter())
class HMACAuth(AuthBase):
"""Requests signing handle."""
#staticmethod
def __pre_hash(timestamp: Union[str, int, float], request: PreparedRequest) -> bytes:
"""
Create the pre-hash string by concatenating the timestamp with
the request method, path_url and body if exists.
"""
message = f"{timestamp}{request.method.upper()}{request.path_url}"
body = request.body
if body is not None:
message = f"{message}{decode(body)}"
return encode(message)
#staticmethod
def __sign(message: bytes) -> bytes:
"""Create a sha256 HMAC and sign the required `message`."""
key = b64decode(encode(API.SECRET)) # be careful were you keep this!
hmac = HMAC(key=key, msg=message, digestmod=sha256).digest()
return b64encode(hmac)
def __call__(self, request: PreparedRequest):
timestamp = req_time()
message = self.__pre_hash(timestamp, request)
cb_access_sign = self.__sign(message)
request.headers.update(
{
to_native_string('CB-ACCESS-KEY'): API.KEY, # be careful where you keep this!
to_native_string('CB-ACCESS-SIGN'): cb_access_sign,
to_native_string('CB-ACCESS-TIMESTAMP'): str(timestamp),
to_native_string('CB-ACCESS-PASSPHRASE'): API.PASSPHRASE, # be careful where you keep this!
}
)
return request
class CoinbaseAPI(object):
"""Coinbase API handle."""
def __init__(self):
self._session = CoinbaseSession()
def request(self, **kwargs):
"""
Send HTTP requests to the Coinbase API.
Raises HTTPError if response is not 200.
"""
print(f"DEBUG: Requesting resource (url = {kwargs.get('url')})")
try:
results = self.__request(**kwargs)
except HTTPError as http_error:
print(f"ERROR: Resource not found!", f"Cause: {http_error}")
else:
print(f"DEBUG: Resource found (url = {kwargs.get('url')})")
return results
def __request(self, **kwargs):
"""
Send HTTP requests to the Coinbase API.
Raises HTTPError if response is not 200.
"""
method = getattr(self._session, kwargs.pop("method"))
response = method(**kwargs)
if response.status_code != 200:
response.raise_for_status()
else:
try:
results = loads(response.text)
except JSONDecodeError as json_error:
print("WARNING: Decoding JSON object failed!", f"Cause: {json_error}")
kwargs.update({"method": method.__name__})
return self.__request(**kwargs)
else:
return results
class Endpoints(object):
"""Coinbase server endpoints."""
_server = None
#staticmethod
def join(*path, **params) -> str:
"""
Construct the resource url by appending all `path`
items to base url and join `params` if any.
"""
url = "/".join(path)
if len(params) > 0:
params = [f"{key}={value}" for key, value in params.items()]
url = f"{url}?{'&'.join(params)}"
return url
class ExchangeEndpoints(Endpoints):
"""Coinbase exchange server endpoints."""
_server = EXCHANGE
def __init__(self):
self.products = self.join(self._server, "products")
self.accounts = self.join(self._server, "accounts")
class Exchange(CoinbaseAPI):
"""Coinbase exchange API client handle."""
_endpoints = ExchangeEndpoints()
def __init__(self):
super(Exchange, self).__init__()
def get_accounts(self) -> Generator:
"""Get a list of trading accounts from the profile of the API key."""
response = self._accounts()
for item in response:
yield ACCOUNT(**item)
def get_account(self, account_id: str) -> ACCOUNT:
"""
Information for a single account. Use this endpoint when you know the account_id.
API key must belong to the same profile as the account.
"""
response = self._accounts(account_id)
return ACCOUNT(**response)
def get_products(self, query: str = None) -> Generator:
"""
Gets a list of available currency pairs for trading.
:param query: `type` query parameter (unknown).
"""
response = self._products(type=query)
for item in response:
yield PRODUCT(**item)
def get_product(self, product_id: str) -> PRODUCT:
"""
Get information on a single product.
:param product_id: The `id` string of the product/currency pair (ex: BTC-USD).
"""
result = self._products(product_id)
return PRODUCT(**result)
def _accounts(self, *path, **params):
"""Access resources from the `accounts` endpoint of the exchange API."""
params = self.clean_params(**params)
_url = self._endpoints.join(
self._endpoints.accounts,
*path,
**params
)
return self.request(method="get", url=_url)
def _products(self, *path, **params):
"""Access resources from the `products` endpoint of the exchange API."""
params = self.clean_params(**params)
_url = self._endpoints.join(
self._endpoints.products,
*path,
**params
)
return self.request(method="get", url=_url)
def clean_params(self, **params) -> dict:
"""Clean `params` by removing None values."""
temp = dict()
for key, value in params.items():
if value is None:
continue
if isinstance(value, dict) is True:
value = self.clean_params(**value)
temp.update({key: value})
return temp
if __name__ == '__main__':
# we're using PBKDF2HMAC (with symmetrically derived encryption key)
# not included in this example
key_vault = KeyVault() # custom class for encrypting and storing secrets to keyring
key_vault.cypher.password(value="your_password", salt="salty_password")
api_credentials = loads(
key_vault.get_password("coinbase", "pro-coinbase-api")
)
# accounts:
ADA: str = "your_crypto_account_id" # ex: 8b9806a4-7395-11ec-9b1a-f02f74d9105d
API = SimpleNamespace(
NAME="pro-coinbase-api",
VERSION="2021-08-27",
KEY=api_credentials.get("key"),
PASSPHRASE=api_credentials.get("passphrase"),
SECRET=api_credentials.get("secret"),
) # not the best example but does the job as long as you don't invite hackers in your PC :)
api = Exchange()
accounts = api.get_accounts()
for account in accounts:
print(account)
account = api.get_account(account_id=ADA)
print(account)
products = api.get_products()
for product in products:
print(product)
product = api.get_product("ATOM-EUR")
print(product)

convert synchronous websocket callbacks to async in python using asyncio

need you assistance to convert asynchronous websocket callbacks to async using asyncio in python 3.6.
Some brokers in India provide a websocket for tick data to their users and users can further use this tick data, one such brokers sample code(from their git repo) is pasted here
`from smartapi import SmartWebSocket
# feed_token=092017047
FEED_TOKEN="YOUR_FEED_TOKEN"
CLIENT_CODE="YOUR_CLIENT_CODE"
# token="mcx_fo|224395"
token="EXCHANGE|TOKEN_SYMBOL" #SAMPLE: nse_cm|2885&nse_cm|1594&nse_cm|11536&nse_cm|3045
# token="mcx_fo|226745&mcx_fo|220822&mcx_fo|227182&mcx_fo|221599"
task="mw" # mw|sfi|dp
ss = SmartWebSocket(FEED_TOKEN, CLIENT_CODE)
def on_message(ws, message):
print("Ticks: {}".format(message))
def on_open(ws):
print("on open")
ss.subscribe(task,token)
def on_error(ws, error):
print(error)
def on_close(ws):
print("Close")
# Assign the callbacks.
ss._on_open = on_open
ss._on_message = on_message
ss._on_error = on_error
ss._on_close = on_close
ss.connect()
so what I think the broker is doing is that it provides messages/ticks aka events that probably trigger a callback function on every received messages/tick in a synchronous way.
I however want to alter the above code so that it can work in an async manner, I have tried to write in a manner how binance provides such as async with websocketName('currencyPair')as ts: but this doesn't seem to work with my broker.
Appreciate if you can share some ideas/code/insights around this.
Thank you for your time.
Edit 1:
Thank you for your response Dirn,
Yes, I've tried this
from smartapi import SmartConnect
from smartapi import SmartWebSocket
import logging
import asyncio
CLIENT_CODE = Code_provided_by_broker
PASSWORD = My_Password
obj=SmartConnect(api_key=My_api_key)
data = obj.generateSession(CLIENT_CODE, PASSWORD)
refreshToken = data['data']['refreshToken']
print("refreshToken", refreshToken)
feedToken=obj.getfeedToken()
print("feedToken", feedToken)
FEED_TOKEN= feedToken
token="nse_cm|3499"
# token="nse_fo|53595"
task = 'mw'
async def on_message(ws, message):
print("Ticks: {}".format(message))
async def chk_msg():
# ss = SmartWebSocket(FEED_TOKEN, CLIENT_CODE)
# ss.subscribe(task,token)
# ss._on_message = on_message
# ss.connect()
# print("Ticks: {}".format(message))
async with SmartWebSocket(FEED_TOKEN, CLIENT_CODE) as ss: #throws an error: AttributeError: __aexit__
ss.subscribe(task,token)
ss.connect()
while True:
ss._on_message = on_message # not sure how to await a non-async method _on_message, also there is no use of async module in the modules(imported at start of the script) SmartWebSocket and SmartConnect
if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(chk_msg())
here is the smartConnect.py
from six.moves.urllib.parse import urljoin
import sys
import csv
import json
import dateutil.parser
import hashlib
import logging
import datetime
import smartapi.smartExceptions as ex
import requests
from requests import get
import re, uuid
import socket
import platform
from smartapi.version import __version__, __title__
log = logging.getLogger(__name__)
#user_sys=platform.system()
#print("the system",user_sys)
class SmartConnect(object):
#_rootUrl = "https://openapisuat.angelbroking.com"
_rootUrl="https://apiconnect.angelbroking.com" #prod endpoint
#_login_url ="https://smartapi.angelbroking.com/login"
_login_url="https://smartapi.angelbroking.com/publisher-login" #prod endpoint
_default_timeout = 7 # In seconds
_routes = {
"api.login":"/rest/auth/angelbroking/user/v1/loginByPassword",
"api.logout":"/rest/secure/angelbroking/user/v1/logout",
"api.token": "/rest/auth/angelbroking/jwt/v1/generateTokens",
"api.refresh": "/rest/auth/angelbroking/jwt/v1/generateTokens",
"api.user.profile": "/rest/secure/angelbroking/user/v1/getProfile",
"api.order.place": "/rest/secure/angelbroking/order/v1/placeOrder",
"api.order.modify": "/rest/secure/angelbroking/order/v1/modifyOrder",
"api.order.cancel": "/rest/secure/angelbroking/order/v1/cancelOrder",
"api.order.book":"/rest/secure/angelbroking/order/v1/getOrderBook",
"api.ltp.data": "/rest/secure/angelbroking/order/v1/getLtpData",
"api.trade.book": "/rest/secure/angelbroking/order/v1/getTradeBook",
"api.rms.limit": "/rest/secure/angelbroking/user/v1/getRMS",
"api.holding": "/rest/secure/angelbroking/portfolio/v1/getHolding",
"api.position": "/rest/secure/angelbroking/order/v1/getPosition",
"api.convert.position": "/rest/secure/angelbroking/order/v1/convertPosition",
"api.gtt.create":"/gtt-service/rest/secure/angelbroking/gtt/v1/createRule",
"api.gtt.modify":"/gtt-service/rest/secure/angelbroking/gtt/v1/modifyRule",
"api.gtt.cancel":"/gtt-service/rest/secure/angelbroking/gtt/v1/cancelRule",
"api.gtt.details":"/rest/secure/angelbroking/gtt/v1/ruleDetails",
"api.gtt.list":"/rest/secure/angelbroking/gtt/v1/ruleList",
"api.candle.data":"/rest/secure/angelbroking/historical/v1/getCandleData"
}
try:
clientPublicIp= " " + get('https://api.ipify.org').text
if " " in clientPublicIp:
clientPublicIp=clientPublicIp.replace(" ","")
hostname = socket.gethostname()
clientLocalIp=socket.gethostbyname(hostname)
except Exception as e:
print("Exception while retriving IP Address,using local host IP address",e)
finally:
clientPublicIp="106.193.147.98"
clientLocalIp="127.0.0.1"
clientMacAddress=':'.join(re.findall('..', '%012x' % uuid.getnode()))
accept = "application/json"
userType = "USER"
sourceID = "WEB"
def __init__(self, api_key=None, access_token=None, refresh_token=None,feed_token=None, userId=None, root=None, debug=False, timeout=None, proxies=None, pool=None, disable_ssl=False,accept=None,userType=None,sourceID=None,Authorization=None,clientPublicIP=None,clientMacAddress=None,clientLocalIP=None,privateKey=None):
self.debug = debug
self.api_key = api_key
self.session_expiry_hook = None
self.disable_ssl = disable_ssl
self.access_token = access_token
self.refresh_token = refresh_token
self.feed_token = feed_token
self.userId = userId
self.proxies = proxies if proxies else {}
self.root = root or self._rootUrl
self.timeout = timeout or self._default_timeout
self.Authorization= None
self.clientLocalIP=self.clientLocalIp
self.clientPublicIP=self.clientPublicIp
self.clientMacAddress=self.clientMacAddress
self.privateKey=api_key
self.accept=self.accept
self.userType=self.userType
self.sourceID=self.sourceID
if pool:
self.reqsession = requests.Session()
reqadapter = requests.adapters.HTTPAdapter(**pool)
self.reqsession.mount("https://", reqadapter)
print("in pool")
else:
self.reqsession = requests
# disable requests SSL warning
requests.packages.urllib3.disable_warnings()
def requestHeaders(self):
return{
"Content-type":self.accept,
"X-ClientLocalIP": self.clientLocalIp,
"X-ClientPublicIP": self.clientPublicIp,
"X-MACAddress": self.clientMacAddress,
"Accept": self.accept,
"X-PrivateKey": self.privateKey,
"X-UserType": self.userType,
"X-SourceID": self.sourceID
}
def setSessionExpiryHook(self, method):
if not callable(method):
raise TypeError("Invalid input type. Only functions are accepted.")
self.session_expiry_hook = method
def getUserId():
return userId
def setUserId(self,id):
self.userId=id
def setAccessToken(self, access_token):
self.access_token = access_token
def setRefreshToken(self, refresh_token):
self.refresh_token = refresh_token
def setFeedToken(self,feedToken):
self.feed_token=feedToken
def getfeedToken(self):
return self.feed_token
def login_url(self):
"""Get the remote login url to which a user should be redirected to initiate the login flow."""
return "%s?api_key=%s" % (self._login_url, self.api_key)
def _request(self, route, method, parameters=None):
"""Make an HTTP request."""
params = parameters.copy() if parameters else {}
uri =self._routes[route].format(**params)
url = urljoin(self.root, uri)
# Custom headers
headers = self.requestHeaders()
if self.access_token:
# set authorization header
auth_header = self.access_token
headers["Authorization"] = "Bearer {}".format(auth_header)
if self.debug:
log.debug("Request: {method} {url} {params} {headers}".format(method=method, url=url, params=params, headers=headers))
try:
r = requests.request(method,
url,
data=json.dumps(params) if method in ["POST", "PUT"] else None,
params=json.dumps(params) if method in ["GET", "DELETE"] else None,
headers=headers,
verify=not self.disable_ssl,
allow_redirects=True,
timeout=self.timeout,
proxies=self.proxies)
except Exception as e:
raise e
if self.debug:
log.debug("Response: {code} {content}".format(code=r.status_code, content=r.content))
# Validate the content type.
if "json" in headers["Content-type"]:
try:
data = json.loads(r.content.decode("utf8"))
except ValueError:
raise ex.DataException("Couldn't parse the JSON response received from the server: {content}".format(
content=r.content))
# api error
if data.get("error_type"):
# Call session hook if its registered and TokenException is raised
if self.session_expiry_hook and r.status_code == 403 and data["error_type"] == "TokenException":
self.session_expiry_hook()
# native errors
exp = getattr(ex, data["error_type"], ex.GeneralException)
raise exp(data["message"], code=r.status_code)
return data
elif "csv" in headers["Content-type"]:
return r.content
else:
raise ex.DataException("Unknown Content-type ({content_type}) with response: ({content})".format(
content_type=headers["Content-type"],
content=r.content))
def _deleteRequest(self, route, params=None):
"""Alias for sending a DELETE request."""
return self._request(route, "DELETE", params)
def _putRequest(self, route, params=None):
"""Alias for sending a PUT request."""
return self._request(route, "PUT", params)
def _postRequest(self, route, params=None):
"""Alias for sending a POST request."""
return self._request(route, "POST", params)
def _getRequest(self, route, params=None):
"""Alias for sending a GET request."""
return self._request(route, "GET", params)
def generateSession(self,clientCode,password):
params={"clientcode":clientCode,"password":password}
loginResultObject=self._postRequest("api.login",params)
if loginResultObject['status']==True:
jwtToken=loginResultObject['data']['jwtToken']
self.setAccessToken(jwtToken)
refreshToken=loginResultObject['data']['refreshToken']
feedToken=loginResultObject['data']['feedToken']
self.setRefreshToken(refreshToken)
self.setFeedToken(feedToken)
user=self.getProfile(refreshToken)
id=user['data']['clientcode']
#id='D88311'
self.setUserId(id)
user['data']['jwtToken']="Bearer "+jwtToken
user['data']['refreshToken']=refreshToken
return user
else:
return loginResultObject
def terminateSession(self,clientCode):
logoutResponseObject=self._postRequest("api.logout",{"clientcode":clientCode})
return logoutResponseObject
def generateToken(self,refresh_token):
response=self._postRequest('api.token',{"refreshToken":refresh_token})
jwtToken=response['data']['jwtToken']
feedToken=response['data']['feedToken']
self.setFeedToken(feedToken)
self.setAccessToken(jwtToken)
return response
def renewAccessToken(self):
response =self._postRequest('api.refresh', {
"jwtToken": self.access_token,
"refreshToken": self.refresh_token,
})
tokenSet={}
if "jwtToken" in response:
tokenSet['jwtToken']=response['data']['jwtToken']
tokenSet['clientcode']=self. userId
tokenSet['refreshToken']=response['data']["refreshToken"]
return tokenSet
def getProfile(self,refreshToken):
user=self._getRequest("api.user.profile",{"refreshToken":refreshToken})
return user
def placeOrder(self,orderparams):
params=orderparams
for k in list(params.keys()):
if params[k] is None :
del(params[k])
orderResponse= self._postRequest("api.order.place", params)['data']['orderid']
return orderResponse
def modifyOrder(self,orderparams):
params = orderparams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
orderResponse= self._postRequest("api.order.modify", params)
return orderResponse
def cancelOrder(self, order_id,variety):
orderResponse= self._postRequest("api.order.cancel", {"variety": variety,"orderid": order_id})
return orderResponse
def ltpData(self,exchange,tradingsymbol,symboltoken):
params={
"exchange":exchange,
"tradingsymbol":tradingsymbol,
"symboltoken":symboltoken
}
ltpDataResponse= self._postRequest("api.ltp.data",params)
return ltpDataResponse
def orderBook(self):
orderBookResponse=self._getRequest("api.order.book")
return orderBookResponse
def tradeBook(self):
tradeBookResponse=self._getRequest("api.trade.book")
return tradeBookResponse
def rmsLimit(self):
rmsLimitResponse= self._getRequest("api.rms.limit")
return rmsLimitResponse
def position(self):
positionResponse= self._getRequest("api.position")
return positionResponse
def holding(self):
holdingResponse= self._getRequest("api.holding")
return holdingResponse
def convertPosition(self,positionParams):
params=positionParams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
convertPositionResponse= self._postRequest("api.convert.position",params)
return convertPositionResponse
def gttCreateRule(self,createRuleParams):
params=createRuleParams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
createGttRuleResponse=self._postRequest("api.gtt.create",params)
#print(createGttRuleResponse)
return createGttRuleResponse['data']['id']
def gttModifyRule(self,modifyRuleParams):
params=modifyRuleParams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
modifyGttRuleResponse=self._postRequest("api.gtt.modify",params)
#print(modifyGttRuleResponse)
return modifyGttRuleResponse['data']['id']
def gttCancelRule(self,gttCancelParams):
params=gttCancelParams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
#print(params)
cancelGttRuleResponse=self._postRequest("api.gtt.cancel",params)
#print(cancelGttRuleResponse)
return cancelGttRuleResponse
def gttDetails(self,id):
params={
"id":id
}
gttDetailsResponse=self._postRequest("api.gtt.details",params)
return gttDetailsResponse
def gttLists(self,status,page,count):
if type(status)== list:
params={
"status":status,
"page":page,
"count":count
}
gttListResponse=self._postRequest("api.gtt.list",params)
#print(gttListResponse)
return gttListResponse
else:
message="The status param is entered as" +str(type(status))+". Please enter status param as a list i.e., status=['CANCELLED']"
return message
def getCandleData(self,historicDataParams):
params=historicDataParams
for k in list(params.keys()):
if params[k] is None:
del(params[k])
getCandleDataResponse=self._postRequest("api.candle.data",historicDataParams)
return getCandleDataResponse
def _user_agent(self):
return (__title__ + "-python/").capitalize() + __version__
the smartApiWebsocket.py
import websocket
import six
import base64
import zlib
import datetime
import time
import json
import threading
import ssl
class SmartWebSocket(object):
ROOT_URI='wss://wsfeeds.angelbroking.com/NestHtml5Mobile/socket/stream'
HB_INTERVAL=30
HB_THREAD_FLAG=False
WS_RECONNECT_FLAG=False
feed_token=None
client_code=None
ws=None
task_dict = {}
def __init__(self, FEED_TOKEN, CLIENT_CODE):
self.root = self.ROOT_URI
self.feed_token = FEED_TOKEN
self.client_code = CLIENT_CODE
if self.client_code == None or self.feed_token == None:
return "client_code or feed_token or task is missing"
def _subscribe_on_open(self):
request = {"task": "cn", "channel": "NONLM", "token": self.feed_token, "user": self.client_code,
"acctid": self.client_code}
print(request)
self.ws.send(
six.b(json.dumps(request))
)
thread = threading.Thread(target=self.run, args=())
thread.daemon = True
thread.start()
def run(self):
while True:
# More statements comes here
if self.HB_THREAD_FLAG:
break
print(datetime.datetime.now().__str__() + ' : Start task in the background')
self.heartBeat()
time.sleep(self.HB_INTERVAL)
def subscribe(self, task, token):
# print(self.task_dict)
self.task_dict.update([(task,token),])
# print(self.task_dict)
if task in ("mw", "sfi", "dp"):
strwatchlistscrips = token # dynamic call
try:
request = {"task": task, "channel": strwatchlistscrips, "token": self.feed_token,
"user": self.client_code, "acctid": self.client_code}
self.ws.send(
six.b(json.dumps(request))
)
return True
except Exception as e:
self._close(reason="Error while request sending: {}".format(str(e)))
raise
else:
print("The task entered is invalid, Please enter correct task(mw,sfi,dp) ")
def resubscribe(self):
for task, marketwatch in self.task_dict.items():
print(task, '->', marketwatch)
try:
request = {"task": task, "channel": marketwatch, "token": self.feed_token,
"user": self.client_code, "acctid": self.client_code}
self.ws.send(
six.b(json.dumps(request))
)
return True
except Exception as e:
self._close(reason="Error while request sending: {}".format(str(e)))
raise
def heartBeat(self):
try:
request = {"task": "hb", "channel": "", "token": self.feed_token, "user": self.client_code,
"acctid": self.client_code}
print(request)
self.ws.send(
six.b(json.dumps(request))
)
except:
print("HeartBeat Sending Failed")
# time.sleep(60)
def _parse_text_message(self, message):
"""Parse text message."""
data = base64.b64decode(message)
try:
data = bytes((zlib.decompress(data)).decode("utf-8"), 'utf-8')
data = json.loads(data.decode('utf8').replace("'", '"'))
data = json.loads(json.dumps(data, indent=4, sort_keys=True))
except ValueError:
return
# return data
if data:
self._on_message(self.ws,data)
def connect(self):
# websocket.enableTrace(True)
self.ws = websocket.WebSocketApp(self.ROOT_URI,
on_message=self.__on_message,
on_close=self.__on_close,
on_open=self.__on_open,
on_error=self.__on_error)
self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def __on_message(self, ws, message):
self._parse_text_message(message)
# print(msg)
def __on_open(self, ws):
print("__on_open################")
self.HB_THREAD_FLAG = False
self._subscribe_on_open()
if self.WS_RECONNECT_FLAG:
self.WS_RECONNECT_FLAG = False
self.resubscribe()
else:
self._on_open(ws)
def __on_close(self, ws):
self.HB_THREAD_FLAG = True
print("__on_close################")
self._on_close(ws)
def __on_error(self, ws, error):
if ( "timed" in str(error) ) or ( "Connection is already closed" in str(error) ) or ( "Connection to remote host was lost" in str(error) ):
self.WS_RECONNECT_FLAG = True
self.HB_THREAD_FLAG = True
if (ws is not None):
ws.close()
ws.on_message = None
ws.on_open = None
ws.close = None
# print (' deleting ws')
del ws
self.connect()
else:
print ('Error info: %s' %(error))
self._on_error(ws, error)
def _on_message(self, ws, message):
pass
def _on_open(self, ws):
pass
def _on_close(self, ws):
pass
def _on_error(self, ws, error):
pass
the provider has also imported the websocket module in the above smartApiWebsocket file and has also provided a custom webSocket.py
here is the link webSocket.py
enter link description here

requests.session() object not recognized in another class

I am trying to pass my session object from one class to another. But I am not sure whats happening.
class CreateSession:
def __init__(self, user, pwd, url="http://url_to_hit"):
self.url = url
self.user = user
self.pwd = pwd
def get_session(self):
sess = requests.Session()
r = sess.get(self.url + "/", auth=(self.user, self.pwd))
print(r.content)
return sess
class TestGet(CreateSession):
def get_response(self):
s = self.get_session()
print(s)
data = s.get(self.url + '/some-get')
print(data.status_code)
print(data)
if __name__ == "__main__":
TestGet(user='user', pwd='pwd').get_response()
I am getting 401 for get_response(). Not able to understand this.
What's a 401?
The response you're getting means that you're unauthorised to access the resource.
A session is used in order to persist headers and other prerequisites throughout requests, why are you creating the session every time rather than storing it in a variable?
As is, the session should work the only issue is that you're trying to call a resource that you don't have access to. - You're not passing the url parameter either in the initialisation.
Example of how you can effectively use Session:
from requests import Session
from requests.exceptions import HTTPError
class TestGet:
__session = None
__username = None
__password = None
def __init__(self, username, password):
self.__username = username
self.__password = password
#property
def session(self):
if self.__session is None:
self.__session = Session()
self.__session.auth = (self.__user, self.__pwd)
return self.__session
#session.setter
def session(self, value):
raise AttributeError('Setting \'session\' attribute is prohibited.')
def get_response(self, url):
try:
response = self.session.get(url)
# raises if the status code is an error - 4xx, 5xx
response.raise_for_status()
return response
except HTTPError as e:
# you received an http error .. handle it here (e contains the request and response)
pass
test_get = TestGet('my_user', 'my_pass')
first_response = test_get.get_response('http://your-website-with-basic-auth.com')
second_response = test_get.get_response('http://another-url.com')
my_session = test_get.session
my_session.get('http://url.com')

flask http-auth and unittesting

Hi!
I have a route that I have protected using HTTP Basic authentication, which is implemented by Flask-HTTPAuth. Everything works fine (i can access the route) if i use curl, but when unit testing, the route can't be accessed, even though i provide it with the right username and password.
Here are the relevant code snippets in my testing module:
class TestClient(object):
def __init__(self, app):
self.client = app.test_client()
def send(self, url, method, data=None, headers={}):
if data:
data = json.dumps(data)
rv = method(url, data=data, headers=headers)
return rv, json.loads(rv.data.decode('utf-8'))
def delete(self, url, headers={}):
return self.send(url, self.client.delete, headers)
class TestCase(unittest.TestCase):
def setUp(self):
app.config.from_object('test_config')
self.app = app
self.app_context = self.app.app_context()
self.app_context.push()
db.create_all()
self.client = TestClient(self.app)
def test_delete_user(self):
# create new user
data = {'username': 'john', 'password': 'doe'}
self.client.post('/users', data=data)
# delete previously created user
headers = {}
headers['Authorization'] = 'Basic ' + b64encode((data['username'] + ':' + data['password'])
.encode('utf-8')).decode('utf-8')
headers['Content-Type'] = 'application/json'
headers['Accept'] = 'application/json'
rv, json = self.client.delete('/users', headers=headers)
self.assertTrue(rv.status_code == 200) # Returns 401 instead
Here are the callback methods required by Flask-HTTPAuth:
auth = HTTPBasicAuth()
#auth.verify_password
def verify_password(username, password):
# THIS METHOD NEVER GETS CALLED
user = User.query.filter_by(username=username).first()
if not user or not user.verify_password(password):
return False
g.user = user
return True
#auth.error_handler
def unauthorized():
response = jsonify({'status': 401, 'error': 'unauthorized', 'message': 'Please authenticate to access this API.'})
response.status_code = 401
return response
Any my route:
#app.route('/users', methods=['DELETE'])
#auth.login_required
def delete_user():
db.session.delete(g.user)
db.session.commit()
return jsonify({})
The unit test throws the following exception:
Traceback (most recent call last):
File "test_api.py", line 89, in test_delete_user
self.assertTrue(rv.status_code == 200) # Returns 401 instead
AssertionError: False is not true
I want to emphazise once more that everything works fine when i run curl with exactly the same arguments i provide for my test client, but when i run the test, verify_password method doesn't even get called.
Thank you very much for your help!
Here is an example how this could be done with pytest and the inbuilt monkeypatch fixture.
If I have this API function in some_flask_app:
from flask_httpauth import HTTPBasicAuth
app = Flask(__name__)
auth = HTTPBasicAuth()
#app.route('/api/v1/version')
#auth.login_required
def api_get_version():
return jsonify({'version': get_version()})
I can create a fixture that returns a flask test client and patches the authenticate function in HTTPBasicAuth to always return True:
import pytest
from some_flask_app import app, auth
#pytest.fixture(name='client')
def initialize_authorized_test_client(monkeypatch):
app.testing = True
client = app.test_client()
monkeypatch.setattr(auth, 'authenticate', lambda x, y: True)
yield client
app.testing = False
def test_settings_tracking(client):
r = client.get("/api/v1/version")
assert r.status_code == 200
You are going to love this.
Your send method:
def send(self, url, method, data=None, headers={}):
pass
Your delete method:
def delete(self, url, headers={}):
return self.send(url, self.client.delete, headers)
Note you are passing headers as third positional argument, so it's going as data into send().

Categories