feat(backend): Add Generic Type Hint Support for Data Access Layer

This commit is contained in:
hay-kot 2021-08-27 20:27:20 -08:00
parent 0675c570ce
commit a266a244d9
4 changed files with 37 additions and 21 deletions

View file

@ -1,20 +1,30 @@
from typing import Callable, Union
from typing import Callable, Generic, TypeVar, Union
from mealie.core.root_logger import get_logger
from mealie.db.models._model_base import SqlAlchemyBase
from pydantic import BaseModel
from sqlalchemy import func
from sqlalchemy.orm import load_only
from sqlalchemy.orm.session import Session
logger = get_logger()
T = TypeVar("T")
D = TypeVar("D")
class BaseAccessModel:
def __init__(self, primary_key, sql_model, schema) -> None:
self.primary_key: str = primary_key
self.sql_model: SqlAlchemyBase = sql_model
self.schema: BaseModel = schema
class BaseAccessModel(Generic[T, D]):
"""A Generic BaseAccess Model method to perform common operations on the database
Args:
Generic ([T]): Represents the Pydantic Model
Generic ([D]): Represents the SqlAlchemyModel Model
"""
def __init__(self, primary_key: Union[str, int], sql_model: D, schema: T) -> None:
self.primary_key = primary_key
self.sql_model = sql_model
self.schema = schema
self.observers: list = []
@ -29,7 +39,7 @@ class BaseAccessModel:
def get_all(
self, session: Session, limit: int = None, order_by: str = None, start=0, override_schema=None
) -> list[dict]:
) -> list[T]:
eff_schema = override_schema or self.schema
if order_by:
@ -42,7 +52,7 @@ class BaseAccessModel:
return [eff_schema.from_orm(x) for x in session.query(self.sql_model).offset(start).limit(limit).all()]
def get_all_limit_columns(self, session: Session, fields: list[str], limit: int = None) -> list[SqlAlchemyBase]:
def get_all_limit_columns(self, session: Session, fields: list[str], limit: int = None) -> list[D]:
"""Queries the database for the selected model. Restricts return responses to the
keys specified under "fields"
@ -70,7 +80,7 @@ class BaseAccessModel:
results_as_dict = [x.dict() for x in results]
return [x.get(self.primary_key) for x in results_as_dict]
def _query_one(self, session: Session, match_value: str, match_key: str = None) -> SqlAlchemyBase:
def _query_one(self, session: Session, match_value: str, match_key: str = None) -> D:
"""Query the sql database for one item an return the sql alchemy model
object. If no match key is provided the primary_key attribute will be used.
@ -89,7 +99,7 @@ class BaseAccessModel:
def get(
self, session: Session, match_value: str, match_key: str = None, limit=1, any_case=False, override_schema=None
) -> Union[BaseModel, list[BaseModel]]:
) -> Union[T, list[T]]:
"""Retrieves an entry from the database by matching a key/value pair. If no
key is provided the class objects primary key will be used to match against.
@ -121,9 +131,10 @@ class BaseAccessModel:
return eff_schema.from_orm(result[0])
except IndexError:
return None
return [eff_schema.from_orm(x) for x in result]
def create(self, session: Session, document: dict) -> BaseModel:
def create(self, session: Session, document: T) -> T:
"""Creates a new database entry for the given SQL Alchemy Model.
Args:
@ -134,17 +145,17 @@ class BaseAccessModel:
dict: A dictionary representation of the database entry
"""
document = document if isinstance(document, dict) else document.dict()
new_document = self.sql_model(session=session, **document)
session.add(new_document)
session.commit()
session.refresh(new_document)
if self.observers:
self.update_observers()
return self.schema.from_orm(new_document)
def update(self, session: Session, match_value: str, new_data: dict) -> BaseModel:
def update(self, session: Session, match_value: str, new_data: dict) -> T:
"""Update a database entry.
Args:
session (Session): Database Session
@ -165,7 +176,7 @@ class BaseAccessModel:
session.commit()
return self.schema.from_orm(entry)
def patch(self, session: Session, match_value: str, new_data: dict) -> BaseModel:
def patch(self, session: Session, match_value: str, new_data: dict) -> T:
new_data = new_data if isinstance(new_data, dict) else new_data.dict()
entry = self._query_one(session=session, match_value=match_value)
@ -178,7 +189,7 @@ class BaseAccessModel:
return self.update(session, match_value, entry_as_dict)
def delete(self, session: Session, primary_key_value) -> dict:
def delete(self, session: Session, primary_key_value) -> D:
result = session.query(self.sql_model).filter_by(**{self.primary_key: primary_key_value}).one()
results_as_model = self.schema.from_orm(result)
@ -205,7 +216,7 @@ class BaseAccessModel:
def _count_attribute(
self, session: Session, attribute_name: str, attr_match: str = None, count=True, override_schema=None
) -> Union[int, BaseModel]:
) -> Union[int, T]:
eff_schema = override_schema or self.schema
# attr_filter = getattr(self.sql_model, attribute_name)

View file

@ -1,3 +1,4 @@
from mealie.db.models.group import Group
from mealie.schema.meal_plan.meal import MealPlanOut
from mealie.schema.user.user import GroupInDB
from sqlalchemy.orm.session import Session
@ -5,7 +6,7 @@ from sqlalchemy.orm.session import Session
from ._base_access_model import BaseAccessModel
class GroupDataAccessModel(BaseAccessModel):
class GroupDataAccessModel(BaseAccessModel[GroupInDB, Group]):
def get_meals(self, session: Session, match_value: str, match_key: str = "name") -> list[MealPlanOut]:
"""A Helper function to get the group from the database and return a sorted list of

View file

@ -2,12 +2,13 @@ from random import randint
from mealie.db.models.recipe.recipe import RecipeModel
from mealie.db.models.recipe.settings import RecipeSettings
from mealie.schema.recipe import Recipe
from sqlalchemy.orm.session import Session
from ._base_access_model import BaseAccessModel
class RecipeDataAccessModel(BaseAccessModel):
class RecipeDataAccessModel(BaseAccessModel[Recipe, RecipeModel]):
def get_all_public(self, session: Session, limit: int = None, order_by: str = None, start=0, override_schema=None):
eff_schema = override_schema or self.schema

View file

@ -1,7 +1,10 @@
from mealie.db.models.users import User
from mealie.schema.user.user import UserInDB
from ._base_access_model import BaseAccessModel
class UserDataAccessModel(BaseAccessModel):
class UserDataAccessModel(BaseAccessModel[UserInDB, User]):
def update_password(self, session, id, password: str):
entry = self._query_one(session=session, match_value=id)
entry.update_password(password)