pytest: rollback between tests using SQLAlchemy and FastAPI - python

I have a FastAPI application where I have several tests written with pytest.
Two particular tests are causing me issues. test_a calls a post endpoint that creates a new entry into the database. test_b gets these entries. test_b is including the created entry from test_a. This is not desired behaviour.
When I run the test individually (using VS Code's testing tab) it runs fine. However when running all the tests together and test_a runs before test_b, test_b fails.
My conftest.py looks like this:
import pytest
from fastapi.testclient import TestClient
from sqlmodel import Session, SQLModel, create_engine
from application.core.config import get_database_uri
from application.core.db import get_db
from application.main import app
#pytest.fixture(scope="module", name="engine")
def fixture_engine():
engine = create_engine(
get_database_uri(uri="postgresql://user:secret#localhost:5432/mydb")
)
SQLModel.metadata.create_all(bind=engine)
yield engine
SQLModel.metadata.drop_all(bind=engine)
#pytest.fixture(scope="function", name="db")
def fixture_db(engine):
connection = engine.connect()
transaction = connection.begin()
session = Session(bind=connection)
yield session
session.close()
transaction.rollback()
connection.close()
#pytest.fixture(scope="function", name="client")
def fixture_client(db):
app.dependency_overrides[get_db] = lambda: db
with TestClient(app) as client:
yield client
The file containing test_a and test_b also has a module-scoped pytest fixture that seeds the data using the engine fixture:
#pytest.fixture(scope="module", autouse=True)
def seed(engine):
connection = test_db_engine.connect()
seed_data_session = Session(bind=connection)
seed_data(seed_data_session)
yield
seed_data_session.rollback()
All tests use the client fixture, like so:
def test_a(client):
...
SQLAlchemy version is 1.4.41, FastAPI version is 0.78.0, and pytest version is 7.1.3.
My Observations
It seems the reason tests run fine on their own is due to SQLModel.metadata.drop_all(bind=engine) being called at the end of testing. However I would like to avoid having to do this, and instead only use rollback between tests.

What worked really well for me is using testcontainers: https://github.com/testcontainers/testcontainers-python.
#pytest.fixture(scope="module", name="session_for_db_in_testcontainer")
def db_engine():
"""
Creates testcontainer with Postgres db
"""
pg_container = PostgresContainer('postgres:latest')
pg_container.start()
# Fireup the SQLModel engine with the uri of the container
db_engine = create_engine(pg_container.get_connection_url())
sqlmodel_metadata.create_all(db_engine)
with Session(db_engine) as session_for_db_in_testcontainer:
# add some rows to start, for test get requests and posting existing data
add_data_to_test_db(database_input_path, session_for_db_in_testcontainer)
yield session_for_db_in_testcontainer
# Will be executed after the last test
session_for_db_in_testcontainer.close()
pg_container.stop()
Like this during the test run a (Postgres) DB is created it only runs during a session, module or function depending on the scope of the fixture. If you want, you can add test data to the db as well like in the example.
In your case you might want to set the scope of this fixture as function. Than test_a and test_b should run independently.

Related

How To Setup a SQLAlchemy Asynchronous Scoped Session In Python Behave Synchronous Hooks?

I have also asked this question on behave GitHub discussions and SQLAlchemy GitHub discussions.
I am trying to hookup a SQLAlchemy 1.4 engine and global scoped asynchronous session in behave before_all and before_scenario hooks to model testing similar to that outlined in the following blog article
The approach is to have a parent transaction and each test running in a nested transaction that gets rolled back when the test completes.
Unfortunately the before_all, before_scenario hooks are synchronous.
The application under test uses an asynchronous engine and asynchronous session created using sessionmaker:
def _create_session_factory(engine) -> sessionmaker[AsyncSession]:
factory = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
factory.configure(bind=engine)
return factory
In the before_scenario test hook the following line raises an error when I try to create a scoped session.
"""THIS RAISES AN ERROR RuntimeError: no running event loop"""
context.session = context.Session(bind=context.connection, loop=loop)
The full code listing for setting up the test environment is listed below.
How do I get an asynchronous scoped session created in the synchronous before_all, before_scenario test hooks of behave?
import asyncio
import logging
from behave.api.async_step import use_or_create_async_context
from behave.log_capture import capture
from behave.runner import Context
from behave.model import Scenario
from sqlalchemy import event
from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session
from sqlalchemy.ext.asyncio.engine import async_engine_from_config
from sqlalchemy.orm.session import sessionmaker
from fastapi_graphql.server.config import Settings
logger = logging.getLogger()
#capture(level=logging.INFO)
def before_all(context: Context) -> None:
"""Setup database engine and session factory."""
logging.info("Setting up logging for behave tests...")
context.config.setup_logging()
logging.info("Setting up async context...")
use_or_create_async_context(context)
loop = context.async_context.loop
asyncio.set_event_loop(loop)
logging.info("Configuring db engine...")
settings = Settings()
config = settings.dict()
config["sqlalchemy_url"] = settings.db_url
engine = async_engine_from_config(config, prefix="sqlalchemy_")
logging.info(f"Db engine configured for connecting to: {settings.db_url}")
logging.info("Creating a global session instance")
factory = sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False)
# factory.configure(bind=engine)
Session = async_scoped_session(factory(), scopefunc=asyncio.current_task)
context.engine = engine
context.connection = loop.run_until_complete(engine.connect())
context.factory = factory
context.Session = Session
#capture(level=logging.INFO)
def after_all(context: Context) -> None:
"""Teardown database engine gracefully."""
loop = context.async_context.loop
logging.info("Closing connection")
loop.run_until_complete(context.connection.close())
logging.info("Closing database engine...")
loop.run_until_complete(context.engine.dispose())
logging.info("Database engine closed")
#capture(level=logging.INFO)
def before_scenario(context: Context, scenario: Scenario) -> None:
"""Create a database session."""
loop = context.async_context.loop
logging.info("Starting a transaction...")
context.transaction = loop.run_until_complete(context.connection.begin())
logging.info("Transaction started...")
logging.info("Creating a db session...")
breakpoint()
# THIS RAISES AN ERROR RuntimeError: no running event loop
context.session = context.Session(bind=context.connection, loop=loop)
logging.info("Db session created")
breakpoint()
logging.info("Starting a nested transaction...")
context.session.begin_nested()
logging.info("Nested transaction started...")
#event.listens_for(context.session, "after_transaction_end")
def restart_savepoint(db_session, transaction):
"""Support tests with rollbacks.
This is required for tests that call some services that issue
rollbacks in try-except blocks.
With this event the Session always runs all operations within
the scope of a SAVEPOINT, which is established at the start of
each transaction, so that tests can also rollback the
“transaction” as well while still remaining in the scope of a
larger “transaction” that’s never committed.
"""
if context.transaction.nested and not context.transaction._parent.nested:
# ensure that state is expired the way session.commit() at
# the top level normally does
context.session.expire_all()
context.session.begin_nested()
#capture(level=logging.INFO)
def after_scenario(context: Context, scenario: Scenario) -> None:
"""Close the database session."""
logging.info("Closing db session...")
loop = asyncio.get_event_loop()
loop.run_until_complete(context.Session.remove())
logging.info("Db session closed")
logging.info("Rolling back transaction...")
loop.run_until_complete(context.transaction.rollback())
logging.info("Rolled back transaction")

How to properly configure pytest with FastAPI and Tortoise ORM?

I am trying to configure the tests. According to the tortoise orm documentation I create this test configuration file:
import pytest
from fastapi.testclient import TestClient
from tortoise.contrib.test import finalizer, initializer
import app.main as main
from app.core.config import settings
#pytest.fixture(scope="session", autouse=True)
def initialize_tests(request):
db_url = "postgres://USERNAME_HERE:SECRET_PASS_HERE#127.0.0.1:5432/test"
initializer(
[
"app.models",
],
db_url=db_url,
app_label="models"
)
print("initialize_tests")
request.add_finaliser(finalizer)
#pytest.fixture(scope="session")
def client():
app = main.create_application()
with TestClient(app) as client:
print("client")
yield client
And the test file looks like this:
def test_get(client):
response = client.get("/v1/url/")
assert response.status_code == 200
I try to run the tests, but I get this error:
asyncpg.exceptions._base.InterfaceError: cannot perform operation: another operation is in progress
I have found that some users don't use initializer and finalizer and do everything manually.
Testing in FastAPI using Tortoise-ORM
https://stackoverflow.com/a/66907531
But that doesn't look like the clear solution.
Question: Is there a way to make the tests work using initializer and finalizer?

Pytest Alembic initialize database with async migrations

The existing posts didn't provide a useful answer to me.
I'm trying to run asynchronous database tests using Pytest (db is Postgres with asyncpg), and I'd like to initialize my database using my Alembic migrations so that I can verify that they work properly in the meantime.
My first attempt was this:
#pytest.fixture(scope="session")
async def tables():
"""Initialize a database before the tests, and then tear it down again"""
alembic_config: config.Config = config.Config('alembic.ini')
command.upgrade(alembic_config, "head")
yield
command.downgrade(alembic_config, "base")
which didn't actually do anything at all (migrations were never applied to the database, tables not created).
Both Alembic's documentation & Pytest-Alembic's documentation say that async migrations should be run by configuring your env like this:
async def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
asyncio.run(run_migrations_online())
but this doesn't resolve the issue (however it does work for production migrations outside of pytest).
I stumpled upon a library called pytest-alembic that provides some built-in tests for this.
When running pytest --test-alembic, I get the following exception:
got Future attached to a different loop
A few comments on pytest-asyncio's GitHub repository suggest that the following fixture might fix it:
#pytest.fixture(scope="session")
def event_loop() -> Generator:
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
but it doesn't (same exception remains).
Next I tried to run the upgrade test manually, using:
async def test_migrations(alembic_runner):
alembic_runner.migrate_up_to("revision_tag_here")
which gives me
alembic_runner.migrate_up_to("revision_tag_here")
venv/lib/python3.9/site-packages/pytest_alembic/runner.py:264: in run_connection_task
return asyncio.run(run(engine))
RuntimeError: asyncio.run() cannot be called from a running event loop
However this is an internal call by pytest-alembic, I'm not calling asyncio.run() myself, so I can't apply any of the online fixes for this (try-catching to check if there is an existing event loop to use, etc.). I'm sure this isn't related to my own asyncio.run() defined in the alembic env, because if I add a breakpoint - or just raise an exception above it - the line is actually never executed.
Lastly, I've also tried nest-asyncio.apply(), which just hangs forever.
A few more blog posts suggest to use this fixture to initialize database tables for tests:
async with engine.begin() as connection:
await connection.run_sync(Base.metadata.create_all)
which works for the purpose of creating a database to run tests against, but this doesn't run through the migrations so that doesn't help my case.
I feel like I've tried everything there is & visited every docs page, but I've got no luck so far. Running an async migration test surely can't be this difficult?
If any extra info is required I'm happy to provide it.
I got this up and running pretty easily with the following
env.py - the main idea here is that the migration can be run synchronously
import asyncio
from logging.config import fileConfig
from alembic import context
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import AsyncEngine
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
target_metadata = mymodel.Base.metadata
def run_migrations_online():
connectable = context.config.attributes.get("connection", None)
if connectable is None:
connectable = AsyncEngine(
engine_from_config(
context.config.get_section(context.config.config_ini_section),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
future=True
)
)
if isinstance(connectable, AsyncEngine):
asyncio.run(run_async_migrations(connectable))
else:
do_run_migrations(connectable)
async def run_async_migrations(connectable):
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
def do_run_migrations(connection):
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True,
)
with context.begin_transaction():
context.run_migrations()
run_migrations_online()
then I added a simple db init script
init_db.py
from alembic import command
from alembic.config import Config
from sqlalchemy.ext.asyncio import create_async_engine
__config_path__ = "/path/to/alembic.ini"
__migration_path__ = "/path/to/folder/with/env.py"
cfg = Config(__config_path__)
cfg.set_main_option("script_location", __migration_path__)
async def migrate_db(conn_url: str):
async_engine = create_async_engine(conn_url, echo=True)
async with async_engine.begin() as conn:
await conn.run_sync(__execute_upgrade)
def __execute_upgrade(connection):
cfg.attributes["connection"] = connection
command.upgrade(cfg, "head")
then your pytest fixture can look like this
conftest.py
...
#pytest_asyncio.fixture(autouse=True)
async def migrate():
await migrate_db(conn_url)
yield
...
Note: I don't scope my migrate fixture to the test session, I tend to drop and migrate after each test.

sqlalchemy session for pytest vs session for production

I'm looking for a way to be able to test code using pytest as well as use that code in production, and I'm struggling with session handling.
For pytest, I have a conftest.py that includes:
#pytest.fixture
def session(setup_database, connection):
transaction = connection.begin()
yield scoped_session(
sessionmaker(autocommit=False, autoflush=False, bind=connection)
)
transaction.rollback()
That allows me to write low-level tests using a test database along the lines of:
def test_create(session):
thing = Things(session, "my thing")
assert thing
...where Things is a sqlalchemy declarative base class defining a database table. This works fine.
The problem I'm trying to solve arises when testing higher levels of the code. The models.py includes:
engine = sqlalchemy.create_engine(
Config.MYSQL_CONNECT,
encoding='utf-8',
pool_pre_ping=True)
Session = scoped_session(sessionmaker(bind=engine))
...and the usage in the code is typically:
def fn():
with Session() as session:
thing = Things(session, "my thing")
I want fn() to use the Session defined in models.py in production, but use the pytest Session in testing.
I clearly have this architected incorrectly but I'm struggling to find a way forwards for what must be quite a common problem.
How do others handle this?

Using a DB dependency in FastAPI without having to pass it through a function tree

I am currently working on a POC using FastAPI on a complex system. This project is heavy in business logic and will interact with 50+ different database tables when completed. Each model has a service, and some of the more complex business logic has its own service (which then interacts/queries with the different tables through the model-specific services).
While everything works, I've gotten some push-back from some members of my team regarding the dependency injection for the Session object. The biggest issue being mainly having to pass the Session from the controller, to a service, to a second service and (in a few cases), a third service further in. In those cases, the intermediary service functions tend to have no database queries but the functions that they call on other services might have some. The complaint mainly lies in this being more difficult to maintain and having to pass the DB object everywhere seems uselessly repetitive.
Example as code:
databases/mysql.py (one of 3 dbs in the project)
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
def get_uri():
return 'the mysql uri'
engine = create_engine(get_uri())
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def get_db():
db: Session = SessionLocal()
try:
yield db
db.commit()
except Exception:
db.rollback()
finally:
db.close()
controllers/controller1.py
from fastapi import APIRouter, HTTPException, Path, Depends
from sqlalchemy.orm import Session
from services.mysql.bar import get_bar_by_id
from services.mysql.process_x import bar_process
from databases.mysql import get_db
router = APIRouter(prefix='/foo')
#router.get('/bar/{bar_id}')
def process_bar(bar_id: int = Path(..., title='The ID of the bar to process', ge=1),
mysql_session: Session = Depends(get_db)):
# From the crontroller, to a service which only runs a query. This is fine.
bar = get_bar_by_id(bar_id, mysql_session)
if bar is None:
raise HTTPException(status_code=404,
detail='Bar not found for id: {bar_id}'.format(bar_id=bar_id))
# This one calls a function in a service which has a lot of business logic but no queries
processed_bar = bar_process(bar, mysql_session)
return processed_bar
services/mysql/process_x.py
from .process_bar import process_the_bar
from models.mysql.w import W
from models.mysql.bar import Bar
from models.mysql.y import Y
from models.mysql.z import Z
from sqlalchemy.orm import Session
def w_process(w: W, mysql_session: Session):
...
def bar_process(bar: Bar, mysql_session: Session):
# Very simplified, there's actually 5 conditional branching service calls here
return process_the_bar(bar, mysql_session)
def y_process(y: Y, mysql_session: Session):
...
def z_process(z: Z, mysql_session: Session):
...
services/mysql/process_bar.py
from . import model_service1
from . import model_service2
from . import model_service3
from . import additional_rules_service
from libraries.bar_functions import do_thing_to_bar
from models.mysql.bar import Bar
from sqlalchemy.orm import Session
def process_the_bar(bar: bar, mysql_session: Session):
process_result = list()
# Many processing steps, not all of them require db and might work on the bar directly
process_result.append(process1(bar, mysql_session))
process_result.append(process2(bar, mysql_session))
process_result.append(process3(bar, mysql_session))
process_result.append(process4(bar))
process_result.append(...(bar))
process_result.append(processY(bar))
def process1(bar: Bar, mysql_session: Session):
return model_service1.do_something(bar.val, mysql_session)
def process2(bar: Bar, mysql_session: Session):
return model_service2.do_something(bar.val, mysql_session)
def process3(bar: Bar, mysql_session: Session):
return model_service3.do_something(bar.val, mysql_session)
def process4-Y(bar: Bar, mysql_session: Session):
# do something using the bar library, or maybe on another service with no queries
return list()
As you can see, we're stuck passing the mysql_session and having it repeat everywhere when using this approach.
Here are a two solutions I have thought of:
Adding the DB session to the Starlette request state
I could do this either through the app.startup event ( https://fastapi.tiangolo.com/advanced/events/ ) or a middleware. However, it does mean passing the request state back and forth in a similar fashion (if my understanding of it is correct)
Session scope approach using Context Manager
Pretty much, I would turn the get_db function into a context manager instead and not inject it as a dependency. By far the cleanest end result, however it goes completely against the concept of sharing a single db session across the request.
I've considered the fully async approach using encode/databases as shown in the FastAPI documentation ( https://fastapi.tiangolo.com/advanced/async-sql-databases/ ), however one of the databases we are working with on SqlAlchemy is used through a plugin and I am assuming does not support async out of the box (Vertica). If I'm wrong, then I could consider the fully async approach.
So in the end, what I'm wondering is if it's possible to accomplish something "cleaner" without compromising the single session per request approach?
I have gotten some help directly from the FastAPI Github
As user Insomnes mentioned, what I am looking to do can be achieved by using ContextVar. I have tried it in my code and it seems to work just fine.

Categories