diff --git a/api/src/api/route.py b/api/src/api/route.py index 6347fcd..40edd8b 100644 --- a/api/src/api/route.py +++ b/api/src/api/route.py @@ -36,15 +36,15 @@ class Route(RouteUserExtension): @classmethod async def _get_auth_type( - cls, request: Request, auth_header: str + cls, request: Request, auth_header: str ) -> Optional[Union[User, ApiKey]]: if auth_header.startswith("Bearer "): return await cls.get_user() elif auth_header.startswith("API-Key "): return await cls.get_api_key(request) elif ( - auth_header.startswith("DEV-User ") - and Environment.get_environment() == "development" + auth_header.startswith("DEV-User ") + and Environment.get_environment() == "development" ): return await cls.get_dev_user() return None @@ -66,7 +66,7 @@ class Route(RouteUserExtension): @classmethod async def get_authenticated_user_or_api_key_or_default( - cls, + cls, ) -> Optional[Union[User, ApiKey]]: request = get_request() if request is None: @@ -93,8 +93,8 @@ class Route(RouteUserExtension): elif auth_header.startswith("API-Key "): return await cls._verify_api_key(request) elif ( - auth_header.startswith("DEV-User ") - and Environment.get_environment() == "development" + auth_header.startswith("DEV-User ") + and Environment.get_environment() == "development" ): user = await cls.get_dev_user() return user is not None @@ -102,10 +102,10 @@ class Route(RouteUserExtension): @classmethod def authorize( - cls, - f: Callable = None, - skip_in_dev=False, - by_api_key=False, + cls, + f: Callable = None, + skip_in_dev=False, + by_api_key=False, ): if f is None: return functools.partial( diff --git a/api/src/api_graphql/abc/mutation_abc.py b/api/src/api_graphql/abc/mutation_abc.py index e47dfa3..5768935 100644 --- a/api/src/api_graphql/abc/mutation_abc.py +++ b/api/src/api_graphql/abc/mutation_abc.py @@ -13,11 +13,11 @@ class MutationABC(QueryABC): QueryABC.__init__(self, f"{name}Mutation") def add_mutation_type( - self, - name: str, - mutation_name: str, - require_any_permission=None, - public: bool = False, + self, + name: str, + mutation_name: str, + require_any_permission=None, + public: bool = False, ): """ Add mutation type (sub mutation) to the mutation object diff --git a/api/src/api_graphql/abc/query_abc.py b/api/src/api_graphql/abc/query_abc.py index 9c74ee4..bafce27 100644 --- a/api/src/api_graphql/abc/query_abc.py +++ b/api/src/api_graphql/abc/query_abc.py @@ -64,19 +64,19 @@ class QueryABC(ObjectType): @classmethod async def _require_any( - cls, - data: Any, - permissions: TRequireAnyPermissions, - resolvers: TRequireAnyResolvers, - *args, - **kwargs, + cls, + data: Any, + permissions: TRequireAnyPermissions, + resolvers: TRequireAnyResolvers, + *args, + **kwargs, ): info = args[0] if len(permissions) > 0: user = await Route.get_authenticated_user_or_api_key_or_default() if user is not None and all( - [await user.has_permission(x) for x in permissions] + [await user.has_permission(x) for x in permissions] ): return @@ -97,13 +97,13 @@ class QueryABC(ObjectType): raise AccessDenied() def field( - self, - builder: Union[ - DaoFieldBuilder, - CollectionFieldBuilder, - ResolverFieldBuilder, - MutationFieldBuilder, - ], + self, + builder: Union[ + DaoFieldBuilder, + CollectionFieldBuilder, + ResolverFieldBuilder, + MutationFieldBuilder, + ], ): """ Add a field to the query @@ -199,7 +199,7 @@ class QueryABC(ObjectType): elif isinstance(field, MutationField): async def input_wrapper( - mutation: QueryABC, info: GraphQLResolveInfo, **kwargs + mutation: QueryABC, info: GraphQLResolveInfo, **kwargs ): if field.input_type is None: return await resolver_wrapper(mutation, info, **kwargs) @@ -230,9 +230,9 @@ class QueryABC(ObjectType): await self._authorize() if ( - field.require_any is None - and not field.public - and field.require_any_permission + field.require_any is None + and not field.public + and field.require_any_permission ): await self._require_any_permission(field.require_any_permission) @@ -252,13 +252,13 @@ class QueryABC(ObjectType): @deprecated("Use field(FieldBuilder()) instead") def mutation( - self, - name: str, - f: Callable, - input_type: Type[InputABC] = None, - input_key: str = "input", - require_any_permission: list[Permissions] = None, - public: bool = False, + self, + name: str, + f: Callable, + input_type: Type[InputABC] = None, + input_key: str = "input", + require_any_permission: list[Permissions] = None, + public: bool = False, ): """ Adds a mutation to the query @@ -284,13 +284,13 @@ class QueryABC(ObjectType): @classmethod def _resolve_collection( - cls, - collection: list, - *_, - filters: list[CollectionFilterABC] = None, - sort: list[Sort] = None, - skip: int = None, - take: int = None, + cls, + collection: list, + *_, + filters: list[CollectionFilterABC] = None, + sort: list[Sort] = None, + skip: int = None, + take: int = None, ) -> CollectionResult: total_count = len(collection) @@ -313,7 +313,7 @@ class QueryABC(ObjectType): return attr for s in reversed( - sort + sort ): # Apply sorting in reverse order to make first primary "orderBy" and other secondary "thenBy" attrs = [a for a in dir(s) if not a.startswith("_")] for k in attrs: diff --git a/api/src/api_graphql/queries/group_query.py b/api/src/api_graphql/queries/group_query.py index 78b9c56..1e28ac6 100644 --- a/api/src/api_graphql/queries/group_query.py +++ b/api/src/api_graphql/queries/group_query.py @@ -16,9 +16,12 @@ class GroupQuery(DbModelQueryABC): self.field( ResolverFieldBuilder("shortUrls") .with_resolver(self._get_urls) - .with_require_any([ - Permissions.groups, - ], [group_by_assignment_resolver]) + .with_require_any( + [ + Permissions.groups, + ], + [group_by_assignment_resolver], + ) ) self.set_field("roles", self._get_roles) diff --git a/api/src/api_graphql/query.py b/api/src/api_graphql/query.py index cf16bdd..b3c01a1 100644 --- a/api/src/api_graphql/query.py +++ b/api/src/api_graphql/query.py @@ -120,7 +120,7 @@ class Query(QueryABC): Permissions.short_urls_create, Permissions.short_urls_update, ], - [group_by_assignment_resolver] + [group_by_assignment_resolver], ) ) self.field( diff --git a/api/src/api_graphql/require_any_resolvers.py b/api/src/api_graphql/require_any_resolvers.py index 81a98e4..0d31a90 100644 --- a/api/src/api_graphql/require_any_resolvers.py +++ b/api/src/api_graphql/require_any_resolvers.py @@ -12,11 +12,19 @@ async def group_by_assignment_resolver(ctx: QueryContext) -> bool: groups = [await x.group for x in ctx.data.nodes] role_ids = {x.id for x in await ctx.user.roles} filtered_groups = [ - g.id for g in groups if - g is not None and (roles := await groupDao.get_roles(g.id)) and all(r.id in role_ids for r in roles) + g.id + for g in groups + if g is not None + and (roles := await groupDao.get_roles(g.id)) + and all(r.id in role_ids for r in roles) ] - ctx.data.nodes = [node for node in ctx.data.nodes if (await node.group) is not None and (await node.group).id in filtered_groups] + ctx.data.nodes = [ + node + for node in ctx.data.nodes + if (await node.group) is not None + and (await node.group).id in filtered_groups + ] return True return True diff --git a/api/src/api_graphql/typing.py b/api/src/api_graphql/typing.py index bfc685f..61f7a7d 100644 --- a/api/src/api_graphql/typing.py +++ b/api/src/api_graphql/typing.py @@ -9,7 +9,7 @@ TRequireAnyResolvers = list[ Union[ Callable[[QueryContext], bool], Awaitable[[QueryContext], bool], - Callable[[QueryContext], Coroutine[Any, Any, bool]] + Callable[[QueryContext], Coroutine[Any, Any, bool]], ] ] TRequireAny = tuple[TRequireAnyPermissions, TRequireAnyResolvers] diff --git a/api/src/core/database/abc/data_access_object_abc.py b/api/src/core/database/abc/data_access_object_abc.py index 0a1f4dd..9c510bc 100644 --- a/api/src/core/database/abc/data_access_object_abc.py +++ b/api/src/core/database/abc/data_access_object_abc.py @@ -42,12 +42,12 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]): return self._table_name def attribute( - self, - attr_name: Attribute, - attr_type: type, - db_name: str = None, - ignore=False, - primary_key=False, + self, + attr_name: Attribute, + attr_type: type, + db_name: str = None, + ignore=False, + primary_key=False, ): """ Add an attribute for db and object mapping to the data access object @@ -77,11 +77,11 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]): self.__date_attributes.add(db_name) def reference( - self, - attr: Attribute, - primary_attr: Attribute, - foreign_attr: Attribute, - table_name: str, + self, + attr: Attribute, + primary_attr: Attribute, + foreign_attr: Attribute, + table_name: str, ): """ Add a reference to another table for the given attribute @@ -164,11 +164,11 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]): return self.to_object(result[0]) async def get_by( - self, - filters: AttributeFilters = None, - sorts: AttributeSorts = None, - take: int = None, - skip: int = None, + self, + filters: AttributeFilters = None, + sorts: AttributeSorts = None, + take: int = None, + skip: int = None, ) -> list[T_DBM]: """ Get all objects by the given filters @@ -189,11 +189,11 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]): return [self.to_object(x) for x in result] async def get_single_by( - self, - filters: AttributeFilters = None, - sorts: AttributeSorts = None, - take: int = None, - skip: int = None, + self, + filters: AttributeFilters = None, + sorts: AttributeSorts = None, + take: int = None, + skip: int = None, ) -> T_DBM: """ Get a single object by the given filters @@ -214,11 +214,11 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]): return result[0] async def find_by( - self, - filters: AttributeFilters = None, - sorts: AttributeSorts = None, - take: int = None, - skip: int = None, + self, + filters: AttributeFilters = None, + sorts: AttributeSorts = None, + take: int = None, + skip: int = None, ) -> list[Optional[T_DBM]]: """ Find all objects by the given filters @@ -238,11 +238,11 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]): return [self.to_object(x) for x in result] async def find_single_by( - self, - filters: AttributeFilters = None, - sorts: AttributeSorts = None, - take: int = None, - skip: int = None, + self, + filters: AttributeFilters = None, + sorts: AttributeSorts = None, + take: int = None, + skip: int = None, ) -> Optional[T_DBM]: """ Find a single object by the given filters @@ -342,7 +342,7 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]): await self._db.execute(query) async def _build_delete_statement( - self, obj: T_DBM, hard_delete: bool = False + self, obj: T_DBM, hard_delete: bool = False ) -> str: if hard_delete: return f""" @@ -458,11 +458,11 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]): return cast_type(value) def _build_conditional_query( - self, - filters: AttributeFilters = None, - sorts: AttributeSorts = None, - take: int = None, - skip: int = None, + self, + filters: AttributeFilters = None, + sorts: AttributeSorts = None, + take: int = None, + skip: int = None, ) -> str: query = f"SELECT {self._table_name}.* FROM {self._table_name}" @@ -506,7 +506,11 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]): " OR ".join( self._build_fuzzy_conditions( [ - self.__db_names[x] if x in self.__db_names else self.__db_names[camel_to_snake(x)] + ( + self.__db_names[x] + if x in self.__db_names + else self.__db_names[camel_to_snake(x)] + ) for x in get_value(values, "fields", list[str]) ], get_value(values, "term", str), @@ -546,7 +550,7 @@ class DataAccessObjectABC(ABC, Database, Generic[T_DBM]): return " AND ".join(conditions) def _build_fuzzy_conditions( - self, fields: list[str], term: str, threshold: int = 10 + self, fields: list[str], term: str, threshold: int = 10 ) -> list[str]: conditions = [] for field in fields: diff --git a/api/src/core/get_value.py b/api/src/core/get_value.py index 6acec5a..e0022d3 100644 --- a/api/src/core/get_value.py +++ b/api/src/core/get_value.py @@ -4,11 +4,11 @@ from core.typing import T def get_value( - source: dict, - key: str, - cast_type: Type[T], - default: Optional[T] = None, - list_delimiter: str = ",", + source: dict, + key: str, + cast_type: Type[T], + default: Optional[T] = None, + list_delimiter: str = ",", ) -> Optional[T]: """ Get value from source dictionary and cast it to a specified type. @@ -26,8 +26,8 @@ def get_value( value = source[key] if isinstance( - value, - cast_type if not hasattr(cast_type, "__origin__") else cast_type.__origin__, + value, + cast_type if not hasattr(cast_type, "__origin__") else cast_type.__origin__, ): return value @@ -36,10 +36,15 @@ def get_value( return value.lower() in ["true", "1"] if ( - cast_type if not hasattr(cast_type, "__origin__") else cast_type.__origin__ + cast_type if not hasattr(cast_type, "__origin__") else cast_type.__origin__ ) == list: - if not (value.startswith("[") and value.endswith("]")) and list_delimiter not in value: - raise ValueError("List values must be enclosed in square brackets or use a delimiter.") + if ( + not (value.startswith("[") and value.endswith("]")) + and list_delimiter not in value + ): + raise ValueError( + "List values must be enclosed in square brackets or use a delimiter." + ) if value.startswith("[") and value.endswith("]"): value = value[1:-1] diff --git a/api/src/core/string.py b/api/src/core/string.py index 6adbe99..91fa109 100644 --- a/api/src/core/string.py +++ b/api/src/core/string.py @@ -4,5 +4,6 @@ import re def first_to_lower(s: str) -> str: return s[0].lower() + s[1:] if s else s + def camel_to_snake(s: str) -> str: - return re.sub(r'(?