74 lines
2.1 KiB
Python
74 lines
2.1 KiB
Python
import os
|
|
from typing import Any, Generator
|
|
|
|
import pytest
|
|
from src.db import Base, get_db
|
|
from src.main import app as _app
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
# Default to using sqlite in memory for fast tests.
|
|
# Can be overridden by environment variable for testing in CI against other
|
|
# database engines
|
|
SQLALCHEMY_DATABASE_URL = os.getenv('TEST_DATABASE_URL', "sqlite://")
|
|
|
|
engine = create_engine(
|
|
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
|
)
|
|
|
|
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def app() -> Generator[FastAPI, Any, None]:
|
|
"""
|
|
Create a fresh database on each test case.
|
|
"""
|
|
Base.metadata.create_all(engine) # Create the tables.
|
|
yield _app
|
|
Base.metadata.drop_all(engine)
|
|
|
|
|
|
@pytest.fixture
|
|
def db_session(app: FastAPI) -> Generator[Session, Any, None]:
|
|
"""
|
|
Creates a fresh sqlalchemy session for each test that operates in a
|
|
transaction. The transaction is rolled back at the end of each test ensuring
|
|
a clean state.
|
|
"""
|
|
|
|
# connect to the database
|
|
connection = engine.connect()
|
|
# begin a non-ORM transaction
|
|
transaction = connection.begin()
|
|
# bind an individual Session to the connection
|
|
session = Session(bind=connection)
|
|
yield session # use the session in tests.
|
|
session.close()
|
|
# rollback - everything that happened with the
|
|
# Session above (including calls to commit())
|
|
# is rolled back.
|
|
transaction.rollback()
|
|
# return connection to the Engine
|
|
connection.close()
|
|
|
|
|
|
@pytest.fixture()
|
|
def client(app: FastAPI, db_session: Session) -> Generator[TestClient, Any, None]:
|
|
"""
|
|
Create a new FastAPI TestClient that uses the `db_session` fixture to override
|
|
the `get_db` dependency that is injected into routes.
|
|
"""
|
|
|
|
def _get_test_db():
|
|
try:
|
|
yield db_session
|
|
finally:
|
|
pass
|
|
|
|
app.dependency_overrides[get_db] = _get_test_db
|
|
with TestClient(app) as client:
|
|
yield client
|