I'm trying to write some async tests in FastAPI using Tortoise ORM under Python 3.8 but I keep getting the same errors (seen at the end). I've been trying to figure this out for the past few days but somehow all my recent efforts in creating tests have been unsuccessful.
I'm following the fastapi docs and tortoise docs on this one.
main.py
# UserPy is a pydantic model
#app.post('/testpost')
async def world(user: UserPy) -> UserPy:
await User.create(**user.dict())
# Just returns the user model
return user
simple_test.py
from fastapi.testclient import TestClient
from httpx import AsyncClient
#pytest.fixture
def client1():
with TestClient(app) as tc:
yield tc
#pytest.fixture
def client2():
initializer(DATABASE_MODELS, DATABASE_URL)
with TestClient(app) as tc:
yield tc
finalizer()
#pytest.fixture
def event_loop(client2): # Been using client1 and client2 on this
yield client2.task.get_loop()
# The test
#pytest.mark.asyncio
def test_testpost(client2, event_loop):
name, age = ['sam', 99]
data = json.dumps(dict(username=name, age=age))
res = client2.post('/testpost', data=data)
assert res.status_code == 200
# Sample query
async def getx(id):
return await User.get(pk=id)
x = event_loop.run_until_complete(getx(123))
assert x.id == 123
# end of code
My errors vary on whether I'm usinng client1 or client2
Using client1 error
RuntimeError: Task <Task pending name='Task-9' coro=<TestClient.wait_shutdown() running at <my virtualenv path>/site-packages/starlette/testclient.py:487> cb=[_run_until_complete_cb() at /usr/lib/python3.8/asyncio/base_events.py:184]> got Future <Future pending> attached to a different loop
Using client2 error
asyncpg.exceptions.ObjectInUseError: cannot drop the currently open database
Oh, I've also tried using httpx.AsyncClient but still no success (and more errors). Any ideas because I'm out of my own.
It cost me about one hour to make the async test worked. Here is the example:
(Python3.8+ is required)
conftest.py
import pytest
from httpx import AsyncClient
from tortoise import Tortoise
from main import app
DB_URL = "sqlite://:memory:"
async def init_db(db_url, create_db: bool = False, schemas: bool = False) -> None:
"""Initial database connection"""
await Tortoise.init(
db_url=db_url, modules={"models": ["models"]}, _create_db=create_db
)
if create_db:
print(f"Database created! {db_url = }")
if schemas:
await Tortoise.generate_schemas()
print("Success to generate schemas")
async def init(db_url: str = DB_URL):
await init_db(db_url, True, True)
#pytest.fixture(scope="session")
def anyio_backend():
return "asyncio"
#pytest.fixture(scope="session")
async def client():
async with AsyncClient(app=app, base_url="http://test") as client:
print("Client is ready")
yield client
#pytest.fixture(scope="session", autouse=True)
async def initialize_tests():
await init()
yield
await Tortoise._drop_databases()
settings.py
import os
from dotenv import load_dotenv
load_dotenv()
DB_NAME = "async_test"
DB_URL = os.getenv(
"APP_DB_URL", f"postgres://postgres:postgres#127.0.0.1:5432/{DB_NAME}"
)
ALLOW_ORIGINS = [
"http://localhost",
"http://localhost:8080",
"http://localhost:8000",
"https://example.com",
]
main.py
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from models.users import User, User_Pydantic, User_Pydantic_List, UserIn_Pydantic
from settings import ALLOW_ORIGINS, DB_URL
from tortoise.contrib.fastapi import register_tortoise
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOW_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
#app.post("/testpost", response_model=User_Pydantic)
async def world(user: UserIn_Pydantic):
return await User.create(**user.dict())
#app.get("/users", response_model=User_Pydantic_List)
async def user_list():
return await User.all()
register_tortoise(
app,
config={
"connections": {"default": DB_URL},
"apps": {"models": {"models": ["models"]}},
"use_tz": True,
"timezone": "Asia/Shanghai",
"generate_schemas": True,
},
)
models/base.py
from typing import List, Set, Tuple, Union
from tortoise import fields, models
from tortoise.queryset import Q, QuerySet
def reduce_query_filters(args: Tuple[Q, ...]) -> Set:
fields = set()
for q in args:
fields |= set(q.filters)
c: Union[List[Q], Tuple[Q, ...]] = q.children
while c:
_c: List[Q] = []
for i in c:
fields |= set(i.filters)
_c += list(i.children)
c = _c
return fields
class AbsModel(models.Model):
id = fields.IntField(pk=True)
created_at = fields.DatetimeField(auto_now_add=True, description="Created At")
updated_at = fields.DatetimeField(auto_now=True, description="Updated At")
is_deleted = fields.BooleanField(default=False, description="Mark as Deleted")
class Meta:
abstract = True
ordering = ("-id",)
#classmethod
def filter(cls, *args, **kwargs) -> QuerySet:
field = "is_deleted"
if not args or (field not in reduce_query_filters(args)):
kwargs.setdefault(field, False)
return super().filter(*args, **kwargs)
class PydanticMeta:
exclude = ("created_at", "updated_at", "is_deleted")
def __repr__(self):
return f"<{self.__class__.__name__} {self.id}>"
models/users.py
from tortoise.contrib.pydantic import pydantic_model_creator, pydantic_queryset_creator
from .base import AbsModel, fields
class User(AbsModel):
username = fields.CharField(60)
age = fields.IntField()
class Meta:
table = "users"
def __str__(self):
return self.name
User_Pydantic = pydantic_model_creator(User)
UserIn_Pydantic = pydantic_model_creator(User, name="UserIn", exclude_readonly=True)
User_Pydantic_List = pydantic_queryset_creator(User)
models/__init__.py
from .users import User # NOQA: F401
tests/test_users.py
import pytest
from httpx import AsyncClient
from models.users import User
#pytest.mark.anyio
async def test_testpost(client: AsyncClient):
name, age = ["sam", 99]
assert await User.filter(username=name).count() == 0
data = {"username": name, "age": age}
response = await client.post("/testpost", json=data)
assert response.json() == dict(data, id=1)
assert response.status_code == 200
response = await client.get("/users")
assert response.status_code == 200
assert response.json() == [dict(data, id=1)]
assert await User.filter(username=name).count() == 1
Source code of the demo had been post to github:
https://github.com/waketzheng/fastapi-tortoise-pytest-demo.git
Related
I have updated the Django version for my project from Django-2.2.16 --> Django3.2.14.
But with this update, some of my test cases are failing and I cannot understand the reason for the failure.
My test-case file:
import json
from os import access
from unittest import mock
from unittest.mock import patch
import asyncio
from app.models import UserProfile
from django.test import TestCase, TransactionTestCase
from requests.models import Response
from services.check_status import check_status
loop = asyncio.get_event_loop()
#mock.patch("services.check_status.save_status")
#mock.patch("services..check_status.send_wss_message")
class TestUserStatus(TransactionTestCase):
def setUp(self):
super().setUp()
self.account_id = 4
self.sim_profile = UserProfile.objects.create()
def test_check_status_completed(
self,
mock_send_wss_message,
mock_save_status,
):
mock_save_status.return_value = {}
send_wss_message_future = asyncio.Future()
send_wss_message_future.set_result(True)
mock_send_wss_message.return_value = send_wss_message_future
loop.run_until_complete(
check_status(
self.sim_profile.id,
)
)
self.assertTrue(mock_save_status.called)
self.assertTrue(mock_send_wss_message.called)
My pseudo check_status file is :-
import logging
from app.models import UserProfile, UserStatus
from services.constants import WebsocketGroups
from services.user.constants import USER
from app.api.serializers import UserStatusSerializer
from services.utils import send_wss_message, Client
logger = logging.getLogger(__name__)
def save_status(**kwargs):
status = UserStatus.objects.filter(
status_id=kwargs.get("status_id")
).first()
data = kwargs
user_status_serializer = UserStatusSerializer(status, data, partial=True)
if user_status_serializer.is_valid():
user_status_serializer.save()
async def check_status(
profile_id
):
user_profile = UserProfile.objects.get(id=profile_id)
login_token = get_login_token(user_profile)
user_creds = env["user_api"]
headers = USER["headers"]
subscription_details = Client.get(
USER["url"], headers
)
transaction_status = subscription_details.json()["Status"]
subscription_data = subscription_details.json()["Data"][0]
transaction_status_details = subscription_data["TransactionStatusDetails"]
error_message = ""
status = ""
if transaction_status == "Success":
#perform some actions and save status...
message = {
"type": "user_profile",
"data": [user_profile.id, transaction_status, {"results": {}},],
}
await send_wss_message(
user_profile.id, message=message, group_name=WebsocketGroups.USER_PROFILE,
)
else:
#perform some actions ...
When I am running my test-case file it's creating the UserProfile object but when control goes to the check_status function in int UserProfile.objects.all returns <QuerySet []>.
I made a temporary sync function to return a list of all user profiles and called it inside my test_check_status_completed and it returned the list. But for async functions that are called through the loop.run_until_complete, they all returned <QuerySet []>.
So, I have created a Custom Middleware for my big FastAPI Application, which alters responses from all of my endpoints this way:
Response model is different for all APIs. However, my MDW adds meta data to all of these responses, in an uniform manner. This is what the final response object looks like:
{
"data": <ANY RESPONSE MODEL THAT ALL THOSE ENDPOINTS ARE SENDING>,
"meta_data":
{
"meta_data_1": "meta_value_1",
"meta_data_2": "meta_value_2",
"meta_data_3": "meta_value_3",
}
}
So essentially, all original responses, are wrapped inside a data field, a new field of meta_data is added with all meta_data. This meta_data model is uniform, it will always be of this type:
"meta_data":
{
"meta_data_1": "meta_value_1",
"meta_data_2": "meta_value_2",
"meta_data_3": "meta_value_3",
}
Now the problem is, when the swagger loads up, it shows the original response model in schema and not the final response model which has been prepared. How to alter swagger to reflect this correctly?
I have tried this:
# This model is common to all endpoints!
# Since we are going to add this for all responses
class MetaDataModel(BaseModel):
meta_data_1: str
meta_data_2: str
meta_data_3: str
class FinalResponseForEndPoint1(BaseModel):
data: OriginalResponseForEndpoint1
meta_data: MetaDataModel
class FinalResponseForEndPoint2(BaseModel):
data: OriginalResponseForEndpoint2
meta_data: MetaDataModel
and so on ...
This approach does render the Swagger perfectly, but there are 2 major problems associated with it:
All my FastAPI endpoints break and give me an error when they are returning response. For example: my endpoint1 is still returning the original response but the endpoint1 expects it to send response adhering to FinalResponseForEndPoint1 model
Doing this approach for all models for all my endpoints, does not seem like the right way
Here is a minimal reproducible example with my custom middleware:
from starlette.types import ASGIApp, Receive, Scope, Send, Message
from starlette.requests import Request
import json
from starlette.datastructures import MutableHeaders
from fastapi import FastAPI
class MetaDataAdderMiddleware:
application_generic_urls = ['/openapi.json', '/docs', '/docs/oauth2-redirect', '/redoc']
def __init__(
self,
app: ASGIApp
) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http" and not any([scope["path"].startswith(endpoint) for endpoint in MetaDataAdderMiddleware.application_generic_urls]):
responder = MetaDataAdderMiddlewareResponder(self.app, self.standard_meta_data, self.additional_custom_information)
await responder(scope, receive, send)
return
await self.app(scope, receive, send)
class MetaDataAdderMiddlewareResponder:
def __init__(
self,
app: ASGIApp,
) -> None:
"""
"""
self.app = app
self.initial_message: Message = {}
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.send = send
await self.app(scope, receive, self.send_with_meta_response)
async def send_with_meta_response(self, message: Message):
message_type = message["type"]
if message_type == "http.response.start":
# Don't send the initial message until we've determined how to
# modify the outgoing headers correctly.
self.initial_message = message
elif message_type == "http.response.body":
response_body = json.loads(message["body"].decode())
data = {}
data["data"] = response_body
data['metadata'] = {
'field_1': 'value_1',
'field_2': 'value_2'
}
data_to_be_sent_to_user = json.dumps(data, default=str).encode("utf-8")
headers = MutableHeaders(raw=self.initial_message["headers"])
headers["Content-Length"] = str(len(data_to_be_sent_to_user))
message["body"] = data_to_be_sent_to_user
await self.send(self.initial_message)
await self.send(message)
app = FastAPI(
title="MY DUMMY APP",
)
app.add_middleware(MetaDataAdderMiddleware)
#app.get("/")
async def root():
return {"message": "Hello World"}
If you add default values to the additional fields you can have the middleware update those fields as opposed to creating them.
SO:
from ast import Str
from starlette.types import ASGIApp, Receive, Scope, Send, Message
from starlette.requests import Request
import json
from starlette.datastructures import MutableHeaders
from fastapi import FastAPI
from pydantic import BaseModel, Field
# This model is common to all endpoints!
# Since we are going to add this for all responses
class MetaDataModel(BaseModel):
meta_data_1: str
meta_data_2: str
meta_data_3: str
class ResponseForEndPoint1(BaseModel):
data: str
meta_data: MetaDataModel | None = Field(None, nullable=True)
class MetaDataAdderMiddleware:
application_generic_urls = ['/openapi.json',
'/docs', '/docs/oauth2-redirect', '/redoc']
def __init__(
self,
app: ASGIApp
) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http" and not any([scope["path"].startswith(endpoint) for endpoint in MetaDataAdderMiddleware.application_generic_urls]):
responder = MetaDataAdderMiddlewareResponder(
self.app)
await responder(scope, receive, send)
return
await self.app(scope, receive, send)
class MetaDataAdderMiddlewareResponder:
def __init__(
self,
app: ASGIApp,
) -> None:
"""
"""
self.app = app
self.initial_message: Message = {}
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.send = send
await self.app(scope, receive, self.send_with_meta_response)
async def send_with_meta_response(self, message: Message):
message_type = message["type"]
if message_type == "http.response.start":
# Don't send the initial message until we've determined how to
# modify the outgoing headers correctly.
self.initial_message = message
elif message_type == "http.response.body":
response_body = json.loads(message["body"].decode())
response_body['meta_data'] = {
'field_1': 'value_1',
'field_2': 'value_2'
}
data_to_be_sent_to_user = json.dumps(
response_body, default=str).encode("utf-8")
headers = MutableHeaders(raw=self.initial_message["headers"])
headers["Content-Length"] = str(len(data_to_be_sent_to_user))
message["body"] = data_to_be_sent_to_user
await self.send(self.initial_message)
await self.send(message)
app = FastAPI(
title="MY DUMMY APP",
)
app.add_middleware(MetaDataAdderMiddleware)
#app.get("/", response_model=ResponseForEndPoint1)
async def root():
return ResponseForEndPoint1(data='hello world')
I don't think this is a good solution - but it doesn't throw errors and it does show the correct output in swagger.
In general I'm struggling to find a good way to document the changes/ additional responses that middleware can introduce in openAI/swagger. If you've found anything else I'd be keen to hear it!
I have been trying to connect my PSN account to Galaxy 2.0 for a while now and it keeps telling me that it's (offline), I have tried the solutions that come up first on a google search and they didn't work for me.
All the solutions I found have a different code (I don't have it) than the one the app installs by default or the one I can find on Github.
I will provide the code that I have which is the one the app installed by default.
If you have any solutions or know how to solve this please help me, thanks in advance.
This is the log in case it is needed
https://www.mediafire.com/file/3b3921wgyq9357m/plugin-psn-38087aea-3c30-439f-867d-ddf9fae8fe6f.log/file
import sys
from typing import List, Any, AsyncGenerator
from galaxy.api.consts import Platform, LicenseType
from galaxy.api.errors import InvalidCredentials
from galaxy.api.plugin import Plugin, create_and_run_plugin
from galaxy.api.types import Authentication, Game, NextStep, SubscriptionGame, \
Subscription, LicenseInfo
from http_client import HttpClient
from http_client import OAUTH_LOGIN_URL, OAUTH_LOGIN_REDIRECT_URL
from psn_client import PSNClient
from version import __version__
AUTH_PARAMS = {
"window_title": "Login to My PlayStation\u2122",
"window_width": 536,
"window_height": 675,
"start_uri": OAUTH_LOGIN_URL,
"end_uri_regex": "^" + OAUTH_LOGIN_REDIRECT_URL + ".*"
}
logger = logging.getLogger(__name__)
class PSNPlugin(Plugin):
def __init__(self, reader, writer, token):
super().__init__(Platform.Psn, __version__, reader, writer, token)
self._http_client = HttpClient()
self._psn_client = PSNClient(self._http_client)
logging.getLogger("urllib3").setLevel(logging.FATAL)
async def _do_auth(self, cookies):
if not cookies:
raise InvalidCredentials()
self._http_client.set_cookies_updated_callback(self._update_stored_cookies)
self._http_client.update_cookies(cookies)
await self._http_client.refresh_cookies()
user_id, user_name = await self._psn_client.async_get_own_user_info()
if user_id == "":
raise InvalidCredentials()
return Authentication(user_id=user_id, user_name=user_name)
async def authenticate(self, stored_credentials=None):
stored_cookies = stored_credentials.get("cookies") if stored_credentials else None
if not stored_cookies:
return NextStep("web_session", AUTH_PARAMS)
auth_info = await self._do_auth(stored_cookies)
return auth_info
async def pass_login_credentials(self, step, credentials, cookies):
cookies = {cookie["name"]: cookie["value"] for cookie in cookies}
self._store_cookies(cookies)
return await self._do_auth(cookies)
def _store_cookies(self, cookies):
credentials = {
"cookies": cookies
}
self.store_credentials(credentials)
def _update_stored_cookies(self, morsels):
cookies = {}
for morsel in morsels:
cookies[morsel.key] = morsel.value
self._store_cookies(cookies)
async def get_subscriptions(self) -> List[Subscription]:
is_plus_active = await self._psn_client.get_psplus_status()
return [Subscription(subscription_name="PlayStation PLUS", end_time=None, owned=is_plus_active)]
async def get_subscription_games(self, subscription_name: str, context: Any) -> AsyncGenerator[List[SubscriptionGame], None]:
yield await self._psn_client.get_subscription_games()
async def get_owned_games(self):
def game_parser(title):
return Game(
game_id=title["titleId"],
game_title=title["name"],
dlcs=[],
license_info=LicenseInfo(LicenseType.SinglePurchase, None)
)
def parse_played_games(titles):
return [{"titleId": title["titleId"], "name": title["name"]} for title in titles]
purchased_games = await self._psn_client.async_get_purchased_games()
played_games = parse_played_games(await self._psn_client.async_get_played_games())
unique_all_games = {game['titleId']: game for game in played_games + purchased_games}.values()
return [game_parser(game) for game in unique_all_games]
async def shutdown(self):
await self._http_client.close()
def main():
create_and_run_plugin(PSNPlugin, sys.argv)
if __name__ == "__main__":
main()```
I'm deploying a chat application built with django channels. It is working on my localhost but IN production, the sockets connects at first but as soon as i send a message the following error shows up: WebSocket is already in CLOSING or CLOSED state. What could go wrong? Thanks in advance
My server logs are as follows:
Error logs
File "/home/ubuntu/django/virtualenv2/lib/python3.7/site-packages/django/db/models/fields/__init__.py", line 1774, in get_prep_value
return int(value)
TypeError: int() argument must be a string, a bytes-like object or a number, not 'UserLazyObject'
File "/home/ubuntu/django/virtualenv2/lib/python3.7/site-packages/channels/generic/websocket.py", line 175, in websocket_connect
await self.connect()
File "./chat/consumers.py", line 179, in connect
File "./chat/managers.py", line 24, in by_user
threads = self.get_queryset().filter(thread_type="personal")
My consumers.py file and managers.py file are as follows.
Consumers.py file
from django.utils import timezone
import pytz, time, datetime
import json
from channels.layers import get_channel_layer
from chat.models import Message, Thread, Notification
from channels.consumer import SyncConsumer
from asgiref.sync import async_to_sync
from django.contrib.auth import get_user_model
from datetime import datetime
from django.dispatch import receiver
from django.db.models import signals
import asyncio
from asgiref.sync import async_to_sync, sync_to_async
from channels.db import database_sync_to_async
from channels.generic.websocket import AsyncJsonWebsocketConsumer
from django.utils import timezone
User = get_user_model()
# ================== Chat consumer starts ==================
class ChatConsumer(SyncConsumer):
def websocket_connect(self, event):
my_slug = self.scope['path'][-8:]
me = User.objects.get(slug=my_slug)
other_user_slug = self.scope['url_route']['kwargs']['other_slug']
other_user = User.objects.get(slug=other_user_slug)
self.thread_obj = Thread.objects.get_or_create_personal_thread(
me, other_user)
self.room_name = 'personal_thread_{}'.format(self.thread_obj.id)
async_to_sync(self.channel_layer.group_add)(
self.room_name, self.channel_name)
self.send({
'type': 'websocket.accept',
})
#staticmethod
#receiver(signals.post_save, sender=Notification)
def order_offer_observer(sender, instance, **kwargs):
layer = get_channel_layer()
thread_id = instance.thread.id
return thread_id
def websocket_receive(self, event):
my_slug = self.scope['path'][-8:]
me = User.objects.get(slug=my_slug)
other_user_slug = self.scope['url_route']['kwargs']['other_slug']
other_user = User.objects.get(slug=other_user_slug)
text = json.loads(event.get('text'))
thread_obj = Thread.objects.get_or_create_personal_thread(
me, other_user)
obj = Thread.objects.get_or_create_personal_thread(me, other_user)
recent_threads_id = []
other_user_threads = Thread.objects.by_user(other_user.id).all()
for one_thread in other_user_threads:
if one_thread.message_set.all():
if (one_thread.users).first().email == other_user.email:
y = Notification.objects.get_notification_object(
thread=obj, user1=other_user, user2=(one_thread.users).last())
y = y.read
recent_threads_id.append(one_thread.id)
else:
y = Notification.objects.get_notification_object(
thread=obj, user1=other_user, user2=(one_thread.users).first())
y = y.read
recent_threads_id.append(one_thread.id)
else:
pass
if "message" in text:
notification = Notification.objects.get_notification_object(
thread=thread_obj, user1=me, user2=other_user)
notification.sender = me
print("notification.sender " , notification.sender)
current_thread_id = self.order_offer_observer(
sender=Notification, instance=Notification.objects.get_notification_object(thread=thread_obj, user1=me, user2=other_user))
obj.updated_at = int(round(time.time() *1000))
obj.save()
fmt = "%I:%M %p"
utctime = datetime.now()
utc = utctime.replace(tzinfo=pytz.UTC)
localtz = utc.astimezone(timezone.get_current_timezone())
sent_time = localtz.strftime(fmt)
message_text = text['message']
msg = json.dumps({
'text': message_text,
'user_slug': my_slug,
'thread_id': current_thread_id,
'sent_time' : sent_time
})
# Sending live notifications
threads_count = len(recent_threads_id)
if threads_count < 1:
async_to_sync(self.channel_layer.group_send)(
self.room_name,
{'type': 'websocket.message',
'text': msg})
else:
for one_id in range(threads_count):
if recent_threads_id[one_id] == current_thread_id:
async_to_sync(self.channel_layer.group_send)(
self.room_name,
{'type': 'websocket.message',
'text': msg})
else:
async_to_sync(self.channel_layer.group_send)(
'personal_thread_{}'.format(recent_threads_id[one_id]),
{'type': 'websocket.message',
'text': repr(current_thread_id)})
if notification:
notification.read = False
notification.save()
self.store_message(message_text, sent_time)
elif "chat_read" in text:
notification = Notification.objects.get_notification_object(
thread=thread_obj, user1=me, user2=other_user)
if notification.sender:
if 'user_slug' in text:
chat_opener_slug = text['user_slug']
if chat_opener_slug != notification.sender.slug:
notification.read = True
notification.save()
async_to_sync(self.channel_layer.group_send)(
'personal_thread_{}'.format(obj.id),
{'type':'websocket.message',
'text': json.dumps({'message_read': "yes"})})
def websocket_message(self, event):
self.send({
'type': 'websocket.send',
'text': event.get('text')
})
def websocket_disconnect(self, event):
my_slug = self.scope['path'][-8:]
me = User.objects.get(slug=my_slug)
other_user_slug = self.scope['url_route']['kwargs']['other_slug']
other_user = User.objects.get(slug=other_user_slug)
thread_obj = Thread.objects.get_or_create_personal_thread(me, other_user)
# Deleting thread if it has no messages
if len(thread_obj.message_set.all()) == 0:
thread_obj.delete()
def store_message(self, text):
my_slug = self.scope['path'][-8:]
me = User.objects.get(slug=my_slug)
Message.objects.create(
thread=self.thread_obj,
sender=me,
text=text,
)
# ================== Chat consumer ends ==================
class NotifyConsumer(AsyncJsonWebsocketConsumer):
async def connect(self):
await self.accept()
await self.channel_layer.group_add("gossip", self.channel_name)
if self.scope['path'][-8:]:
current_user_slug = self.scope['path'][-8:]
current_user = User.objects.get(slug=current_user_slug)
current_users_chats = sync_to_async(Thread.objects.by_user(current_user).order_by('updated_at'))
#staticmethod
#receiver(signals.post_save, sender=Message)
def announce_new_user(sender, instance, **kwargs):
sender = instance.sender
channel_layer = get_channel_layer()
if instance.thread.users.first() == sender:
me = instance.thread.users.last()
else:
me = instance.thread.users.first()
async_to_sync(channel_layer.group_send)(
"gossip",{
"type": "user.gossip",
"event": "New user",
"sender": me.slug,
"receiver": sender.slug
}
)
async def user_gossip(self, event):
await self.send_json(event)
async def disconnect(self, close_code):
await self.channel_layer.group_discard("gossip", self.channel_name)
```
----------------------- Managers.py file ------------------
from datetime import datetime
from django.db import models
from django.shortcuts import Http404
from django.db.models import Count
class ThreadManager(models.Manager):
def get_or_create_personal_thread(self, user1, user2):
threads = self.get_queryset().filter(thread_type="personal")
threads = threads.filter(users__in=[user1, user2])
threads = threads.annotate(u_count=Count('users')).filter(u_count=2)
if threads.exists():
return threads.first()
else:
if (user1 == user2):
raise Http404
else:
thread = self.create(thread_type="personal")
thread.users.add(user1)
thread.users.add(user2)
return thread
def by_user(self, user):
return self.get_queryset().filter(users__in=[user])
class MessageManager(models.Manager):
def message_by_user(self, user):
return self.get_queryset.filter(sender__in=[user])
I'm trying to run the tests on a Flask-SQLAlchemy example in https://github.com/pallets/flask-sqlalchemy/tree/master/examples/flaskr using pytest.
The conftest.py file looks like:
from datetime import datetime
import pytest
from werkzeug.security import generate_password_hash
from flaskr import create_app
from flaskr import db
from flaskr import init_db
from flaskr.auth.models import User
from flaskr.blog.models import Post
_user1_pass = generate_password_hash("test")
_user2_pass = generate_password_hash("other")
#pytest.fixture
def app():
"""Create and configure a new app instance for each test."""
# create the app with common test config
app = create_app({"TESTING": True, "SQLALCHEMY_DATABASE_URI": "sqlite:///:memory:"})
# create the database and load test data
# set _password to pre-generated hashes, since hashing for each test is slow
with app.app_context():
init_db()
user = User(username="test", _password=_user1_pass)
db.session.add_all(
(
user,
User(username="other", _password=_user2_pass),
Post(
title="test title",
body="test\nbody",
author=user,
created=datetime(2018, 1, 1),
),
)
)
db.session.commit()
yield app
#pytest.fixture
def client(app):
"""A test client for the app."""
return app.test_client()
#pytest.fixture
def runner(app):
"""A test runner for the app's Click commands."""
return app.test_cli_runner()
class AuthActions(object):
def __init__(self, client):
self._client = client
def login(self, username="test", password="test"):
return self._client.post(
"/auth/login", data={"username": username, "password": password}
)
def logout(self):
return self._client.get("/auth/logout")
#pytest.fixture
def auth(client):
return AuthActions(client)
One of the tests looks like:
import pytest
from flask import g
from flask import session
from flaskr.auth.models import User
def test_register(client, app):
# test that viewing the page renders without template errors
assert client.get("/auth/register").status_code == 200
# test that successful registration redirects to the login page
response = client.post("/auth/register", data={"username": "a", "password": "a"})
assert "http://localhost/auth/login" == response.headers["Location"]
# test that the user was inserted into the database
with app.app_context():
assert User.query.filter_by(username="a").first() is not None
def test_user_password(app):
user = User(username="a", password="a")
assert user.password != "a"
assert user.check_password("a")
#pytest.mark.parametrize(
("username", "password", "message"),
(
("", "", b"Username is required."),
("a", "", b"Password is required."),
("test", "test", b"already registered"),
),
)
def test_register_validate_input(client, username, password, message):
response = client.post(
"/auth/register", data={"username": username, "password": password}
)
assert message in response.data
def test_login(client, auth):
# test that viewing the page renders without template errors
assert client.get("/auth/login").status_code == 200
# test that successful login redirects to the index page
response = auth.login()
assert response.headers["Location"] == "http://localhost/"
# login request set the user_id in the session
# check that the user is loaded from the session
with client:
client.get("/")
assert session["user_id"] == 1
assert g.user.username == "test"
#pytest.mark.parametrize(
("username", "password", "message"),
(("a", "test", b"Incorrect username."), ("test", "a", b"Incorrect password.")),
)
def test_login_validate_input(auth, username, password, message):
response = auth.login(username, password)
assert message in response.data
def test_logout(client, auth):
auth.login()
with client:
auth.logout()
assert "user_id" not in session
However, when I run pytest, I get the following error for all the tests.
AttributeError: 'Function' object has no attribute 'get_marker'
Here is a more detailed error:
________________ ERROR at setup of test_login_required[/create] ________________
item = <Function test_login_required[/create]>
def pytest_runtest_setup(item):
> remote_data = item.get_marker('remote_data')
E AttributeError: 'Function' object has no attribute 'get_marker'
../../../../../../anaconda3/lib/python3.6/site-packages/pytest_remotedata/plugin.py:59: AttributeError
It seems that I get this error for any pytest/flask setup. What's going on here?