How to set up and tear down a database between tests in FastAPI?

Question:

I have set up my unit tests as per FastAPI documentation, but it only covers a case where database is persisted among tests.

What if I want to build and tear down database per test? (for example, the second test below will fail, because the database will no longer be empty after the first test).

I am currently doing it by calling create_all and drop_all (commented out in code below) on the beginning and end of each test, but this is obviously not ideal (if a test fails, the database will be never torn down, impacting the result of the next test).

How can I do it properly? Should I create some kind of Pytest fixture around override_get_db dependency?

from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

from main import app, get_db
from database import Base

SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"

engine = create_engine(
    SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

# Base.metadata.create_all(bind=engine)

def override_get_db():
    try:
        db = TestingSessionLocal()
        yield db
    finally:
        db.close()

app.dependency_overrides[get_db] = override_get_db

client = TestClient(app)

def test_get_todos():
    # Base.metadata.create_all(bind=engine)

    # create
    response = client.post('/todos/', json={'text': 'some new todo'})
    data1 = response.json()
    response = client.post('/todos/', json={'text': 'some even newer todo'})
    data2 = response.json()

    assert data1['user_id'] == data2['user_id']

    response = client.get('/todos/')

    assert response.status_code == 200
    assert response.json() == [
        {'id': data1['id'], 'user_id': data1['user_id'], 'text': data1['text']},
        {'id': data2['id'], 'user_id': data2['user_id'], 'text': data2['text']}
    ]

    # Base.metadata.drop_all(bind=engine)

def test_get_empty_todos_list():
    # Base.metadata.create_all(bind=engine)

    response = client.get('/todos/')

    assert response.status_code == 200
    assert response.json() == []

    # Base.metadata.drop_all(bind=engine)
Asked By: barciewicz

||

Answers:

For cleaning up after tests even when they fail (and setting up before tests), pytest provides pytest.fixture.

In your case you want to create all tables before each test, and drop them again afterwards. This can be achieved with the following fixture:

@pytest.fixture()
def test_db():
    Base.metadata.create_all(bind=engine)
    yield
    Base.metadata.drop_all(bind=engine)

And then use it in your tests like so:

def test_get_empty_todos_list(test_db):
    response = client.get('/todos/')

    assert response.status_code == 200
    assert response.json() == []

For each test that has test_db in its argument list pytest first runs Base.metadata.create_all(bind=engine), then yields to the test code, and afterwards makes sure that Base.metadata.drop_all(bind=engine) gets run, even when the tests fail.

The full code:

import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from main import app, get_db
from database import Base


SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"

engine = create_engine(
    SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


def override_get_db():
    try:
        db = TestingSessionLocal()
        yield db
    finally:
        db.close()


@pytest.fixture()
def test_db():
    Base.metadata.create_all(bind=engine)
    yield
    Base.metadata.drop_all(bind=engine)

app.dependency_overrides[get_db] = override_get_db

client = TestClient(app)


def test_get_todos(test_db):
    response = client.post("/todos/", json={"text": "some new todo"})
    data1 = response.json()
    response = client.post("/todos/", json={"text": "some even newer todo"})
    data2 = response.json()

    assert data1["user_id"] == data2["user_id"]

    response = client.get("/todos/")

    assert response.status_code == 200
    assert response.json() == [
        {"id": data1["id"], "user_id": data1["user_id"], "text": data1["text"]},
        {"id": data2["id"], "user_id": data2["user_id"], "text": data2["text"]},
    ]


def test_get_empty_todos_list(test_db):
    response = client.get("/todos/")

    assert response.status_code == 200
    assert response.json() == []

As your application grows, setting up and tearing down the whole database for each test might get slow.

A solution for that is to only set up the db once and then never actually commit anything to it.
This can be achieved using nested transactions and rollbacks:

import pytest
import sqlalchemy as sa
from fastapi.testclient import TestClient
from sqlalchemy.orm import sessionmaker

from database import Base
from main import app, get_db

SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"

engine = sa.create_engine(
    SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

# Set up the database once
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)


# These two event listeners are only needed for sqlite for proper
# SAVEPOINT / nested transaction support. Other databases like postgres
# don't need them. 
# From: https://docs.sqlalchemy.org/en/14/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl
@sa.event.listens_for(engine, "connect")
def do_connect(dbapi_connection, connection_record):
    # disable pysqlite's emitting of the BEGIN statement entirely.
    # also stops it from emitting COMMIT before any DDL.
    dbapi_connection.isolation_level = None


@sa.event.listens_for(engine, "begin")
def do_begin(conn):
    # emit our own BEGIN
    conn.exec_driver_sql("BEGIN")


# This fixture is the main difference to before. It creates a nested
# transaction, recreates it when the application code calls session.commit
# and rolls it back at the end.
# Based on: https://docs.sqlalchemy.org/en/14/orm/session_transaction.html#joining-a-session-into-an-external-transaction-such-as-for-test-suites
@pytest.fixture()
def session():
    connection = engine.connect()
    transaction = connection.begin()
    session = TestingSessionLocal(bind=connection)

    # Begin a nested transaction (using SAVEPOINT).
    nested = connection.begin_nested()

    # If the application code calls session.commit, it will end the nested
    # transaction. Need to start a new one when that happens.
    @sa.event.listens_for(session, "after_transaction_end")
    def end_savepoint(session, transaction):
        nonlocal nested
        if not nested.is_active:
            nested = connection.begin_nested()

    yield session

    # Rollback the overall transaction, restoring the state before the test ran.
    session.close()
    transaction.rollback()
    connection.close()


# A fixture for the fastapi test client which depends on the
# previous session fixture. Instead of creating a new session in the
# dependency override as before, it uses the one provided by the
# session fixture.
@pytest.fixture()
def client(session):
    def override_get_db():
        yield session

    app.dependency_overrides[get_db] = override_get_db
    yield TestClient(app)
    del app.dependency_overrides[get_db]


def test_get_empty_todos_list(client):
    response = client.get("/todos/")

    assert response.status_code == 200
    assert response.json() == []

Having two fixtures (session and client) here has an additional advantage:

If a test only talks to the API, then you don’t need to remember adding the db fixture explicitly (but it will still be invoked implicitly).
And if you want to write a test that directly talks the db, you can do that as well:

def test_something(session):
    session.query(...)

Or both, if you for example want to prepare the db state before an API call:

def test_something_else(client, session):
    session.add(...)
    session.commit()
    client.get(...)

Both the application code and test code will see the same state of the db.

Answered By: mihi

You can also truncate the tables after each test run. This clears all data in them without actually removing the schema so it’s not as slow as doing Base.metadata.drop_all(bind=engine):

import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from contextlib import contextmanager

engine = create_engine('postgresql://...')
Session = sessionmaker(bind=engine)
Base = declarative_base()


@contextmanager
def session_scope():
    """Provide a transactional scope around a series of operations."""
    session = Session()
    try:
        yield session
        session.commit()
    except:
        session.rollback()
        raise
    finally:
        session.close()


def clear_tables():
    with session_scope() as conn:
        for table in Base.metadata.sorted_tables:
            conn.execute(
                f"TRUNCATE {table.name} RESTART IDENTITY CASCADE;"
            )
        conn.commit()


@pytest.fixture
def test_db_session():
    yield engine
    engine.dispose()
    clear_tables()


def test_some_feature(test_db_session):
    test_db_session.query(...)
    (...)

Answered By: luisgc93

Here’s a solution for a full FastAPI test environment, including database setup and teardown. Despite the fact that there is already an accepted answer, I’d like to contribute my thoughts.

When configuring a test environment, you’ll want to include these fixtures in your conftest.py file. Fixtures defined within it will be automatically accessible to any of your tests contained within the test package.

a) First of all, do the imports.

Remember that your imports path may differ from mine, so double-check that as well.

import pytest
from fastapi.testclient import TestClient

# Import the SQLAlchemy parts
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from app.main import app
from app.database import get_db,Base

# Create the new database session

SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"

engine = create_engine(SQLALCHEMY_DATABASE_URL)

TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

Following that, we’ll use Pytest fixtures, which are functions that run before each test function to which they’re applied.

b). Session fixture

@pytest.fixture()
def session():

    Base.metadata.drop_all(bind=engine)
    Base.metadata.create_all(bind=engine)

    db = TestingSessionLocal()

    try:
        yield db
    finally:
        db.close()

The above session fixture ensures that every time a test is run, we connect to a testing database, create tables, and then delete the tables once the test is finished.

c) client fixture

@pytest.fixture()
def client(session):

    # Dependency override

    def override_get_db():
        try:

            yield session
        finally:
            session.close()

    app.dependency_overrides[get_db] = override_get_db

    yield TestClient(app)

The above fixture connects us to the new test database and overrides the initial database connection made by the main app. The session fixture is required for this client fixture to function.

After that, you can use the fixtures as shown without needing to import anything as shown below.

def test_index(client):
    res = client.get("/")
    assert res.status_code == 200

Your complete conftest.py file should now look like this:

import pytest
from fastapi.testclient import TestClient

# Import the SQLAlchemy parts
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from app.main import app
from app.database import get_db, Base

# Create the new database session

SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"

engine = create_engine(SQLALCHEMY_DATABASE_URL)

TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


@pytest.fixture()
def session():

    # Create the database

    Base.metadata.drop_all(bind=engine)
    Base.metadata.create_all(bind=engine)

    db = TestingSessionLocal()
    try:
        yield db
    finally:
        db.close()


@pytest.fixture()
def client(session):

    # Dependency override

    def override_get_db():
        try:
            yield session
        finally:
            session.close()

    app.dependency_overrides[get_db] = override_get_db

    yield TestClient(app)


Answered By: Ondiek Elijah

Create a file in the tests folder named confest.py, in this file, we will keep the test database settings and make fixtures that will use in our test API. i m using postgres database.

from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from app.main import app
from app.db.database import Base, get_db
import pytest

TEST_DATABASE_URL = 'postgresql://postgres:admin@localhost:5432/yourdatabsename'
main_app = app    # this app comes from my main file, in which i declared FAST API as app

def start_application():   # start application
  return main_app

SQLALCHEMY_DATABASE_URL = TEST_DATABASE_URL
engine = create_engine(SQLALCHEMY_DATABASE_URL)  # create engine

SessionTesting = sessionmaker(autocommit=False, autoflush=False, bind=engine)  # now we create test-session 


@pytest.fixture(scope="function")
def app():
   """
   Create a fresh database on each test case.
   """
   Base.metadata.create_all(engine)  # Create the tables.
   _app = start_application()
   yield _app
   Base.metadata.drop_all(engine)    # drop that tables


@pytest.fixture(scope="function")
def db_session(app: FastAPI):
   connection = engine.connect()
   transaction = connection.begin()
   session = SessionTesting(bind=connection)
   yield session  # use the session in tests.
   session.close()
   transaction.rollback()
   connection.close()

@pytest.fixture(scope="function")
def client(app: FastAPI, db_session: SessionTesting):
   """
   Create a new FastAPI TestClient that uses the `db_session` fixture to override the `get_db` dependency that is injected into routes.
   """
   
   def _get_test_db():
      db_session = SessionTesting()
      try:
         yield db_session
      finally:
         db_session.close()  

   app.dependency_overrides[get_db] = _get_test_db
   with TestClient(app) as client:
     yield client      

Now this contest.py will use in every test file

For example

In my Organization Type Api, I create a file named test_organization_type in tests folder

from app.models.models import OrganizationType


def test_get_organization_type(client, db_session):
    response_post = client.post(URL will comes Here, json={'type_name': 'test_organization'})
    assert response_post.status_code == 201

    response_get = client.get(URL will come Here)  # get_request
    data = response_get.json()
    assert response_get.status_code == 200

Note: Your local database will be separate and the test database will be separate but in the test database, Objects will create on POST request and after the test it will be removed automatically.

Answered By: Tanveer Ahmad
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.