Backend: first commit
This commit is contained in:
parent
ee6d7d2e96
commit
e5a5e141a9
|
@ -0,0 +1,2 @@
|
|||
DATABASE_URL = "sqlite:///db.sqlite"
|
||||
JWT_SECRET = "4SmRyfsvG86R9jZQfTshfoDlcxYlueHmkMXJbszp"
|
|
@ -0,0 +1,21 @@
|
|||
import databases
|
||||
import sqlalchemy
|
||||
from starlette import requests
|
||||
from config import DATABASE_URL
|
||||
from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from starlette.requests import Request
|
||||
|
||||
|
||||
database = databases.Database(DATABASE_URL)
|
||||
Base: DeclarativeMeta = declarative_base()
|
||||
|
||||
engine = sqlalchemy.create_engine(
|
||||
DATABASE_URL, connect_args={"check_same_thread": False}
|
||||
)
|
||||
|
||||
DbSession = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
def get_db(request: Request):
|
||||
return request.state.db
|
|
@ -0,0 +1,38 @@
|
|||
from fastapi import FastAPI
|
||||
from users import add_user_routers, User
|
||||
from db import database, engine, Base, DbSession
|
||||
from starlette.requests import Request
|
||||
from routes import router as api_router
|
||||
import models
|
||||
|
||||
|
||||
def get_app() -> FastAPI:
|
||||
application = FastAPI(title="swimtracker", debug=True, version="0.1")
|
||||
application.include_router(api_router)
|
||||
add_user_routers(application)
|
||||
return application
|
||||
|
||||
|
||||
app = get_app()
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def db_session_middleware(request: Request, call_next):
|
||||
request.state.db = DbSession()
|
||||
response = await call_next(request)
|
||||
request.state.db.close()
|
||||
return response
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup() -> None:
|
||||
print("creating")
|
||||
Base.metadata.create_all(engine)
|
||||
if not database.is_connected:
|
||||
await database.connect()
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown() -> None:
|
||||
if database.is_connected:
|
||||
await database.disconnect()
|
|
@ -0,0 +1,42 @@
|
|||
from db import Base
|
||||
from sqlalchemy import Column, Integer, Index, LargeBinary, ForeignKey, and_, or_
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class Session(Base):
|
||||
__tablename__ = "session"
|
||||
|
||||
device_id = Column(Integer, primary_key=True)
|
||||
start_time = Column(Integer, primary_key=True)
|
||||
data = Column(LargeBinary(1024 * 1024 * 2), nullable=False)
|
||||
user = Column(ForeignKey("user.id"), nullable=False)
|
||||
|
||||
value_right_shift = Column(Integer)
|
||||
tare_value = Column(Integer)
|
||||
kg_factor = Column(Integer)
|
||||
|
||||
Index('device_id', 'start_time', unique=True)
|
||||
|
||||
|
||||
class FriendRequest(Base):
|
||||
__tablename__ = "friend_request"
|
||||
requesting_user = Column(ForeignKey("user.id"), primary_key=True)
|
||||
receiving_user = Column(ForeignKey("user.id"), primary_key=True)
|
||||
|
||||
|
||||
class Friendship(Base):
|
||||
__tablename__ = "friendship"
|
||||
user_id = Column(ForeignKey("user.id"), primary_key=True)
|
||||
friend_id = Column(ForeignKey("user.id"), primary_key=True)
|
||||
|
||||
@staticmethod
|
||||
def befriend(db, userid1, userid2):
|
||||
db.add(Friendship(user_id=userid1, friend_id=userid2))
|
||||
|
||||
@staticmethod
|
||||
def are_friends(db, userid1, userid2):
|
||||
query_filter = or_(
|
||||
and_(Friendship.user_id == userid1, Friendship.friend_id == userid2),
|
||||
and_(Friendship.user_id == userid2, Friendship.friend_id == userid1),
|
||||
)
|
||||
return db.query(Friendship).filter(query_filter).count() > 0
|
|
@ -0,0 +1,94 @@
|
|||
import base64
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import List
|
||||
import schemas
|
||||
from users import User, UserDB, UserTable, current_user
|
||||
from db import get_db
|
||||
import models
|
||||
from sqlalchemy.orm import Session as DbSession, lazyload
|
||||
from sqlalchemy.orm.exc import NoResultFound
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from fastapi import status
|
||||
from sqlalchemy.sql import select
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/sessions",
|
||||
response_model=schemas.Session,
|
||||
tags=["sessions"],
|
||||
status_code=status.HTTP_201_CREATED)
|
||||
def create_session(session: schemas.SessionBase,
|
||||
db: DbSession = Depends(get_db),
|
||||
user: User = Depends(current_user)):
|
||||
session_props = session.dict()
|
||||
session_props['user'] = user.id
|
||||
session_props['data'] = base64.b64decode(session_props['data'])
|
||||
db_obj = models.Session(**session_props)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
return db_obj
|
||||
|
||||
|
||||
@router.get("/sessions", response_model=List[schemas.Session], tags=["sessions"])
|
||||
def list_sessions(skip=0,
|
||||
limit=100,
|
||||
db: DbSession = Depends(get_db),
|
||||
user: User = Depends(current_user)):
|
||||
return db.query(models.Session).filter(models.Session.user == user.id).order_by(
|
||||
models.Session.start_time.desc()).offset(skip).limit(limit).all()
|
||||
|
||||
|
||||
@router.post("/friends/request_friendship/{user_id}", tags=["friends"])
|
||||
def create_friend_request(other_user_id: str,
|
||||
db: DbSession = Depends(get_db),
|
||||
user: User = Depends(current_user)):
|
||||
if models.Friendship.are_friends(db, other_user_id, user.id):
|
||||
raise HTTPException(status.HTTP_406_NOT_ACCEPTABLE, detail="already friends")
|
||||
|
||||
FR = models.FriendRequest
|
||||
friend_request_from_other_user = db.query(FR).filter(FR.requesting_user == other_user_id,
|
||||
FR.receiving_user == user.id).count()
|
||||
if friend_request_from_other_user > 0:
|
||||
raise HTTPException(status.HTTP_406_NOT_ACCEPTABLE,
|
||||
detail="Friend request exist from other user, accept it")
|
||||
else:
|
||||
try:
|
||||
new_friend_request = FR(requesting_user=user.id, receiving_user=other_user_id)
|
||||
db.add(new_friend_request)
|
||||
db.commit()
|
||||
return {"msg": "ok"}
|
||||
except IntegrityError:
|
||||
raise HTTPException(status.HTTP_406_NOT_ACCEPTABLE,
|
||||
detail="Friend request already exists")
|
||||
|
||||
|
||||
@router.post("/friends/accept_friendship/{user_id}", tags=["friends"])
|
||||
def accept_friend_request(other_user_id: str,
|
||||
db: DbSession = Depends(get_db),
|
||||
user: User = Depends(current_user)):
|
||||
FR = models.FriendRequest
|
||||
try:
|
||||
friend_request = db.query(FR).filter(FR.requesting_user == other_user_id,
|
||||
FR.receiving_user == user.id).one()
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="No matching friend request found")
|
||||
|
||||
models.Friendship.befriend(db, other_user_id, user.id)
|
||||
db.delete(friend_request)
|
||||
db.commit()
|
||||
return {"msg": "ok"}
|
||||
|
||||
|
||||
@router.get("/friends", tags=["friends"], response_model=schemas.FriendsInfo)
|
||||
def list_friends_info(db: DbSession = Depends(get_db), user: User = Depends(current_user)):
|
||||
user_obj = db.query(UserTable).filter(UserTable.id == user.id).one()
|
||||
return schemas.FriendsInfo(incoming_requests=user_obj.friend_requests_in,
|
||||
outgoing_requests=user_obj.friend_requests_out)
|
||||
|
||||
|
||||
# todo: remove friend requests
|
||||
# todo: remove friendship
|
||||
# todo: search user by email
|
||||
# todo: add usernames to users
|
||||
# todo: search by username
|
|
@ -0,0 +1,40 @@
|
|||
from typing import Optional, List
|
||||
from pydantic import BaseModel, conint, UUID4
|
||||
from pydantic.networks import EmailStr
|
||||
|
||||
|
||||
class SessionBase(BaseModel):
|
||||
device_id: int
|
||||
start_time: conint(gt=1546297200)
|
||||
data: str
|
||||
|
||||
value_right_shift: Optional[conint(ge=0, le=32)]
|
||||
tare_value: Optional[conint(ge=0)]
|
||||
kg_factor: Optional[conint(ge=0)]
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
id: UUID4
|
||||
email: EmailStr
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class Session(SessionBase):
|
||||
user: UUID4
|
||||
|
||||
|
||||
class FriendRequestCreate(BaseModel):
|
||||
other_user_id: int
|
||||
|
||||
|
||||
class FriendsInfo(BaseModel):
|
||||
incoming_requests: List[UserInfo]
|
||||
outgoing_requests: List[UserInfo]
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
|
@ -0,0 +1,92 @@
|
|||
from fastapi_users import FastAPIUsers, models
|
||||
from fastapi_users.db import SQLAlchemyBaseUserTable, SQLAlchemyUserDatabase
|
||||
from fastapi_users.authentication import JWTAuthentication
|
||||
from config import JWT_SECRET
|
||||
from fastapi import Request
|
||||
from db import database, Base
|
||||
from sqlalchemy.orm import relationship, backref
|
||||
from sqlalchemy import Integer, Column
|
||||
from fastapi_users.models import BaseUser
|
||||
|
||||
|
||||
class User(models.BaseUser):
|
||||
pass
|
||||
|
||||
|
||||
class UserCreate(models.BaseUserCreate):
|
||||
pass
|
||||
|
||||
|
||||
class UserUpdate(User, models.BaseUserUpdate):
|
||||
pass
|
||||
|
||||
|
||||
class UserDB(User, models.BaseUserDB):
|
||||
pass
|
||||
|
||||
|
||||
class UserTable(Base, SQLAlchemyBaseUserTable):
|
||||
#id = Column(Integer, primary_key=True)
|
||||
sessions = relationship("Session")
|
||||
friend_requests_in = relationship(
|
||||
"UserTable",
|
||||
secondary="friend_request",
|
||||
primaryjoin=("UserTable.id == FriendRequest.receiving_user"),
|
||||
secondaryjoin=("UserTable.id == FriendRequest.requesting_user"),
|
||||
backref=backref("friend_requests_out"))
|
||||
friends = relationship('UserTable',
|
||||
secondary="friendship",
|
||||
primaryjoin=("UserTable.id == Friendship.user_id"),
|
||||
secondaryjoin=("UserTable.id == Friendship.friend_id"))
|
||||
|
||||
|
||||
user_db = SQLAlchemyUserDatabase(UserDB, database, UserTable.__table__)
|
||||
jwt_authentication = JWTAuthentication(secret=JWT_SECRET,
|
||||
lifetime_seconds=60 * 60 * 8,
|
||||
tokenUrl="auth/jwt/login")
|
||||
|
||||
fastapi_users = FastAPIUsers(
|
||||
user_db,
|
||||
[jwt_authentication],
|
||||
User,
|
||||
UserCreate,
|
||||
UserUpdate,
|
||||
UserDB,
|
||||
)
|
||||
current_user = fastapi_users.current_user(active=True, verified=True)
|
||||
current_superuser = fastapi_users.current_user(active=True, superuser=True, verified=True)
|
||||
|
||||
|
||||
def on_after_register(user: UserDB, request: Request):
|
||||
print(f"User {user.id} has registered.")
|
||||
|
||||
|
||||
def on_after_forgot_password(user: UserDB, token: str, request: Request):
|
||||
print(f"User {user.id} has forgot their password. Reset token: {token}")
|
||||
|
||||
|
||||
def after_verification_request(user: UserDB, token: str, request: Request):
|
||||
print(f"Verification requested for user {user.id}. Verification token: {token}")
|
||||
|
||||
|
||||
def add_user_routers(app):
|
||||
app.include_router(fastapi_users.get_auth_router(jwt_authentication),
|
||||
prefix="/auth/jwt",
|
||||
tags=["auth"])
|
||||
|
||||
app.include_router(fastapi_users.get_register_router(on_after_register),
|
||||
prefix="/auth",
|
||||
tags=["auth"])
|
||||
app.include_router(
|
||||
fastapi_users.get_reset_password_router(JWT_SECRET,
|
||||
after_forgot_password=on_after_forgot_password),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
app.include_router(
|
||||
fastapi_users.get_verify_router(JWT_SECRET,
|
||||
after_verification_request=after_verification_request),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
app.include_router(fastapi_users.get_users_router(), prefix="/users", tags=["users"])
|
|
@ -0,0 +1,73 @@
|
|||
import os
|
||||
from typing import Any, Generator
|
||||
|
||||
import pytest
|
||||
from src.db import Base, get_db
|
||||
from src.main import app as _app
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# Default to using sqlite in memory for fast tests.
|
||||
# Can be overridden by environment variable for testing in CI against other
|
||||
# database engines
|
||||
SQLALCHEMY_DATABASE_URL = os.getenv('TEST_DATABASE_URL', "sqlite://")
|
||||
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
||||
)
|
||||
|
||||
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def app() -> Generator[FastAPI, Any, None]:
|
||||
"""
|
||||
Create a fresh database on each test case.
|
||||
"""
|
||||
Base.metadata.create_all(engine) # Create the tables.
|
||||
yield _app
|
||||
Base.metadata.drop_all(engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_session(app: FastAPI) -> Generator[Session, Any, None]:
|
||||
"""
|
||||
Creates a fresh sqlalchemy session for each test that operates in a
|
||||
transaction. The transaction is rolled back at the end of each test ensuring
|
||||
a clean state.
|
||||
"""
|
||||
|
||||
# connect to the database
|
||||
connection = engine.connect()
|
||||
# begin a non-ORM transaction
|
||||
transaction = connection.begin()
|
||||
# bind an individual Session to the connection
|
||||
session = Session(bind=connection)
|
||||
yield session # use the session in tests.
|
||||
session.close()
|
||||
# rollback - everything that happened with the
|
||||
# Session above (including calls to commit())
|
||||
# is rolled back.
|
||||
transaction.rollback()
|
||||
# return connection to the Engine
|
||||
connection.close()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client(app: FastAPI, db_session: Session) -> Generator[TestClient, Any, None]:
|
||||
"""
|
||||
Create a new FastAPI TestClient that uses the `db_session` fixture to override
|
||||
the `get_db` dependency that is injected into routes.
|
||||
"""
|
||||
|
||||
def _get_test_db():
|
||||
try:
|
||||
yield db_session
|
||||
finally:
|
||||
pass
|
||||
|
||||
app.dependency_overrides[get_db] = _get_test_db
|
||||
with TestClient(app) as client:
|
||||
yield client
|
|
@ -0,0 +1,11 @@
|
|||
from backend.src.db import DbSession
|
||||
from src.schemas import Session
|
||||
from fastapi import FastAPI
|
||||
from src.db import DbSession
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
#----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_session_create(app: FastAPI, db_session: DbSession, client: TestClient):
|
||||
pass
|
Loading…
Reference in New Issue