import os from typing import Any, Generator import pytest from src.db import Base, get_db from src.main import app as _app from fastapi import FastAPI from fastapi.testclient import TestClient from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker # Default to using sqlite in memory for fast tests. # Can be overridden by environment variable for testing in CI against other # database engines SQLALCHEMY_DATABASE_URL = os.getenv('TEST_DATABASE_URL', "sqlite://") engine = create_engine( SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} ) Session = sessionmaker(autocommit=False, autoflush=False, bind=engine) @pytest.fixture(autouse=True) def app() -> Generator[FastAPI, Any, None]: """ Create a fresh database on each test case. """ Base.metadata.create_all(engine) # Create the tables. yield _app Base.metadata.drop_all(engine) @pytest.fixture def db_session(app: FastAPI) -> Generator[Session, Any, None]: """ Creates a fresh sqlalchemy session for each test that operates in a transaction. The transaction is rolled back at the end of each test ensuring a clean state. """ # connect to the database connection = engine.connect() # begin a non-ORM transaction transaction = connection.begin() # bind an individual Session to the connection session = Session(bind=connection) yield session # use the session in tests. session.close() # rollback - everything that happened with the # Session above (including calls to commit()) # is rolled back. transaction.rollback() # return connection to the Engine connection.close() @pytest.fixture() def client(app: FastAPI, db_session: Session) -> Generator[TestClient, Any, None]: """ Create a new FastAPI TestClient that uses the `db_session` fixture to override the `get_db` dependency that is injected into routes. """ def _get_test_db(): try: yield db_session finally: pass app.dependency_overrides[get_db] = _get_test_db with TestClient(app) as client: yield client