Added graphql from prototype #162

This commit is contained in:
Sven Heidemann 2023-01-15 02:28:28 +01:00
parent b95a951a1b
commit 95b9eea236
17 changed files with 134 additions and 181 deletions

View File

@ -16,7 +16,6 @@ from modules.technician.technician_module import TechnicianModule
class ModuleList: class ModuleList:
@staticmethod @staticmethod
def get_modules(): def get_modules():
# core modules (modules out of modules folder) should be loaded first! # core modules (modules out of modules folder) should be loaded first!

View File

@ -1,4 +1,3 @@
import re
import sys import sys
import textwrap import textwrap
import uuid import uuid
@ -16,29 +15,25 @@ from werkzeug.exceptions import NotFound
from bot_api.configuration.api_settings import ApiSettings from bot_api.configuration.api_settings import ApiSettings
from bot_api.configuration.authentication_settings import AuthenticationSettings from bot_api.configuration.authentication_settings import AuthenticationSettings
from bot_api.configuration.frontend_settings import FrontendSettings
from bot_api.exception.service_error_code_enum import ServiceErrorCode from bot_api.exception.service_error_code_enum import ServiceErrorCode
from bot_api.exception.service_exception import ServiceException from bot_api.exception.service_exception import ServiceException
from bot_api.logging.api_logger import ApiLogger from bot_api.logging.api_logger import ApiLogger
from bot_api.model.error_dto import ErrorDTO from bot_api.model.error_dto import ErrorDTO
from bot_api.route.route import Route from bot_api.route.route import Route
from bot_graphql.graphql_service import GraphQLService
class Api(Flask): class Api(Flask):
def __init__( def __init__(
self, self,
logger: ApiLogger, logger: ApiLogger,
services: ServiceProviderABC, services: ServiceProviderABC,
api_settings: ApiSettings, api_settings: ApiSettings,
frontend_settings: FrontendSettings,
auth_settings: AuthenticationSettings, auth_settings: AuthenticationSettings,
graphql: GraphQLService, *args,
*args, **kwargs **kwargs,
): ):
if not args: if not args:
kwargs.setdefault('import_name', __name__) kwargs.setdefault("import_name", __name__)
Flask.__init__(self, *args, **kwargs) Flask.__init__(self, *args, **kwargs)
@ -58,17 +53,21 @@ class Api(Flask):
self.register_error_handler(exc_class, self.handle_exception) self.register_error_handler(exc_class, self.handle_exception)
# websockets # websockets
self._socketio = SocketIO(self, cors_allowed_origins='*', path='/api/socket.io') self._socketio = SocketIO(self, cors_allowed_origins="*", path="/api/socket.io")
self._socketio.on_event('connect', self.on_connect) self._socketio.on_event("connect", self.on_connect)
self._socketio.on_event('disconnect', self.on_disconnect) self._socketio.on_event("disconnect", self.on_disconnect)
self._requests = {} self._requests = {}
@staticmethod @staticmethod
def _get_methods_from_registered_route() -> Union[list[str], str]: def _get_methods_from_registered_route() -> Union[list[str], str]:
methods = ['Unknown'] methods = ["Unknown"]
if request.path in Route.registered_routes and len(Route.registered_routes[request.path]) >= 1 and 'methods' in Route.registered_routes[request.path][1]: if (
methods = Route.registered_routes[request.path][1]['methods'] request.path in Route.registered_routes
and len(Route.registered_routes[request.path]) >= 1
and "methods" in Route.registered_routes[request.path][1]
):
methods = Route.registered_routes[request.path][1]["methods"]
if len(methods) == 1: if len(methods) == 1:
return methods[0] return methods[0]
@ -79,7 +78,7 @@ class Api(Flask):
route = f[0] route = f[0]
kwargs = f[1] kwargs = f[1]
cls = None cls = None
qual_name_split = route.__qualname__.split('.') qual_name_split = route.__qualname__.split(".")
if len(qual_name_split) > 0: if len(qual_name_split) > 0:
cls_type = vars(sys.modules[route.__module__])[qual_name_split[0]] cls_type = vars(sys.modules[route.__module__])[qual_name_split[0]]
cls = self._services.get_service(cls_type) cls = self._services.get_service(cls_type)
@ -89,7 +88,7 @@ class Api(Flask):
self.route(path, **kwargs)(partial_f) self.route(path, **kwargs)(partial_f)
def handle_exception(self, e: Exception): def handle_exception(self, e: Exception):
self._logger.error(__name__, f'Caught error', e) self._logger.error(__name__, f"Caught error", e)
if isinstance(e, ServiceException): if isinstance(e, ServiceException):
ex: ServiceException = e ex: ServiceException = e
@ -102,7 +101,7 @@ class Api(Flask):
return jsonify(error.to_dict()), 404 return jsonify(error.to_dict()), 404
else: else:
tracking_id = uuid.uuid4() tracking_id = uuid.uuid4()
user_message = f'Tracking Id: {tracking_id}' user_message = f"Tracking Id: {tracking_id}"
self._logger.error(__name__, user_message, e) self._logger.error(__name__, user_message, e)
error = ErrorDTO(None, user_message) error = ErrorDTO(None, user_message)
return jsonify(error.to_dict()), 400 return jsonify(error.to_dict()), 400
@ -112,47 +111,48 @@ class Api(Flask):
self._requests[request] = request_id self._requests[request] = request_id
method = request.access_control_request_method method = request.access_control_request_method
self._logger.info(__name__, f'Received {request_id} @ {self._get_methods_from_registered_route() if method is None else method} {request.url} from {request.remote_addr}') self._logger.info(
__name__,
f"Received {request_id} @ {self._get_methods_from_registered_route() if method is None else method} {request.url} from {request.remote_addr}",
)
headers = str(request.headers).replace('\n', '\n\t\t') headers = str(request.headers).replace("\n", "\n\t\t")
data = request.get_data() data = request.get_data()
data = '' if len(data) == 0 else str(data.decode(encoding="utf-8")) data = "" if len(data) == 0 else str(data.decode(encoding="utf-8"))
text = textwrap.dedent(f'Request: {request_id}:\n\tHeader:\n\t\t{headers}\n\tUser-Agent: {request.user_agent.string}\n\tBody: {data}') text = textwrap.dedent(
f"Request: {request_id}:\n\tHeader:\n\t\t{headers}\n\tUser-Agent: {request.user_agent.string}\n\tBody: {data}"
)
self._logger.trace(__name__, text) self._logger.trace(__name__, text)
def after_request_hook(self, response: Response): def after_request_hook(self, response: Response):
method = request.access_control_request_method method = request.access_control_request_method
request_id = f'{self._get_methods_from_registered_route() if method is None else method} {request.url} from {request.remote_addr}' request_id = f"{self._get_methods_from_registered_route() if method is None else method} {request.url} from {request.remote_addr}"
if request in self._requests: if request in self._requests:
request_id = self._requests[request] request_id = self._requests[request]
self._logger.info(__name__, f'Answered {request_id}') self._logger.info(__name__, f"Answered {request_id}")
headers = str(request.headers).replace('\n', '\n\t\t') headers = str(request.headers).replace("\n", "\n\t\t")
data = request.get_data() data = request.get_data()
data = '' if len(data) == 0 else str(data.decode(encoding="utf-8")) data = "" if len(data) == 0 else str(data.decode(encoding="utf-8"))
text = textwrap.dedent(f'Request: {request_id}:\n\tHeader:\n\t\t{headers}\n\tResponse: {data}') text = textwrap.dedent(f"Request: {request_id}:\n\tHeader:\n\t\t{headers}\n\tResponse: {data}")
self._logger.trace(__name__, text) self._logger.trace(__name__, text)
return response return response
def start(self): def start(self):
self._logger.info(__name__, f'Starting API {self._api_settings.host}:{self._api_settings.port}') self._logger.info(__name__, f"Starting API {self._api_settings.host}:{self._api_settings.port}")
self._register_routes() self._register_routes()
self.secret_key = CredentialManager.decrypt(self._auth_settings.secret_key) self.secret_key = CredentialManager.decrypt(self._auth_settings.secret_key)
# from waitress import serve # from waitress import serve
# https://docs.pylonsproject.org/projects/waitress/en/stable/arguments.html # https://docs.pylonsproject.org/projects/waitress/en/stable/arguments.html
# serve(self, host=self._apt_settings.host, port=self._apt_settings.port, threads=10, connection_limit=1000, channel_timeout=10) # serve(self, host=self._apt_settings.host, port=self._apt_settings.port, threads=10, connection_limit=1000, channel_timeout=10)
wsgi.server( wsgi.server(eventlet.listen((self._api_settings.host, self._api_settings.port)), self, log_output=False)
eventlet.listen((self._api_settings.host, self._api_settings.port)),
self,
log_output=False
)
def on_connect(self): def on_connect(self):
self._logger.info(__name__, f'Client connected') self._logger.info(__name__, f"Client connected")
def on_disconnect(self): def on_disconnect(self):
self._logger.info(__name__, f'Client disconnected') self._logger.info(__name__, f"Client disconnected")

View File

@ -11,7 +11,7 @@ from bot_graphql.schema import Schema
class GraphQLController: class GraphQLController:
BasePath = f'/api/graphql' BasePath = f"/api/graphql"
def __init__( def __init__(
self, self,
@ -25,20 +25,16 @@ class GraphQLController:
self._logger = logger self._logger = logger
self._schema = schema self._schema = schema
@Route.get(f'{BasePath}/playground') @Route.get(f"{BasePath}/playground")
async def playground(self): async def playground(self):
return PLAYGROUND_HTML, 200 return PLAYGROUND_HTML, 200
@Route.post(f'{BasePath}') @Route.post(f"{BasePath}")
async def graphql(self): async def graphql(self):
data = request.get_json() data = request.get_json()
# Note: Passing the request to the context is optional. # Note: Passing the request to the context is optional.
# In Flask, the current request is always accessible as flask.request # In Flask, the current request is always accessible as flask.request
success, result = graphql_sync( success, result = graphql_sync(self._schema.schema, data, context_value=request)
self._schema.schema,
data,
context_value=request
)
return jsonify(result), 200 if success else 400 return jsonify(result), 200 if success else 400

View File

@ -28,8 +28,9 @@ class ClientRepositoryService(ClientRepositoryABC):
self._logger.trace(__name__, f"Send SQL command: {Client.get_select_all_string()}") self._logger.trace(__name__, f"Send SQL command: {Client.get_select_all_string()}")
results = self._context.select(Client.get_select_all_string()) results = self._context.select(Client.get_select_all_string())
for result in results: for result in results:
self._logger.trace(__name__, f'Get client with id {result[0]}') self._logger.trace(__name__, f"Get client with id {result[0]}")
clients.append(Client( clients.append(
Client(
result[1], result[1],
result[2], result[2],
result[3], result[3],
@ -39,8 +40,9 @@ class ClientRepositoryService(ClientRepositoryABC):
self._servers.get_server_by_id(result[7]), self._servers.get_server_by_id(result[7]),
result[8], result[8],
result[9], result[9],
id=result[0] id=result[0],
)) )
)
return clients return clients
@ -57,7 +59,7 @@ class ClientRepositoryService(ClientRepositoryABC):
self._servers.get_server_by_id(result[7]), self._servers.get_server_by_id(result[7]),
result[8], result[8],
result[9], result[9],
id=result[0] id=result[0],
) )
def get_client_by_discord_id(self, discord_id: int) -> Client: def get_client_by_discord_id(self, discord_id: int) -> Client:
@ -76,7 +78,7 @@ class ClientRepositoryService(ClientRepositoryABC):
self._servers.get_server_by_id(result[7]), self._servers.get_server_by_id(result[7]),
result[8], result[8],
result[9], result[9],
id=result[0] id=result[0],
) )
def find_client_by_discord_id(self, discord_id: int) -> Optional[Client]: def find_client_by_discord_id(self, discord_id: int) -> Optional[Client]:
@ -100,7 +102,7 @@ class ClientRepositoryService(ClientRepositoryABC):
self._servers.get_server_by_id(result[7]), self._servers.get_server_by_id(result[7]),
result[8], result[8],
result[9], result[9],
id=result[0] id=result[0],
) )
def find_client_by_server_id(self, discord_id: int) -> Optional[Client]: def find_client_by_server_id(self, discord_id: int) -> Optional[Client]:
@ -124,7 +126,7 @@ class ClientRepositoryService(ClientRepositoryABC):
self._servers.get_server_by_id(result[7]), self._servers.get_server_by_id(result[7]),
result[8], result[8],
result[9], result[9],
id=result[0] id=result[0],
) )
def find_client_by_discord_id_and_server_id(self, discord_id: int, server_id: int) -> Optional[Client]: def find_client_by_discord_id_and_server_id(self, discord_id: int, server_id: int) -> Optional[Client]:
@ -148,7 +150,7 @@ class ClientRepositoryService(ClientRepositoryABC):
self._servers.get_server_by_id(result[7]), self._servers.get_server_by_id(result[7]),
result[8], result[8],
result[9], result[9],
id=result[0] id=result[0],
) )
def add_client(self, client: Client): def add_client(self, client: Client):
@ -166,13 +168,13 @@ class ClientRepositoryService(ClientRepositoryABC):
def _get_client_and_server(self, id: int, server_id: int) -> Client: def _get_client_and_server(self, id: int, server_id: int) -> Client:
server = self._servers.find_server_by_discord_id(server_id) server = self._servers.find_server_by_discord_id(server_id)
if server is None: if server is None:
self._logger.warn(__name__, f'Cannot find server by id {server_id}') self._logger.warn(__name__, f"Cannot find server by id {server_id}")
raise Exception('Value not found') raise Exception("Value not found")
client = self.find_client_by_discord_id_and_server_id(id, server.server_id) client = self.find_client_by_discord_id_and_server_id(id, server.server_id)
if client is None: if client is None:
self._logger.warn(__name__, f'Cannot find client by ids {id}@{server.server_id}') self._logger.warn(__name__, f"Cannot find client by ids {id}@{server.server_id}")
raise Exception('Value not found') raise Exception("Value not found")
return client return client

View File

@ -22,12 +22,7 @@ class ServerRepositoryService(ServerRepositoryABC):
self._logger.trace(__name__, f"Send SQL command: {Server.get_select_all_string()}") self._logger.trace(__name__, f"Send SQL command: {Server.get_select_all_string()}")
results = self._context.select(Server.get_select_all_string()) results = self._context.select(Server.get_select_all_string())
for result in results: for result in results:
servers.append(Server( servers.append(Server(result[1], result[2], result[3], id=result[0]))
result[1],
result[2],
result[3],
id=result[0]
))
return servers return servers
@ -59,12 +54,7 @@ class ServerRepositoryService(ServerRepositoryABC):
def get_server_by_id(self, server_id: int) -> Server: def get_server_by_id(self, server_id: int) -> Server:
self._logger.trace(__name__, f"Send SQL command: {Server.get_select_by_id_string(server_id)}") self._logger.trace(__name__, f"Send SQL command: {Server.get_select_by_id_string(server_id)}")
result = self._context.select(Server.get_select_by_id_string(server_id))[0] result = self._context.select(Server.get_select_by_id_string(server_id))[0]
return Server( return Server(result[1], result[2], result[3], id=result[0])
result[1],
result[2],
result[3],
id=result[0]
)
def get_server_by_discord_id(self, discord_id: int) -> Server: def get_server_by_discord_id(self, discord_id: int) -> Server:
self._logger.trace( self._logger.trace(
@ -72,12 +62,7 @@ class ServerRepositoryService(ServerRepositoryABC):
f"Send SQL command: {Server.get_select_by_discord_id_string(discord_id)}", f"Send SQL command: {Server.get_select_by_discord_id_string(discord_id)}",
) )
result = self._context.select(Server.get_select_by_discord_id_string(discord_id))[0] result = self._context.select(Server.get_select_by_discord_id_string(discord_id))[0]
return Server( return Server(result[1], result[2], result[3], id=result[0])
result[1],
result[2],
result[3],
id=result[0]
)
def find_server_by_discord_id(self, discord_id: int) -> Optional[Server]: def find_server_by_discord_id(self, discord_id: int) -> Optional[Server]:
self._logger.trace( self._logger.trace(
@ -90,12 +75,7 @@ class ServerRepositoryService(ServerRepositoryABC):
result = result[0] result = result[0]
return Server( return Server(result[1], result[2], result[3], id=result[0])
result[1],
result[2],
result[3],
id=result[0]
)
def add_server(self, server: Server): def add_server(self, server: Server):
self._logger.trace(__name__, f"Send SQL command: {server.insert_string}") self._logger.trace(__name__, f"Send SQL command: {server.insert_string}")

View File

@ -4,12 +4,11 @@ from bot_graphql.abc.query_abc import QueryABC
class DataQueryABC(QueryABC): class DataQueryABC(QueryABC):
def __init__(self, name: str): def __init__(self, name: str):
QueryABC.__init__(self, name) QueryABC.__init__(self, name)
self.set_field('created_at', self.resolve_created_at) self.set_field("created_at", self.resolve_created_at)
self.set_field('modified_at', self.resolve_modified_at) self.set_field("modified_at", self.resolve_modified_at)
@staticmethod @staticmethod
def resolve_created_at(entry: TableABC, *_): def resolve_created_at(entry: TableABC, *_):

View File

@ -7,7 +7,6 @@ from cpl_query.extension import List
class FilterABC(ABC): class FilterABC(ABC):
def __init__(self): def __init__(self):
ABC.__init__(self) ABC.__init__(self)
@ -45,7 +44,7 @@ class FilterABC(ABC):
sig = signature(f) sig = signature(f)
for param in sig.parameters.items(): for param in sig.parameters.items():
parameter = param[1] parameter = param[1]
if parameter.name == 'self' or parameter.name == 'cls' or parameter.annotation == Parameter.empty: if parameter.name == "self" or parameter.name == "cls" or parameter.annotation == Parameter.empty:
continue continue
if issubclass(parameter.annotation, FilterABC): if issubclass(parameter.annotation, FilterABC):
@ -60,8 +59,8 @@ class FilterABC(ABC):
@functools.wraps(f) @functools.wraps(f)
def decorator(*args, **kwargs): def decorator(*args, **kwargs):
if 'filter' in kwargs: if "filter" in kwargs:
kwargs['filter'] = cls.get_filter(f, kwargs['filter']) kwargs["filter"] = cls.get_filter(f, kwargs["filter"])
return f(*args, **kwargs) return f(*args, **kwargs)

View File

@ -9,7 +9,6 @@ from bot_graphql.abc.filter_abc import FilterABC
class LevelFilter(FilterABC): class LevelFilter(FilterABC):
def __init__(self): def __init__(self):
FilterABC.__init__(self) FilterABC.__init__(self)
@ -18,8 +17,8 @@ class LevelFilter(FilterABC):
# self._server_id = None # self._server_id = None
def from_dict(self, values: dict): def from_dict(self, values: dict):
if 'id' in values: if "id" in values:
self._id = values['id'] self._id = values["id"]
def filter(self, query: List[Level]) -> List[Level]: def filter(self, query: List[Level]) -> List[Level]:
if self._id is not None: if self._id is not None:

View File

@ -8,7 +8,6 @@ from bot_graphql.abc.filter_abc import FilterABC
class ServerFilter(FilterABC): class ServerFilter(FilterABC):
def __init__(self): def __init__(self):
FilterABC.__init__(self) FilterABC.__init__(self)
@ -17,8 +16,8 @@ class ServerFilter(FilterABC):
self._name = None self._name = None
def from_dict(self, values: dict): def from_dict(self, values: dict):
if 'id' in values: if "id" in values:
self._id = int(values['id']) self._id = int(values["id"])
@ServiceProviderABC.inject @ServiceProviderABC.inject
def filter(self, query: List[Server], bot: DiscordBotServiceABC) -> List[Server]: def filter(self, query: List[Server], bot: DiscordBotServiceABC) -> List[Server]:
@ -29,9 +28,12 @@ class ServerFilter(FilterABC):
query = query.where(lambda x: x.discord_server_id == self._discord_id) query = query.where(lambda x: x.discord_server_id == self._discord_id)
if self._name is not None: if self._name is not None:
def where_guild(x: Guild): def where_guild(x: Guild):
guild = bot.get_guild(x.discord_server_id) guild = bot.get_guild(x.discord_server_id)
return guild is not None and (self._name.lower() == guild.name.lower() or self._name.lower() in guild.name.lower()) return guild is not None and (
self._name.lower() == guild.name.lower() or self._name.lower() in guild.name.lower()
)
query = query.where(where_guild) query = query.where(where_guild)

View File

@ -17,7 +17,6 @@ from bot_graphql.schema import Schema
class GraphQLModule(ModuleABC): class GraphQLModule(ModuleABC):
def __init__(self, dc: DiscordCollectionABC): def __init__(self, dc: DiscordCollectionABC):
ModuleABC.__init__(self, dc, FeatureFlagsEnum.data_module) ModuleABC.__init__(self, dc, FeatureFlagsEnum.data_module)

View File

@ -2,6 +2,5 @@ from bot_graphql.abc.query_abc import QueryABC
class GraphQLService: class GraphQLService:
def __init__(self, queries: list[QueryABC]): def __init__(self, queries: list[QueryABC]):
self._queries = queries self._queries = queries

View File

@ -4,15 +4,11 @@ from bot_graphql.mutations.level_mutation import LevelMutation
class Mutation(MutationType): class Mutation(MutationType):
def __init__(self, level_mutation: LevelMutation):
def __init__(
self,
level_mutation: LevelMutation
):
MutationType.__init__(self) MutationType.__init__(self)
self._level_mutation = level_mutation self._level_mutation = level_mutation
self.set_field('level', self.resolve_level) self.set_field("level", self.resolve_level)
def resolve_level(self, *_): def resolve_level(self, *_):
return self._level_mutation return self._level_mutation

View File

@ -5,28 +5,23 @@ from bot_graphql.abc.query_abc import QueryABC
class LevelMutation(QueryABC): class LevelMutation(QueryABC):
def __init__(self, servers: ServerRepositoryABC, levels: LevelRepositoryABC):
def __init__( QueryABC.__init__(self, "LevelMutation")
self,
servers: ServerRepositoryABC,
levels: LevelRepositoryABC
):
QueryABC.__init__(self, 'LevelMutation')
self._servers = servers self._servers = servers
self._levels = levels self._levels = levels
self.set_field('create_level', self.resolve_create_level) self.set_field("create_level", self.resolve_create_level)
self.set_field('update_level', self.resolve_create_level) self.set_field("update_level", self.resolve_create_level)
self.set_field('delete_level', self.resolve_create_level) self.set_field("delete_level", self.resolve_create_level)
def resolve_create_level(self, *_, input: dict): def resolve_create_level(self, *_, input: dict):
level = Level( level = Level(
input['name'], input["name"],
input['color'], input["color"],
int(input['min_xp']), int(input["min_xp"]),
int(input['permissions']), int(input["permissions"]),
self._servers.get_server_by_id(input['server_id']) self._servers.get_server_by_id(input["server_id"]),
) )
return level return level

View File

@ -3,16 +3,15 @@ from bot_graphql.abc.data_query_abc import DataQueryABC
class LevelQuery(DataQueryABC): class LevelQuery(DataQueryABC):
def __init__(self): def __init__(self):
DataQueryABC.__init__(self, 'Level') DataQueryABC.__init__(self, "Level")
self.set_field('id', self.resolve_id) self.set_field("id", self.resolve_id)
self.set_field('name', self.resolve_name) self.set_field("name", self.resolve_name)
self.set_field('color', self.resolve_color) self.set_field("color", self.resolve_color)
self.set_field('min_xp', self.resolve_min_xp) self.set_field("min_xp", self.resolve_min_xp)
self.set_field('permissions', self.resolve_permissions) self.set_field("permissions", self.resolve_permissions)
self.set_field('server', self.resolve_server) self.set_field("server", self.resolve_server)
@staticmethod @staticmethod
def resolve_id(level: Level, *_): def resolve_id(level: Level, *_):

View File

@ -8,21 +8,20 @@ from bot_graphql.filter.level_filter import LevelFilter
class ServerQuery(DataQueryABC): class ServerQuery(DataQueryABC):
def __init__( def __init__(
self, self,
bot: DiscordBotServiceABC, bot: DiscordBotServiceABC,
levels: LevelRepositoryABC, levels: LevelRepositoryABC,
): ):
DataQueryABC.__init__(self, 'Server') DataQueryABC.__init__(self, "Server")
self._bot = bot self._bot = bot
self._levels = levels self._levels = levels
self.set_field('id', self.resolve_id) self.set_field("id", self.resolve_id)
self.set_field('discord_id', self.resolve_discord_id) self.set_field("discord_id", self.resolve_discord_id)
self.set_field('name', self.resolve_name) self.set_field("name", self.resolve_name)
self.set_field('levels', self.resolve_levels) self.set_field("levels", self.resolve_levels)
@staticmethod @staticmethod
def resolve_id(server: Server, *_): def resolve_id(server: Server, *_):

View File

@ -6,16 +6,12 @@ from bot_graphql.filter.server_filter import ServerFilter
class Query(QueryType): class Query(QueryType):
def __init__(self, servers: ServerRepositoryService):
def __init__(
self,
servers: ServerRepositoryService
):
QueryType.__init__(self) QueryType.__init__(self)
self._servers = servers self._servers = servers
self.set_field('servers', self.resolve_servers) self.set_field("servers", self.resolve_servers)
self.set_field('server_count', self.resolve_server_count) self.set_field("server_count", self.resolve_server_count)
@FilterABC.resolve_filter_annotation @FilterABC.resolve_filter_annotation
def resolve_servers(self, *_, filter: ServerFilter = None): def resolve_servers(self, *_, filter: ServerFilter = None):

View File

@ -9,14 +9,8 @@ from bot_graphql.query import Query
class Schema: class Schema:
def __init__(self, query: Query, mutation: Mutation, queries: list[QueryABC]):
def __init__( type_defs = load_schema_from_path(os.path.join(os.path.dirname(os.path.realpath(__file__)), "model.gql"))
self,
query: Query,
mutation: Mutation,
queries: list[QueryABC]
):
type_defs = load_schema_from_path(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'model.gql'))
self._schema = make_executable_schema(type_defs, query, mutation, *queries) self._schema = make_executable_schema(type_defs, query, mutation, *queries)
@property @property