refactor: remove depreciated repo call (#1370)
* ingredient parser hot fixes (float equality) * remove `get` in favor of `get_one` & `multi_query`
This commit is contained in:
parent
b904b161eb
commit
932f4a72df
10 changed files with 25 additions and 44 deletions
|
@ -13,7 +13,7 @@ from sqlalchemy.orm.session import Session
|
|||
from mealie.core.config import get_app_dirs, get_app_settings
|
||||
from mealie.db.db_setup import generate_session
|
||||
from mealie.repos.all_repositories import get_repositories
|
||||
from mealie.schema.user import LongLiveTokenInDB, PrivateUser, TokenData
|
||||
from mealie.schema.user import PrivateUser, TokenData
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/token")
|
||||
oauth2_scheme_soft_fail = OAuth2PasswordBearer(tokenUrl="/api/auth/token", auto_error=False)
|
||||
|
@ -76,7 +76,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme), session=Depends(
|
|||
|
||||
repos = get_repositories(session)
|
||||
|
||||
user = repos.users.get(token_data.user_id, "id", any_case=False)
|
||||
user = repos.users.get_one(token_data.user_id, "id", any_case=False)
|
||||
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
|
@ -89,16 +89,15 @@ async def get_admin_user(current_user: PrivateUser = Depends(get_current_user))
|
|||
return current_user
|
||||
|
||||
|
||||
def validate_long_live_token(session: Session, client_token: str, id: int) -> PrivateUser:
|
||||
def validate_long_live_token(session: Session, client_token: str, user_id: str) -> PrivateUser:
|
||||
repos = get_repositories(session)
|
||||
|
||||
tokens: list[LongLiveTokenInDB] = repos.api_tokens.get(id, "user_id", limit=9999)
|
||||
token = repos.api_tokens.multi_query({"token": client_token, "user_id": user_id})
|
||||
|
||||
for token in tokens:
|
||||
if token.token == client_token:
|
||||
return token.user
|
||||
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Token")
|
||||
try:
|
||||
return token[0].user
|
||||
except IndexError as e:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED) from e
|
||||
|
||||
|
||||
def validate_file_token(token: Optional[str] = None) -> Path:
|
||||
|
|
|
@ -122,33 +122,6 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
|||
eff_schema = override_schema or self.schema
|
||||
return eff_schema.from_orm(result)
|
||||
|
||||
def get(
|
||||
self, match_value: str | int | UUID4, match_key: str = None, limit=1, any_case=False, override_schema=None
|
||||
) -> Schema | list[Schema] | None:
|
||||
self.logger.info("DEPRECATED: use get_one or get_all instead")
|
||||
match_key = match_key or self.primary_key
|
||||
|
||||
if any_case:
|
||||
search_attr = getattr(self.model, match_key)
|
||||
result = (
|
||||
self.session.query(self.model)
|
||||
.filter(func.lower(search_attr) == match_value.lower()) # type: ignore
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
result = self.session.query(self.model).filter_by(**{match_key: match_value}).limit(limit).all()
|
||||
|
||||
eff_schema = override_schema or self.schema
|
||||
|
||||
if limit == 1:
|
||||
try:
|
||||
return eff_schema.from_orm(result[0])
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
return [eff_schema.from_orm(x) for x in result]
|
||||
|
||||
def create(self, data: Schema | BaseModel | dict) -> Schema:
|
||||
data = data if isinstance(data, dict) else data.dict()
|
||||
new_document = self.model(session=self.session, **data) # type: ignore
|
||||
|
|
|
@ -42,7 +42,7 @@ class GroupSelfServiceController(BaseUserController):
|
|||
def set_member_permissions(self, permissions: SetPermissions):
|
||||
self.checks.can_manage()
|
||||
|
||||
target_user = self.repos.users.get(permissions.user_id)
|
||||
target_user = self.repos.users.get_one(permissions.user_id)
|
||||
|
||||
if not target_user:
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
|
|
|
@ -38,7 +38,7 @@ class UserApiTokensController(BaseUserController):
|
|||
@router.delete("/api-tokens/{token_id}", response_model=DeleteTokenResponse)
|
||||
def delete_api_token(self, token_id: int):
|
||||
"""Delete api_token from the Database"""
|
||||
token: LongLiveTokenInDB = self.repos.api_tokens.get(token_id)
|
||||
token: LongLiveTokenInDB = self.repos.api_tokens.get_one(token_id)
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Could not locate token with id '{token_id}' in database")
|
||||
|
|
|
@ -13,7 +13,7 @@ class UserFavoritesController(BaseUserController):
|
|||
@router.get("/{id}/favorites", response_model=UserFavorites)
|
||||
async def get_favorites(self, id: UUID4):
|
||||
"""Get user's favorite recipes"""
|
||||
return self.repos.users.get(id, override_schema=UserFavorites)
|
||||
return self.repos.users.get_one(id, override_schema=UserFavorites)
|
||||
|
||||
@router.post("/{id}/favorites/{slug}")
|
||||
def add_favorite(self, id: UUID4, slug: str):
|
||||
|
|
|
@ -88,6 +88,15 @@ class IngredientConfidence(MealieModel):
|
|||
quantity: NoneFloat = None
|
||||
food: NoneFloat = None
|
||||
|
||||
@validator("quantity", pre=True)
|
||||
@classmethod
|
||||
def validate_quantity(cls, value, values) -> NoneFloat:
|
||||
if isinstance(value, float):
|
||||
return round(value, 3)
|
||||
if value is None or value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
class ParsedIngredient(MealieModel):
|
||||
input: Optional[str]
|
||||
|
|
|
@ -59,7 +59,7 @@ class RegistrationService:
|
|||
|
||||
if self.repos.users.get_by_username(registration.username):
|
||||
raise HTTPException(status.HTTP_409_CONFLICT, {"message": self.t("exceptions.username-conflict-error")})
|
||||
elif self.repos.users.get(registration.email, "email"):
|
||||
elif self.repos.users.get_one(registration.email, "email"):
|
||||
raise HTTPException(status.HTTP_409_CONFLICT, {"message": self.t("exceptions.email-conflict-error")})
|
||||
|
||||
self.logger.info(f"Registering user {registration.username}")
|
||||
|
|
2
tests/fixtures/fixture_shopping_lists.py
vendored
2
tests/fixtures/fixture_shopping_lists.py
vendored
|
@ -75,7 +75,7 @@ def list_with_items(database: AllRepositories, unique_user: TestUser):
|
|||
)
|
||||
|
||||
# refresh model
|
||||
list_model = database.group_shopping_lists.get(list_model.id)
|
||||
list_model = database.group_shopping_lists.get_one(list_model.id)
|
||||
|
||||
yield list_model
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ class Routes:
|
|||
|
||||
|
||||
def assert_ingredient(api_response: dict, test_ingredient: TestIngredient):
|
||||
assert api_response["ingredient"]["quantity"] == test_ingredient.quantity
|
||||
assert api_response["ingredient"]["quantity"] == pytest.approx(test_ingredient.quantity)
|
||||
assert api_response["ingredient"]["unit"]["name"] == test_ingredient.unit
|
||||
assert api_response["ingredient"]["food"]["name"] == test_ingredient.food
|
||||
assert api_response["ingredient"]["note"] == test_ingredient.comments
|
||||
|
|
|
@ -31,7 +31,7 @@ test_ingredients = [
|
|||
# Small Fraction Tests - PR #1369
|
||||
# Reported error is was for 1/8 - new lowest expected threshold is 1/32
|
||||
TestIngredient("1/8 cup all-purpose flour", 0.125, "cup", "all-purpose flour", ""),
|
||||
TestIngredient("1/32 cup all-purpose flour", 0.03125, "cup", "all-purpose flour", ""),
|
||||
TestIngredient("1/32 cup all-purpose flour", 0.031, "cup", "all-purpose flour", ""),
|
||||
]
|
||||
|
||||
|
||||
|
@ -41,7 +41,7 @@ def test_nlp_parser():
|
|||
|
||||
# Itterate over mdoels and test_ingreidnets to gether
|
||||
for model, test_ingredient in zip(models, test_ingredients):
|
||||
assert float(sum(Fraction(s) for s in model.qty.split())) == test_ingredient.quantity
|
||||
assert round(float(sum(Fraction(s) for s in model.qty.split())), 3) == pytest.approx(test_ingredient.quantity)
|
||||
|
||||
assert model.comment == test_ingredient.comments
|
||||
assert model.name == test_ingredient.food
|
||||
|
|
Loading…
Reference in a new issue