diff --git a/example/general/src/application.py b/example/general/src/application.py index e901097f..5447c054 100644 --- a/example/general/src/application.py +++ b/example/general/src/application.py @@ -7,6 +7,7 @@ from cpl.core.environment import Environment from cpl.core.log import LoggerABC from cpl.core.pipes import IPAddressPipe from cpl.dependency import ServiceProvider +from cpl.dependency.typing import Modules from cpl.mail import EMail, EMailClientABC from cpl.query import List from scoped_service import ScopedService @@ -16,8 +17,8 @@ from test_settings import TestSettings class Application(ApplicationABC): - def __init__(self, services: ServiceProvider): - ApplicationABC.__init__(self, services) + def __init__(self, services: ServiceProvider, modules: Modules): + ApplicationABC.__init__(self, services, modules) self._logger = self._services.get_service(LoggerABC) self._mailer = self._services.get_service(EMailClientABC) diff --git a/src/cpl-application/cpl/application/abc/application_abc.py b/src/cpl-application/cpl/application/abc/application_abc.py index 2ed5c342..59c43b88 100644 --- a/src/cpl-application/cpl/application/abc/application_abc.py +++ b/src/cpl-application/cpl/application/abc/application_abc.py @@ -114,6 +114,9 @@ class ApplicationABC(ABC): Host.run_app(self.main) except KeyboardInterrupt: pass + finally: + logger = self._services.get_service(LoggerABC) + logger.info("Application shutdown") @abstractmethod def main(self): ... diff --git a/src/cpl-application/cpl/application/application_builder.py b/src/cpl-application/cpl/application/application_builder.py index 7abefddc..3d3d1529 100644 --- a/src/cpl-application/cpl/application/application_builder.py +++ b/src/cpl-application/cpl/application/application_builder.py @@ -6,7 +6,6 @@ from cpl.application.abc.application_extension_abc import ApplicationExtensionAB from cpl.application.abc.startup_abc import StartupABC from cpl.application.abc.startup_extension_abc import StartupExtensionABC from cpl.application.host import Host -from cpl.core.errors import dependency_error from cpl.dependency.context import get_provider, use_root_provider from cpl.dependency.service_collection import ServiceCollection diff --git a/src/cpl-core/cpl/core/log/logger.py b/src/cpl-core/cpl/core/log/logger.py index 31a0a707..117bb354 100644 --- a/src/cpl-core/cpl/core/log/logger.py +++ b/src/cpl-core/cpl/core/log/logger.py @@ -93,14 +93,13 @@ class Logger(LoggerABC): def _log(self, level: LogLevel, *messages: Messages): try: timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") - formatted_message = self._format_message(level.value, timestamp, *messages) - self._write_log_to_file(level, formatted_message) - self._write_to_console(level, formatted_message) + self._write_log_to_file(level, self._file_format_message(level.value, timestamp, *messages)) + self._write_to_console(level, self._console_format_message(level.value, timestamp, *messages)) except Exception as e: print(f"Error while logging: {e} -> {traceback.format_exc()}") - def _format_message(self, level: str, timestamp, *messages: Messages) -> str: + def _file_format_message(self, level: str, timestamp, *messages: Messages) -> str: if isinstance(messages, tuple): messages = list(messages) @@ -119,6 +118,24 @@ class Logger(LoggerABC): return message + def _console_format_message(self, level: str, timestamp, *messages: Messages) -> str: + if isinstance(messages, tuple): + messages = list(messages) + + if not isinstance(messages, list): + messages = [messages] + + messages = [str(message) for message in messages if message is not None] + + message = f"[{level.upper():^3}]" + message += f" [{self._file_prefix}]" + if self._source is not None: + message += f" - [{self._source}]" + + message += f": {' '.join(messages)}" + + return message + def header(self, string: str): self._log(LogLevel.info, string) diff --git a/src/cpl-core/cpl/core/log/logger_abc.py b/src/cpl-core/cpl/core/log/logger_abc.py index f4efb608..f0df5066 100644 --- a/src/cpl-core/cpl/core/log/logger_abc.py +++ b/src/cpl-core/cpl/core/log/logger_abc.py @@ -11,7 +11,10 @@ class LoggerABC(ABC): def set_level(self, level: LogLevel): ... @abstractmethod - def _format_message(self, level: str, timestamp, *messages: Messages) -> str: ... + def _file_format_message(self, level: str, timestamp, *messages: Messages) -> str: ... + + @abstractmethod + def _console_format_message(self, level: str, timestamp, *messages: Messages) -> str: ... @abstractmethod def header(self, string: str): diff --git a/src/cpl-core/cpl/core/log/structured_logger.py b/src/cpl-core/cpl/core/log/structured_logger.py index 41ae66de..2d1b9eca 100644 --- a/src/cpl-core/cpl/core/log/structured_logger.py +++ b/src/cpl-core/cpl/core/log/structured_logger.py @@ -1,15 +1,13 @@ import asyncio import importlib.util import json -import traceback from datetime import datetime from starlette.requests import Request -from cpl.core.log.log_level import LogLevel from cpl.core.log.logger import Logger from cpl.core.typing import Source, Messages -from cpl.dependency import get_provider +from cpl.dependency.context import get_provider class StructuredLogger(Logger): @@ -21,18 +19,7 @@ class StructuredLogger(Logger): def log_file(self): return f"logs/{self._file_prefix}_{datetime.now().strftime('%Y-%m-%d')}.jsonl" - def _log(self, level: LogLevel, *messages: Messages): - try: - timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") - formatted_message = self._format_message(level.value, timestamp, *messages) - structured_message = self._get_structured_message(level.value, timestamp, formatted_message) - - self._write_log_to_file(level, structured_message) - self._write_to_console(level, formatted_message) - except Exception as e: - print(f"Error while logging: {e} -> {traceback.format_exc()}") - - def _get_structured_message(self, level: str, timestamp: str, messages: str) -> str: + def _file_format_message(self, level: str, timestamp: str, *messages: Messages) -> str: structured_message = { "timestamp": timestamp, "level": level.upper(), diff --git a/src/cpl-core/cpl/core/log/wrapped_logger.py b/src/cpl-core/cpl/core/log/wrapped_logger.py index e137637c..08441009 100644 --- a/src/cpl-core/cpl/core/log/wrapped_logger.py +++ b/src/cpl-core/cpl/core/log/wrapped_logger.py @@ -1,7 +1,7 @@ import inspect from typing import Type -from cpl.core.log import LoggerABC, LogLevel +from cpl.core.log import LoggerABC, LogLevel, StructuredLogger from cpl.core.typing import Messages from cpl.dependency.inject import inject from cpl.dependency.service_provider import ServiceProvider @@ -31,8 +31,11 @@ class WrappedLogger(LoggerABC): def set_level(self, level: LogLevel): self._logger.set_level(level) - def _format_message(self, level: str, timestamp, *messages: Messages) -> str: - return self._logger._format_message(level, timestamp, *messages) + def _file_format_message(self, level: str, timestamp, *messages: Messages) -> str: + return self._logger._file_format_message(level, timestamp, *messages) + + def _console_format_message(self, level: str, timestamp, *messages: Messages) -> str: + return self._logger._console_format_message(level, timestamp, *messages) @staticmethod def _get_source() -> str | None: @@ -48,6 +51,7 @@ class WrappedLogger(LoggerABC): ServiceCollection, WrappedLogger, WrappedLogger.__subclasses__(), + StructuredLogger, ] ignore_modules = [x.__module__ for x in ignore_classes if isinstance(x, type)] diff --git a/src/cpl-database/cpl/database/model/database_settings.py b/src/cpl-database/cpl/database/model/database_settings.py index e3862bea..ccf1ad44 100644 --- a/src/cpl-database/cpl/database/model/database_settings.py +++ b/src/cpl-database/cpl/database/model/database_settings.py @@ -21,4 +21,4 @@ class DatabaseSettings(ConfigurationModelABC): self.option("use_unicode", bool, False) self.option("buffered", bool, False) self.option("auth_plugin", str, "caching_sha2_password") - self.option("ssl_disabled", bool, False) + self.option("ssl_disabled", bool, True) diff --git a/src/cpl-database/cpl/database/mysql/mysql_pool.py b/src/cpl-database/cpl/database/mysql/mysql_pool.py index fe8110f0..a5422761 100644 --- a/src/cpl-database/cpl/database/mysql/mysql_pool.py +++ b/src/cpl-database/cpl/database/mysql/mysql_pool.py @@ -22,27 +22,27 @@ class MySQLPool: "use_unicode": database_settings.use_unicode, "buffered": database_settings.buffered, "auth_plugin": database_settings.auth_plugin, - "ssl_disabled": False, + "ssl_disabled": database_settings.ssl_disabled, } self._pool: Optional[MySQLConnectionPool] = None async def _get_pool(self): if self._pool is None: - self._pool = MySQLConnectionPool( - pool_name="mypool", pool_size=Environment.get("DB_POOL_SIZE", int, 1), **self._dbconfig - ) - await self._pool.initialize_pool() - - con = await self._pool.get_connection() try: + self._pool = MySQLConnectionPool( + pool_name="mypool", pool_size=Environment.get("DB_POOL_SIZE", int, 1), **self._dbconfig + ) + await self._pool.initialize_pool() + + con = await self._pool.get_connection() async with await con.cursor() as cursor: await cursor.execute("SELECT 1") await cursor.fetchall() + + await con.close() except Exception as e: logger = get_provider().get_service(DBLogger) - logger.fatal(f"Error connecting to the database: {e}") - finally: - await con.close() + logger.fatal(f"Error connecting to the database", e) return self._pool diff --git a/src/cpl-database/cpl/database/postgres/postgres_pool.py b/src/cpl-database/cpl/database/postgres/postgres_pool.py index 19cdd656..891fb7f1 100644 --- a/src/cpl-database/cpl/database/postgres/postgres_pool.py +++ b/src/cpl-database/cpl/database/postgres/postgres_pool.py @@ -7,7 +7,7 @@ from psycopg_pool import AsyncConnectionPool, PoolTimeout from cpl.core.environment import Environment from cpl.database.logger import DBLogger from cpl.database.model import DatabaseSettings -from cpl.dependency import ServiceProvider +from cpl.dependency.context import get_provider class PostgresPool: @@ -31,15 +31,16 @@ class PostgresPool: pool = AsyncConnectionPool( conninfo=self._conninfo, open=False, min_size=1, max_size=Environment.get("DB_POOL_SIZE", int, 1) ) - await pool.open() try: + await pool.open() async with pool.connection() as con: await pool.check_connection(con) + + self._pool = pool except PoolTimeout as e: await pool.close() logger = get_provider().get_service(DBLogger) logger.fatal(f"Failed to connect to the database", e) - self._pool = pool return self._pool diff --git a/src/cpl-dependency/cpl/dependency/service_provider.py b/src/cpl-dependency/cpl/dependency/service_provider.py index dc3250ab..23a4216d 100644 --- a/src/cpl-dependency/cpl/dependency/service_provider.py +++ b/src/cpl-dependency/cpl/dependency/service_provider.py @@ -1,4 +1,5 @@ import copy +import inspect import typing from contextlib import contextmanager from inspect import signature, Parameter, Signature @@ -77,6 +78,35 @@ class ServiceProvider: return implementations + def _get_source(self): + stack = inspect.stack() + if len(stack) <= 1: + return None + + from cpl.dependency.service_collection import ServiceCollection + + ignore_classes = [ + ServiceProvider, + ServiceProvider.__subclasses__(), + ServiceCollection, + ] + + ignore_modules = [x.__module__ for x in ignore_classes if isinstance(x, type)] + + for i, frame_info in enumerate(stack[1:]): + module = inspect.getmodule(frame_info.frame) + if module is None: + continue + + if module.__name__ in ignore_classes or module in ignore_classes: + continue + + if module in ignore_modules or module.__name__ in ignore_modules: + continue + + if module.__name__ != __name__: + return module.__name__ + def _build_by_signature(self, sig: Signature, origin_service_type: type = None) -> list[T]: params = [] for param in sig.parameters.items(): @@ -88,7 +118,11 @@ class ServiceProvider: ) elif parameter.annotation == Source: - params.append(origin_service_type.__name__) + params.append( + origin_service_type.__name__ + if inspect.isclass(origin_service_type) + else str(origin_service_type) + ) elif issubclass(parameter.annotation, ServiceProvider): params.append(self) @@ -116,6 +150,9 @@ class ServiceProvider: else: service_type = descriptor.service_type + if origin_service_type is None: + origin_service_type = self._get_source() + if origin_service_type is None: origin_service_type = service_type