refactor: remove depreciated repo call (#1370)

* ingredient parser hot fixes (float equality)

* remove `get` in favor of `get_one` & `multi_query`
This commit is contained in:
Hayden 2022-06-10 19:01:14 -08:00 committed by GitHub
parent b904b161eb
commit 932f4a72df
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 25 additions and 44 deletions

View file

@ -13,7 +13,7 @@ from sqlalchemy.orm.session import Session
from mealie.core.config import get_app_dirs, get_app_settings
from mealie.db.db_setup import generate_session
from mealie.repos.all_repositories import get_repositories
from mealie.schema.user import LongLiveTokenInDB, PrivateUser, TokenData
from mealie.schema.user import PrivateUser, TokenData
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/token")
oauth2_scheme_soft_fail = OAuth2PasswordBearer(tokenUrl="/api/auth/token", auto_error=False)
@ -76,7 +76,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme), session=Depends(
repos = get_repositories(session)
user = repos.users.get(token_data.user_id, "id", any_case=False)
user = repos.users.get_one(token_data.user_id, "id", any_case=False)
if user is None:
raise credentials_exception
@ -89,16 +89,15 @@ async def get_admin_user(current_user: PrivateUser = Depends(get_current_user))
return current_user
def validate_long_live_token(session: Session, client_token: str, id: int) -> PrivateUser:
def validate_long_live_token(session: Session, client_token: str, user_id: str) -> PrivateUser:
repos = get_repositories(session)
tokens: list[LongLiveTokenInDB] = repos.api_tokens.get(id, "user_id", limit=9999)
token = repos.api_tokens.multi_query({"token": client_token, "user_id": user_id})
for token in tokens:
if token.token == client_token:
return token.user
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Token")
try:
return token[0].user
except IndexError as e:
raise HTTPException(status.HTTP_401_UNAUTHORIZED) from e
def validate_file_token(token: Optional[str] = None) -> Path:

View file

@ -122,33 +122,6 @@ class RepositoryGeneric(Generic[Schema, Model]):
eff_schema = override_schema or self.schema
return eff_schema.from_orm(result)
def get(
self, match_value: str | int | UUID4, match_key: str = None, limit=1, any_case=False, override_schema=None
) -> Schema | list[Schema] | None:
self.logger.info("DEPRECATED: use get_one or get_all instead")
match_key = match_key or self.primary_key
if any_case:
search_attr = getattr(self.model, match_key)
result = (
self.session.query(self.model)
.filter(func.lower(search_attr) == match_value.lower()) # type: ignore
.limit(limit)
.all()
)
else:
result = self.session.query(self.model).filter_by(**{match_key: match_value}).limit(limit).all()
eff_schema = override_schema or self.schema
if limit == 1:
try:
return eff_schema.from_orm(result[0])
except IndexError:
return None
return [eff_schema.from_orm(x) for x in result]
def create(self, data: Schema | BaseModel | dict) -> Schema:
data = data if isinstance(data, dict) else data.dict()
new_document = self.model(session=self.session, **data) # type: ignore

View file

@ -42,7 +42,7 @@ class GroupSelfServiceController(BaseUserController):
def set_member_permissions(self, permissions: SetPermissions):
self.checks.can_manage()
target_user = self.repos.users.get(permissions.user_id)
target_user = self.repos.users.get_one(permissions.user_id)
if not target_user:
raise HTTPException(status.HTTP_404_NOT_FOUND, detail="User not found")

View file

@ -38,7 +38,7 @@ class UserApiTokensController(BaseUserController):
@router.delete("/api-tokens/{token_id}", response_model=DeleteTokenResponse)
def delete_api_token(self, token_id: int):
"""Delete api_token from the Database"""
token: LongLiveTokenInDB = self.repos.api_tokens.get(token_id)
token: LongLiveTokenInDB = self.repos.api_tokens.get_one(token_id)
if not token:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Could not locate token with id '{token_id}' in database")

View file

@ -13,7 +13,7 @@ class UserFavoritesController(BaseUserController):
@router.get("/{id}/favorites", response_model=UserFavorites)
async def get_favorites(self, id: UUID4):
"""Get user's favorite recipes"""
return self.repos.users.get(id, override_schema=UserFavorites)
return self.repos.users.get_one(id, override_schema=UserFavorites)
@router.post("/{id}/favorites/{slug}")
def add_favorite(self, id: UUID4, slug: str):

View file

@ -88,6 +88,15 @@ class IngredientConfidence(MealieModel):
quantity: NoneFloat = None
food: NoneFloat = None
@validator("quantity", pre=True)
@classmethod
def validate_quantity(cls, value, values) -> NoneFloat:
if isinstance(value, float):
return round(value, 3)
if value is None or value == "":
return None
return value
class ParsedIngredient(MealieModel):
input: Optional[str]

View file

@ -59,7 +59,7 @@ class RegistrationService:
if self.repos.users.get_by_username(registration.username):
raise HTTPException(status.HTTP_409_CONFLICT, {"message": self.t("exceptions.username-conflict-error")})
elif self.repos.users.get(registration.email, "email"):
elif self.repos.users.get_one(registration.email, "email"):
raise HTTPException(status.HTTP_409_CONFLICT, {"message": self.t("exceptions.email-conflict-error")})
self.logger.info(f"Registering user {registration.username}")

View file

@ -75,7 +75,7 @@ def list_with_items(database: AllRepositories, unique_user: TestUser):
)
# refresh model
list_model = database.group_shopping_lists.get(list_model.id)
list_model = database.group_shopping_lists.get_one(list_model.id)
yield list_model

View file

@ -12,7 +12,7 @@ class Routes:
def assert_ingredient(api_response: dict, test_ingredient: TestIngredient):
assert api_response["ingredient"]["quantity"] == test_ingredient.quantity
assert api_response["ingredient"]["quantity"] == pytest.approx(test_ingredient.quantity)
assert api_response["ingredient"]["unit"]["name"] == test_ingredient.unit
assert api_response["ingredient"]["food"]["name"] == test_ingredient.food
assert api_response["ingredient"]["note"] == test_ingredient.comments

View file

@ -31,7 +31,7 @@ test_ingredients = [
# Small Fraction Tests - PR #1369
# Reported error is was for 1/8 - new lowest expected threshold is 1/32
TestIngredient("1/8 cup all-purpose flour", 0.125, "cup", "all-purpose flour", ""),
TestIngredient("1/32 cup all-purpose flour", 0.03125, "cup", "all-purpose flour", ""),
TestIngredient("1/32 cup all-purpose flour", 0.031, "cup", "all-purpose flour", ""),
]
@ -41,7 +41,7 @@ def test_nlp_parser():
# Itterate over mdoels and test_ingreidnets to gether
for model, test_ingredient in zip(models, test_ingredients):
assert float(sum(Fraction(s) for s in model.qty.split())) == test_ingredient.quantity
assert round(float(sum(Fraction(s) for s in model.qty.split())), 3) == pytest.approx(test_ingredient.quantity)
assert model.comment == test_ingredient.comments
assert model.name == test_ingredient.food