feat(backend): ✨ Add Generic Type Hint Support for Data Access Layer
This commit is contained in:
parent
0675c570ce
commit
a266a244d9
4 changed files with 37 additions and 21 deletions
|
@ -1,20 +1,30 @@
|
|||
from typing import Callable, Union
|
||||
from typing import Callable, Generic, TypeVar, Union
|
||||
|
||||
from mealie.core.root_logger import get_logger
|
||||
from mealie.db.models._model_base import SqlAlchemyBase
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import load_only
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
T = TypeVar("T")
|
||||
D = TypeVar("D")
|
||||
|
||||
class BaseAccessModel:
|
||||
def __init__(self, primary_key, sql_model, schema) -> None:
|
||||
self.primary_key: str = primary_key
|
||||
self.sql_model: SqlAlchemyBase = sql_model
|
||||
self.schema: BaseModel = schema
|
||||
|
||||
class BaseAccessModel(Generic[T, D]):
|
||||
"""A Generic BaseAccess Model method to perform common operations on the database
|
||||
|
||||
Args:
|
||||
Generic ([T]): Represents the Pydantic Model
|
||||
Generic ([D]): Represents the SqlAlchemyModel Model
|
||||
"""
|
||||
|
||||
def __init__(self, primary_key: Union[str, int], sql_model: D, schema: T) -> None:
|
||||
self.primary_key = primary_key
|
||||
|
||||
self.sql_model = sql_model
|
||||
|
||||
self.schema = schema
|
||||
|
||||
self.observers: list = []
|
||||
|
||||
|
@ -29,7 +39,7 @@ class BaseAccessModel:
|
|||
|
||||
def get_all(
|
||||
self, session: Session, limit: int = None, order_by: str = None, start=0, override_schema=None
|
||||
) -> list[dict]:
|
||||
) -> list[T]:
|
||||
eff_schema = override_schema or self.schema
|
||||
|
||||
if order_by:
|
||||
|
@ -42,7 +52,7 @@ class BaseAccessModel:
|
|||
|
||||
return [eff_schema.from_orm(x) for x in session.query(self.sql_model).offset(start).limit(limit).all()]
|
||||
|
||||
def get_all_limit_columns(self, session: Session, fields: list[str], limit: int = None) -> list[SqlAlchemyBase]:
|
||||
def get_all_limit_columns(self, session: Session, fields: list[str], limit: int = None) -> list[D]:
|
||||
"""Queries the database for the selected model. Restricts return responses to the
|
||||
keys specified under "fields"
|
||||
|
||||
|
@ -70,7 +80,7 @@ class BaseAccessModel:
|
|||
results_as_dict = [x.dict() for x in results]
|
||||
return [x.get(self.primary_key) for x in results_as_dict]
|
||||
|
||||
def _query_one(self, session: Session, match_value: str, match_key: str = None) -> SqlAlchemyBase:
|
||||
def _query_one(self, session: Session, match_value: str, match_key: str = None) -> D:
|
||||
"""Query the sql database for one item an return the sql alchemy model
|
||||
object. If no match key is provided the primary_key attribute will be used.
|
||||
|
||||
|
@ -89,7 +99,7 @@ class BaseAccessModel:
|
|||
|
||||
def get(
|
||||
self, session: Session, match_value: str, match_key: str = None, limit=1, any_case=False, override_schema=None
|
||||
) -> Union[BaseModel, list[BaseModel]]:
|
||||
) -> Union[T, list[T]]:
|
||||
"""Retrieves an entry from the database by matching a key/value pair. If no
|
||||
key is provided the class objects primary key will be used to match against.
|
||||
|
||||
|
@ -121,9 +131,10 @@ class BaseAccessModel:
|
|||
return eff_schema.from_orm(result[0])
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
return [eff_schema.from_orm(x) for x in result]
|
||||
|
||||
def create(self, session: Session, document: dict) -> BaseModel:
|
||||
def create(self, session: Session, document: T) -> T:
|
||||
"""Creates a new database entry for the given SQL Alchemy Model.
|
||||
|
||||
Args:
|
||||
|
@ -134,17 +145,17 @@ class BaseAccessModel:
|
|||
dict: A dictionary representation of the database entry
|
||||
"""
|
||||
document = document if isinstance(document, dict) else document.dict()
|
||||
|
||||
new_document = self.sql_model(session=session, **document)
|
||||
session.add(new_document)
|
||||
session.commit()
|
||||
session.refresh(new_document)
|
||||
|
||||
if self.observers:
|
||||
self.update_observers()
|
||||
|
||||
return self.schema.from_orm(new_document)
|
||||
|
||||
def update(self, session: Session, match_value: str, new_data: dict) -> BaseModel:
|
||||
def update(self, session: Session, match_value: str, new_data: dict) -> T:
|
||||
"""Update a database entry.
|
||||
Args:
|
||||
session (Session): Database Session
|
||||
|
@ -165,7 +176,7 @@ class BaseAccessModel:
|
|||
session.commit()
|
||||
return self.schema.from_orm(entry)
|
||||
|
||||
def patch(self, session: Session, match_value: str, new_data: dict) -> BaseModel:
|
||||
def patch(self, session: Session, match_value: str, new_data: dict) -> T:
|
||||
new_data = new_data if isinstance(new_data, dict) else new_data.dict()
|
||||
|
||||
entry = self._query_one(session=session, match_value=match_value)
|
||||
|
@ -178,7 +189,7 @@ class BaseAccessModel:
|
|||
|
||||
return self.update(session, match_value, entry_as_dict)
|
||||
|
||||
def delete(self, session: Session, primary_key_value) -> dict:
|
||||
def delete(self, session: Session, primary_key_value) -> D:
|
||||
result = session.query(self.sql_model).filter_by(**{self.primary_key: primary_key_value}).one()
|
||||
results_as_model = self.schema.from_orm(result)
|
||||
|
||||
|
@ -205,7 +216,7 @@ class BaseAccessModel:
|
|||
|
||||
def _count_attribute(
|
||||
self, session: Session, attribute_name: str, attr_match: str = None, count=True, override_schema=None
|
||||
) -> Union[int, BaseModel]:
|
||||
) -> Union[int, T]:
|
||||
eff_schema = override_schema or self.schema
|
||||
# attr_filter = getattr(self.sql_model, attribute_name)
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from mealie.db.models.group import Group
|
||||
from mealie.schema.meal_plan.meal import MealPlanOut
|
||||
from mealie.schema.user.user import GroupInDB
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
@ -5,7 +6,7 @@ from sqlalchemy.orm.session import Session
|
|||
from ._base_access_model import BaseAccessModel
|
||||
|
||||
|
||||
class GroupDataAccessModel(BaseAccessModel):
|
||||
class GroupDataAccessModel(BaseAccessModel[GroupInDB, Group]):
|
||||
def get_meals(self, session: Session, match_value: str, match_key: str = "name") -> list[MealPlanOut]:
|
||||
"""A Helper function to get the group from the database and return a sorted list of
|
||||
|
||||
|
|
|
@ -2,12 +2,13 @@ from random import randint
|
|||
|
||||
from mealie.db.models.recipe.recipe import RecipeModel
|
||||
from mealie.db.models.recipe.settings import RecipeSettings
|
||||
from mealie.schema.recipe import Recipe
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from ._base_access_model import BaseAccessModel
|
||||
|
||||
|
||||
class RecipeDataAccessModel(BaseAccessModel):
|
||||
class RecipeDataAccessModel(BaseAccessModel[Recipe, RecipeModel]):
|
||||
def get_all_public(self, session: Session, limit: int = None, order_by: str = None, start=0, override_schema=None):
|
||||
eff_schema = override_schema or self.schema
|
||||
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
from mealie.db.models.users import User
|
||||
from mealie.schema.user.user import UserInDB
|
||||
|
||||
from ._base_access_model import BaseAccessModel
|
||||
|
||||
|
||||
class UserDataAccessModel(BaseAccessModel):
|
||||
class UserDataAccessModel(BaseAccessModel[UserInDB, User]):
|
||||
def update_password(self, session, id, password: str):
|
||||
entry = self._query_one(session=session, match_value=id)
|
||||
entry.update_password(password)
|
||||
|
|
Loading…
Reference in a new issue