swimtracker-app/backend/tests/conftest.py

74 lines
2.2 KiB
Python

import os
from typing import Any, Generator
import pytest
import databases
from src.db import Base, get_db
from src import get_app
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# Default to using sqlite in memory for fast tests.
# Can be overridden by environment variable for testing in CI against other
# database engines
SQLALCHEMY_DATABASE_URL = os.getenv('TEST_DATABASE_URL', "sqlite://")
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@pytest.fixture(autouse=True)
def app() -> Generator[FastAPI, Any, None]:
"""
Create a fresh database on each test case.
"""
print(list(Base.metadata.tables.keys()))
Base.metadata.create_all(engine) # Create the tables.
_app = get_app(databases.Database(SQLALCHEMY_DATABASE_URL))
yield _app
Base.metadata.drop_all(engine)
@pytest.fixture
def db_session(app: FastAPI) -> Generator[Session, Any, None]:
"""
Creates a fresh sqlalchemy session for each test that operates in a
transaction. The transaction is rolled back at the end of each test ensuring
a clean state.
"""
# connect to the database
connection = engine.connect()
# begin a non-ORM transaction
transaction = connection.begin()
# bind an individual Session to the connection
session = Session(bind=connection)
yield session # use the session in tests.
session.close()
# rollback - everything that happened with the
# Session above (including calls to commit())
# is rolled back.
transaction.rollback()
# return connection to the Engine
connection.close()
@pytest.fixture()
def client(app: FastAPI, db_session: Session) -> Generator[TestClient, Any, None]:
"""
Create a new FastAPI TestClient that uses the `db_session` fixture to override
the `get_db` dependency that is injected into routes.
"""
def _get_test_db():
try:
yield db_session
finally:
pass
app.dependency_overrides[get_db] = _get_test_db
with TestClient(app) as client:
yield client