74 lines
2.1 KiB
74 lines
2.1 KiB
import os
from typing import Any, Generator
import pytest
from src.db import Base, get_db
from src.main import app as _app
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# Default to using sqlite in memory for fast tests.
# Can be overridden by environment variable for testing in CI against other
# database engines
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def app() -> Generator[FastAPI, Any, None]:
Create a fresh database on each test case.
Base.metadata.create_all(engine) # Create the tables.
yield _app
def db_session(app: FastAPI) -> Generator[Session, Any, None]:
Creates a fresh sqlalchemy session for each test that operates in a
transaction. The transaction is rolled back at the end of each test ensuring
a clean state.
# connect to the database
connection = engine.connect()
# begin a non-ORM transaction
transaction = connection.begin()
# bind an individual Session to the connection
session = Session(bind=connection)
yield session # use the session in tests.
# rollback - everything that happened with the
# Session above (including calls to commit())
# is rolled back.
# return connection to the Engine
def client(app: FastAPI, db_session: Session) -> Generator[TestClient, Any, None]:
Create a new FastAPI TestClient that uses the `db_session` fixture to override
the `get_db` dependency that is injected into routes.
def _get_test_db():
yield db_session
app.dependency_overrides[get_db] = _get_test_db
with TestClient(app) as client:
yield client