I'll try to be as complete as possible in this issue.
I'm using Sanic, an ASGI Python framework, and I built a Database manager on top of this.
This database manager uses the ContextVar to give access to my current db instance everywhere in the code.
Here's the code related to the database:
database.py
# -*- coding:utf-8 -*-
from sqlalchemy import exc, event
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession as SQLAlchemyAsyncSession
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.pool import Pool, QueuePool, NullPool
from sqlalchemy.exc import OperationalError
from contextvars import ContextVar
from sentry_sdk import push_scope, capture_exception
from sanic import Sanic
class EngineNotInitialisedError(Exception):
pass
class DBSessionContext:
def __init__(self, read_session: Session, write_session: Session, commit_on_exit: bool = True) -> None:
self.read_session = read_session
self.write_session = write_session
self.commit_on_exit = commit_on_exit
self.token = None
self._read = None
self._write = None
def _disable_flush(self, *args, **kwargs):
raise NotImplementedError('Unable to flush a read-only session.')
async def close(self, exc_type=None, exc_value=None, traceback=None):
if self._write:
try:
if exc_value and getattr(exc_value, 'status_code', 500) > 300:
await self._write.rollback()
else:
await self._write.commit()
except Exception as e:
pass
try:
await self._write.close()
except OperationalError as e:
if e.orig.args[0] != 2013: # Lost connection to MySQL server during query
raise e
if self._read:
try:
await self._read.close()
except OperationalError as e:
if e.orig.args[0] != 2013: # Lost connection to MySQL server during query
raise e
def set_token(self, token):
self.token = token
#property
def read(self) -> Session:
if not self._read:
self._read = self.read_session()
self._read.flush = self._disable_flush
return self._read
#property
def write(self) -> Session:
if not self._write:
self._write = self.write_session()
return self._write
class AsyncSession(SQLAlchemyAsyncSession):
async def execute(self, statement, **parameters):
return await super().execute(statement, parameters)
async def first(self, statement, **parameters):
executed = await self.execute(statement, **parameters)
return executed.first()
async def all(self, statement, **parameters):
executed = await self.execute(statement, **parameters)
return executed.all()
class DBSession:
def __init__(self):
self.app = None
self.read_engine = None
self.read_session = None
self.write_engine = None
self.write_session = None
self._session = None
self.context = ContextVar("context", default=None)
self.commit_on_exit = True
def init_app(self, app: Surge) -> None:
self.app = app
self.commit_on_exit = self.app.config.get('DATABASE_COMMIT_ON_EXIT', cast=bool, default=True)
self.read_engine = create_async_engine(
self.app.config.get('DATABASE_READ_URL'),
connect_args={
'connect_timeout': self.app.config.get('DATABASE_CONNECT_TIMEOUT', cast=int, default=3)
},
**{
'echo': self.app.config.get('DATABASE_ECHO', cast=bool, default=False),
'echo_pool': self.app.config.get('DATABASE_ECHO_POOL', cast=bool, default=False),
'poolclass': QueuePool, # will be used to create a connection pool instance using the connection parameters given in the URL
# if pool_class is not NullPool:
# if True will enable the connection pool “pre-ping” feature that tests connections for liveness upon each checkout
'pool_pre_ping': self.app.config.get('DATABASE_POOL_PRE_PING', cast=bool, default=True),
# the number of connections to allow in connection pool “overflow”
'max_overflow': self.app.config.get('DATABASE_MAX_OVERFLOW', cast=int, default=10),
# the number of connections to keep open inside the connection pool
'pool_size': self.app.config.get('DATABASE_POOL_SIZE', cast=int, default=100),
# this setting causes the pool to recycle connections after the given number of seconds has passed
'pool_recycle': self.app.config.get('DATABASE_POOL_RECYCLE', cast=int, default=3600),
# number of seconds to wait before giving up on getting a connection from the pool
'pool_timeout': self.app.config.get('DATABASE_POOL_TIMEOUT', cast=int, default=5),
}
)
# #see https://writeonly.wordpress.com/2009/07/16/simple-read-only-sqlalchemy-sessions/
self.read_session = sessionmaker(
bind=self.read_engine,
expire_on_commit=False,
class_=AsyncSession,
autoflush=False,
autocommit=False
)
self.write_engine = create_async_engine(
self.app.config.get('DATABASE_WRITE_URL'),
connect_args={
'connect_timeout': self.app.config.get('DATABASE_CONNECT_TIMEOUT', cast=int, default=3)
},
**{
'echo': self.app.config.get('DATABASE_ECHO', cast=bool, default=False),
'echo_pool': self.app.config.get('DATABASE_ECHO_POOL', cast=bool, default=False),
'poolclass': NullPool, # will be used to create a connection pool instance using the connection parameters given in the URL
}
)
self.write_session = sessionmaker(
bind=self.write_engine,
expire_on_commit=False,
class_=AsyncSession,
autoflush=True
)
async def __aenter__(self):
session_ctx = DBSessionContext(self.read_session, self.write_session, self.commit_on_exit)
session_ctx.set_token(self.context.set(session_ctx))
return session_ctx
async def __aexit__(self, exc_type, exc_value, traceback):
session_ctx = self.context.get()
try:
await session_ctx.close(exc_type, exc_value, traceback)
except Exception:
pass
self.context.reset(session_ctx.token)
#property
def read(self) -> Session:
return self.context.get().read
#property
def write(self) -> Session:
return self.context.get().write
#event.listens_for(Pool, "checkout")
def check_connection(dbapi_con, con_record, con_proxy):
'''Listener for Pool checkout events that pings every connection before using.
Implements pessimistic disconnect handling strategy. See also:
http://docs.sqlalchemy.org/en/rel_0_8/core/pooling.html#disconnect-handling-pessimistic'''
cursor = dbapi_con.cursor()
try:
cursor.execute("SELECT 1")
except exc.OperationalError as ex:
if ex.args[0] in (2006, # MySQL server has gone away
2013, # Lost connection to MySQL server during query
2055): # Lost connection to MySQL server at '%s', system error: %d
raise exc.DisconnectionError() # caught by pool, which will retry with a new connection
else:
raise
cursor.close()
db = DBSession()
This configuration allows me to run something like:
from models import User
from database import db
#app.get('/user')
async def get_user(request):
async with db:
users = User.find_all() # Special function in the Model that returns all users
return json({'items': [{'id': x.id for x in users}])
The __aenter__ and mostly the __aexit__ from the DBSession class (and the subsequent DBSessionContext) handles everything when the code quit the async with, including any exceptions if they occurred.
The issue I'm having, is that from time to time, I have the following error reported at Sentry:
The garbage collector is trying to clean up connection <AdaptedConnection <asyncmy.connection.Connection object at 0x7f290c50dd30>>. This feature is unsupported on unsupported on asyncio dbapis that lack a "terminate" feature, since no IO can be performed at this stage to reset the connection. Please close out all connections when they are no longer used, calling close() or using a context manager to manage their lifetime.
I don't understand why this is happening. Even more odd is that I often get this error on a function call that doesn't use the database at all (the async with db is still present, but the inside doesn't use the database at all).
The content of that function is network call:
import requests
#app.get('/notify')
async def get_user(request):
async with db:
requests.post('https://service.com/notify', data={'some': 'data'})
return text('ok')
Here are my assumptions, but I'm hoping to have a clearer view on the issue:
Assumption 1: Since the read is using a QueuePool, maybe the __aexit__ call to close doesn't really close the connection, and as such, the connection remain open, causing the "The garbage collector is trying to clean up connection" issue later on.
Assumption 2: The connection is made at the check_connection and remains open, causing the "garbage collector" issue
Any idea why I'm having that "garbage collector" issue?
I'm using :
sanic==22.9.0
sqlalchemy[asyncio]==1.4.41
asyncmy==0.2.5
This line might be causing you the problem await session_ctx.close(exc_type, exc_value, traceback).
Try changing it to this await asyncio.shield(session_ctx.close(exc_type, exc_value, traceback)).
This was added to the SQLAlchemy code base in July.
This change was implemented in /asyncio/engine.py and /asyncio/session.py. Here is the change in the code:
Additional references:
SQLAlchemy issue 8145
The change was added to version 1.4.40 with a released date of August 8, 2022
A naive and fast solution might be to check it by wrapping it in a try/except block and handling the specific error by printing the output.
You don't manage the lifetime of the requests.post, isn't this keeping the close from being called?
Although I do think aexit should close the session I don't really understand why you do this at all: async with db:. What is the purpose of the session?
Nice implementation overall.
I need to simulate DB connection without actual connection. All answers I found are trying to mock methods in different ways, connect to docker db, connect to actual PostgreSQL running locally. I believe I need mocking variant but I cannot formulate in my head how should I mock. Am I missing something? Am I moving into wrong direction?
I use PostgreSQL and psycopg2. Package psycopg2-binary
Database connection:
import os
import psycopg2
from loguru import logger
from psycopg2.extensions import parse_dsn
def init_currency_history_table(cursor):
create_users_table_query = """
CREATE TABLE IF NOT EXISTS history(
id BIGINT PRIMARY KEY NOT NULL,
event TEXT,
creation_date TIMESTAMPTZ DEFAULT NOW()
);
"""
cursor.execute(create_users_table_query)
def load_db(db_url):
db = psycopg2.connect(**db_url)
db.autocommit = True
return db
class PostgresqlApi(object):
def __init__(self, load=load_db):
logger.info(os.environ.get('DATABASE_URL'))
db_url = parse_dsn(os.environ.get('DATABASE_URL'))
db_url['sslmode'] = 'require'
logger.info('HOST: {0}'.format(db_url.get('host')))
self.db = load_db(db_url)
self.cursor = self.db.cursor()
init_currency_history_table(self.cursor)
self.db.commit()
def add_event(self, *, event):
insert_event_table = """
INSERT INTO history (event) VALUES (%s);
"""
self.cursor.execute(insert_event_table, (event))
def events(self):
select_event_table = """SELECT * FROM event;"""
self.cursor.execute(select_event_table)
return self.cursor.fetchall()
def close(self):
self.cursor.close()
self.db.close()
I use DB for Falcon API.
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from decimal import Decimal, getcontext
from db import PostgresqlApi
app = FastAPI()
security = HTTPBasic()
database = None
def db_connection():
global database
if not database:
database = PostgresqlApi()
return database
def check_basic_auth_creds(credentials: HTTPBasicCredentials = Depends(security)):
correct_username = secrets.compare_digest(credentials.username, os.environ.get('APP_USERNAME'))
correct_password = secrets.compare_digest(credentials.password, os.environ.get('APP_PASSWORD'))
if not (correct_username and correct_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username and password",
headers={'WWW-Authenticate': 'Basic'}
)
return credentials
#app.get("/currencies")
def read_currencies(credentials: HTTPBasicCredentials = Depends(check_basic_auth_creds)):
db = db_connection()
return {'get events': 'ok'}
I have tried different methods and plugins. Among others arepytest-pgsql, pytest-postgresql.
The solution I landed at is below.
Created fake class that has exactly structure of PostgresqlApi. (see implementation below)
Created fixture for db_connection method. (see implementation below)
Fake class implementation
class FakePostgresqlApi(PostgresqlApi):
event_list = []
def __init__(self):
pass
def add_event(self, *, event):
self.event_list.append([1, 'magic trick', 1653630607])
def events(self):
return self.event_list
def close(self):
self.event_list.clear()
Fixture
from unittest.mock import MagicMock
#pytest.fixture
def mock_db_connection(mocker):
mocker.patch('src.main.db_connection', MagicMock(return_value=FakePostgresqlApi()))
The test itself was:
def test_read_events(mock_db_connection):
# Do whatever I need here, in my case call Falcon API test client
I am building some RestAPI's using python FastAPI and getting some data from MSSQL server using SQLAlchemy.
I am trying to insert data into existing stored procedure in MSSQL server.
The stored procedure AcctsCostCentersAddV001 will take ArDescription, EngDescription, ErrMsg as parameters and return Output as an output.
MSSQL Server Code:
CREATE Proc [dbo].[AcctsCostCentersAddV001] #ArDescription nvarchar(100), #EngDescription varchar(100), #ErrMsg nvarchar(100) Output
As
Insert Into AcctsCostCenters (ArDescription, EngDescription) Values(#ArDescription, #EngDescription)
Set #ErrMsg = 'Test'
GO
My Python code:
from fastapi import APIRouter, Request, Depends, HTTPException
from fastapi.responses import JSONResponse
# # # # # # # # # # SQL # # # # # # # # # #
from sqlalchemy import text
from sqlalchemy.exc import ProgrammingError
# # # # # # # # # # Files # # # # # # # # # #
from dependencies import get_db
from sqlalchemy.engine import create_engine
from internal.config import username, password, SQL_SERVER, database_name
#router.get("/create/CostCenter/")
async def forms(request: Request, db=Depends(get_db)):
try:
connection_string = f'mssql://{username}:{password}#{SQL_SERVER}/{database_name}?
driver=ODBC+Driver+17+for+SQL+Server'
engine = create_engine(connection_string, pool_pre_ping=True)
connection = engine.raw_connection()
try:
cursor_obj = connection.cursor()
query = "Declare #ErrMsg nvarchar(100) Exec AcctsCostCentersAddV001 'Test', 'Test', #ErrMsg Output Print #ErrMsg"
cursor_obj.execute(text(query))
results = list(cursor_obj.fetchall())
cursor_obj.close()
connection.commit()
print(results)
finally:
connection.close()
except IndexError:
raise HTTPException(status_code=404, detail="Not found")
except ProgrammingError as e:
print(e)
raise HTTPException(status_code=400, detail="Invalid Entry")
except Exception as e:
print(e)
return UnknownError(
"unknown error caused by CostCenter API request handler", error=e)
return JSONResponse(results)
For some reason, this code doesn't raise any exceptions but yet i keep getting
('The SQL contains 0 parameter markers, but 3 parameters were supplied', 'HY000')
I have even tried this
cursor_obj.execute("Declare #ErrMsg nvarchar(100) Exec AcctsCostCentersAddV001 'Test', 'Test', #ErrMsg")
but i get back
No results. Previous SQL was not a query.
i tried wrapping the query in text() but i got The first argument to execute must be a string or unicode query.
but when i go into MSSQL Server and run
Declare #ErrMsg nvarchar(100)
Exec AcctsCostCentersAddV001 'Test', 'Test', #ErrMsg Output
Print #ErrMsg
It runs without any problem.
My Env:
Ubuntu 21.04 VPS by OVH
I hope i am providing everything you guys need and let me know if missed anything, Thanks!
btw I know i am connecting to the db twice :3
(Edited):
I am actually connecting to the DB from a datapase.py file, and i am connecting again in the function just for testing.
My database.py:
import time
from colorama import Fore
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.engine import create_engine
from internal.config import username, password, SQL_SERVER, database_name
try:
connection_string = f'mssql://{username}:{password}#{SQL_SERVER}/{database_name}?driver=ODBC+Driver+17+for+SQL+Server'
engine = create_engine(connection_string, pool_pre_ping=True)
connection = engine.raw_connection()
print(f"{Fore.CYAN}DB{Fore.RESET}: Database Connected. ")
SessionLocal = sessionmaker(autocommit=True, autoflush=False, bind=engine)
Base = declarative_base()
except Exception as e:
print(f"{Fore.RED}DB{Fore.RESET}: Couldn't connect to Database.", e)
I am trying to write test for web service and I want to create a separate database for tests when you run them. It is my pytest fixture for realise it
#pytest.fixture(scope="session")
def db_engine():
engine = create_engine(SQLALCHEMY_DATABASE_URL)
if not database_exists:
create_database(engine.url)
Base.metadata.create_all(bind=engine)
yield engine
#pytest.fixture(scope="function")
def db(db_engine):
connection = db_engine.connect()
connection.begin()
db = Session(bind=connection)
yield db
db.rollback()
connection.close()
#pytest.fixture(scope="function")
def client(db):
app.dependency_overrides[get_db] = lambda: db
with TestClient(app) as c:
yield c
But app.dependecy_overrides[get_db] = lambda: db didnt work and requests continue to be sent to the main database and not the test one.
One of my endpoints
#router.get("/", response_model=List[RoomPayload])
def read(db: Session = Depends(get_db),
user=Depends(manager)):
q = db.query(Room).all()
if not q:
raise HTTPException(status_code=404, detail=f"Rooms not found")
return q
If I want to use database while processing a request, I make a Dependency Injection like this:
#app.post("/sample_test")
async def sample_test(db: Session = Depends(get_db)):
return db.query(models.User.height).all()
But I cannot do it with events like this:
#app.on_event("startup")
async def sample_test(db: Session = Depends(get_db)):
return db.query(models.User.height).all()
because starlette events don't support Depends.
This is my get_db() function:
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
just like in FastAPI manual (https://fastapi.tiangolo.com/tutorial/sql-databases/).
How can I access get_db() inside my event function, so I can work with a Session?
I've tried:
#app.on_event("startup")
async def sample_test(db: Session = Depends(get_db)):
db = next(get_db())
return db.query(models.User.height).all()
but it doesn't work.
I use MSSQL, if it's important.
Instead of using a dependency you can import the SessionLocal you've created as shown in the FastAPI manual and use a contextmanager to open and close this session:
#app.on_event("startup")
async def sample_test():
with SessionLocal() as db:
return db.query(models.User.height).all()