From 5beb1fadf0a61105b015f3dbe901ee719433cda4 Mon Sep 17 00:00:00 2001 From: AmineDiro Date: Fri, 4 Oct 2024 11:34:58 +0200 Subject: [PATCH] feat: structlog parseable (#3319) # Description - bettter logging --------- Co-authored-by: Stan Girard --- .env.example | 7 + .gitignore | 2 +- backend/api/quivr_api/logger.py | 264 ++++++++++++++++-- backend/api/quivr_api/main.py | 34 +-- .../quivr_api/middlewares/auth/auth_bearer.py | 10 +- .../middlewares/logging_middleware.py | 95 +++++++ backend/api/quivr_api/models/settings.py | 24 +- backend/api/quivr_api/modules/dependencies.py | 2 +- .../modules/misc/controller/misc_routes.py | 9 +- backend/api/quivr_api/utils/__init__.py | 2 - .../utils/handle_request_validation_error.py | 24 -- backend/pyproject.toml | 5 +- backend/requirements-dev.lock | 6 +- backend/requirements.lock | 6 +- backend/worker/quivr_worker/celery_monitor.py | 5 +- backend/worker/quivr_worker/celery_worker.py | 18 +- 16 files changed, 405 insertions(+), 108 deletions(-) create mode 100644 backend/api/quivr_api/middlewares/logging_middleware.py delete mode 100644 backend/api/quivr_api/utils/handle_request_validation_error.py diff --git a/.env.example b/.env.example index d69382ebe..4ae028019 100644 --- a/.env.example +++ b/.env.example @@ -64,6 +64,13 @@ BACKEND_URL=http://localhost:5050 EMBEDDING_DIM=1536 DEACTIVATE_STRIPE=true + +# PARSEABLE LOGGING +USE_PARSEABLE=False +PARSEABLE_STREAM_NAME=quivr-api +PARSEABLE_URL= +PARSEABLE_AUTH= + #RESEND RESEND_API_KEY= RESEND_EMAIL_ADDRESS=onboarding@resend.dev diff --git a/.gitignore b/.gitignore index c034793f4..89fbda293 100644 --- a/.gitignore +++ b/.gitignore @@ -80,7 +80,7 @@ paulgraham.py .env_test supabase/seed-airwallex.sql airwallexpayouts.py -application.log +**/application.log* backend/celerybeat-schedule.db backend/application.log.* diff --git a/backend/api/quivr_api/logger.py b/backend/api/quivr_api/logger.py index b839e9aef..5b12765c1 100644 --- a/backend/api/quivr_api/logger.py +++ b/backend/api/quivr_api/logger.py @@ -1,45 +1,247 @@ import logging import os +import queue +import sys +import threading from logging.handlers import RotatingFileHandler +from typing import List -from colorlog import ( - ColoredFormatter, -) +import orjson +import requests +import structlog + +from quivr_api.models.settings import parseable_settings + +# Thread-safe queue for log messages +log_queue = queue.Queue() +stop_log_queue = threading.Event() -def get_logger(logger_name, log_file="application.log"): - log_level = os.getenv("LOG_LEVEL", "WARNING").upper() - logger = logging.getLogger(logger_name) - logger.setLevel(log_level) - logger.propagate = False # Prevent log propagation to avoid double logging +class ParseableLogHandler(logging.Handler): + def __init__( + self, + base_parseable_url: str, + auth_token: str, + stream_name: str, + batch_size: int = 10, + flush_interval: float = 1, + ): + super().__init__() + self.base_url = base_parseable_url + self.stream_name = stream_name + self.url = self.base_url + self.stream_name + self.batch_size = batch_size + self.flush_interval = flush_interval + self._worker_thread = threading.Thread(target=self._process_log_queue) + self._worker_thread.daemon = True + self._worker_thread.start() + self.headers = { + "Authorization": f"Basic {auth_token}", # base64 encoding user:mdp + "Content-Type": "application/json", + } - formatter = logging.Formatter( - "[%(levelname)s] %(name)s [%(filename)s:%(lineno)d]: %(message)s" + def emit(self, record: logging.LogRecord): + # FIXME (@AmineDiro): This ping-pong of serialization/deserialization is a limitation of logging formatter + # The formatter should return a 'str' for the logger to print + if isinstance(record.msg, str): + return + elif isinstance(record.msg, dict): + logger_name = record.msg.get("logger", None) + if logger_name and ( + logger_name.startswith("quivr_api.access") + or logger_name.startswith("quivr_api.error") + ): + url = record.msg.get("url", None) + # Filter on healthz + if url and "healthz" not in url: + fmt = orjson.loads(self.format(record)) + log_queue.put(fmt) + else: + return + + def _process_log_queue(self): + """Background thread that processes the log queue and sends logs to Parseable.""" + logs_batch = [] + while not stop_log_queue.is_set(): + try: + # Collect logs for batch processing + log_data = log_queue.get(timeout=self.flush_interval) + logs_batch.append(log_data) + + # Send logs if batch size is reached + if len(logs_batch) >= self.batch_size: + self._send_logs_to_parseable(logs_batch) + logs_batch.clear() + + except queue.Empty: + # If the queue is empty, send any remaining logs + if logs_batch: + self._send_logs_to_parseable(logs_batch) + logs_batch.clear() + + def _send_logs_to_parseable(self, logs: List[str]): + payload = orjson.dumps(logs) + try: + response = requests.post(self.url, headers=self.headers, data=payload) + if response.status_code != 200: + print(f"Failed to send logs to Parseable server: {response.text}") + except Exception as e: + print(f"Error sending logs to Parseable: {e}") + + def stop(self): + """Stop the background worker thread and process any remaining logs.""" + stop_log_queue.set() + self._worker_thread.join() + # Process remaining logs before shutting down + remaining_logs = list(log_queue.queue) + if remaining_logs: + self._send_logs_to_parseable(remaining_logs) + + +def extract_from_record(_, __, event_dict): + """ + Extract thread and process names and add them to the event dict. + """ + record = event_dict["_record"] + event_dict["thread_name"] = record.threadName + event_dict["process_name"] = record.processName + return event_dict + + +def drop_http_context(_, __, event_dict): + """ + Extract thread and process names and add them to the event dict. + """ + keys = ["msg", "logger", "level", "timestamp", "exc_info"] + return {k: event_dict.get(k, None) for k in keys} + + +def setup_logger( + log_file="application.log", send_log_server: bool = parseable_settings.use_parseable +): + structlog.reset_defaults() + # Shared handlers + shared_processors = [ + structlog.contextvars.merge_contextvars, + structlog.stdlib.add_log_level, + structlog.stdlib.add_logger_name, + structlog.stdlib.PositionalArgumentsFormatter(), + structlog.processors.TimeStamper(fmt="iso"), + structlog.processors.StackInfoRenderer(), + structlog.processors.UnicodeDecoder(), + structlog.processors.EventRenamer("msg"), + ] + structlog.configure( + processors=shared_processors + + [ + structlog.stdlib.ProcessorFormatter.wrap_for_formatter, + ], + # Use standard logging compatible logger + logger_factory=structlog.stdlib.LoggerFactory(), + wrapper_class=structlog.stdlib.BoundLogger, + # Use Python's logging configuration + cache_logger_on_first_use=True, + ) + # Set Formatters + plain_fmt = structlog.stdlib.ProcessorFormatter( + foreign_pre_chain=shared_processors, + processors=[ + extract_from_record, + structlog.processors.format_exc_info, + structlog.stdlib.ProcessorFormatter.remove_processors_meta, + structlog.dev.ConsoleRenderer( + colors=False, exception_formatter=structlog.dev.plain_traceback + ), + ], + ) + color_fmt = structlog.stdlib.ProcessorFormatter( + processors=[ + drop_http_context, + structlog.dev.ConsoleRenderer( + colors=True, + exception_formatter=structlog.dev.RichTracebackFormatter( + show_locals=False + ), + ), + ], + foreign_pre_chain=shared_processors, + ) + parseable_fmt = structlog.stdlib.ProcessorFormatter( + processors=[ + # TODO: Which one gets us the better debug experience ? + # structlog.processors.ExceptionRenderer( + # exception_formatter=structlog.tracebacks.ExceptionDictTransformer( + # show_locals=False + # ) + # ), + structlog.processors.format_exc_info, + structlog.stdlib.ProcessorFormatter.remove_processors_meta, + structlog.processors.JSONRenderer(), + ], + foreign_pre_chain=shared_processors + + [ + structlog.processors.CallsiteParameterAdder( + { + structlog.processors.CallsiteParameter.FUNC_NAME, + structlog.processors.CallsiteParameter.LINENO, + } + ), + ], ) - color_formatter = ColoredFormatter( - "%(log_color)s[%(levelname)s]%(reset)s %(name)s [%(filename)s:%(lineno)d]: %(message)s", - log_colors={ - "DEBUG": "cyan", - "INFO": "green", - "WARNING": "yellow", - "ERROR": "red", - "CRITICAL": "red,bg_white", - }, - reset=True, - style="%", - ) - - console_handler = logging.StreamHandler() - console_handler.setFormatter(color_formatter) - + # Set handlers + console_handler = logging.StreamHandler(sys.stdout) file_handler = RotatingFileHandler( log_file, maxBytes=5000000, backupCount=5 ) # 5MB file - file_handler.setFormatter(formatter) + console_handler.setFormatter(color_fmt) + file_handler.setFormatter(plain_fmt) + handlers: list[logging.Handler] = [console_handler, file_handler] + if ( + send_log_server + and parseable_settings.parseable_url is not None + and parseable_settings.parseable_auth is not None + and parseable_settings.parseable_stream_name + ): + parseable_handler = ParseableLogHandler( + auth_token=parseable_settings.parseable_auth, + base_parseable_url=parseable_settings.parseable_url, + stream_name=parseable_settings.parseable_stream_name, + ) + parseable_handler.setFormatter(parseable_fmt) + handlers.append(parseable_handler) - if not logger.handlers: - logger.addHandler(console_handler) - logger.addHandler(file_handler) + # Configure logger + log_level = os.getenv("LOG_LEVEL", "INFO").upper() + root_logger = logging.getLogger() + root_logger.setLevel(log_level) + root_logger.handlers = [] + for handler in handlers: + root_logger.addHandler(handler) - return logger + _clear_uvicorn_logger() + + +def _clear_uvicorn_logger(): + for _log in [ + "uvicorn", + "httpcore", + "uvicorn.error", + "uvicorn.access", + "urllib3", + "httpx", + ]: + # Clear the log handlers for uvicorn loggers, and enable propagation + # so the messages are caught by our root logger and formatted correctly + # by structlog + logging.getLogger(_log).setLevel(logging.WARNING) + logging.getLogger(_log).handlers.clear() + logging.getLogger(_log).propagate = True + + +setup_logger() + + +def get_logger(name: str | None = None): + assert structlog.is_configured() + return structlog.get_logger(name) diff --git a/backend/api/quivr_api/main.py b/backend/api/quivr_api/main.py index bf2753a03..5d2696051 100644 --- a/backend/api/quivr_api/main.py +++ b/backend/api/quivr_api/main.py @@ -1,17 +1,17 @@ import logging import os -import litellm import sentry_sdk from dotenv import load_dotenv # type: ignore -from fastapi import FastAPI, HTTPException, Request -from fastapi.responses import HTMLResponse, JSONResponse +from fastapi import FastAPI, Request +from fastapi.responses import HTMLResponse from pyinstrument import Profiler from sentry_sdk.integrations.fastapi import FastApiIntegration from sentry_sdk.integrations.starlette import StarletteIntegration -from quivr_api.logger import get_logger +from quivr_api.logger import get_logger, stop_log_queue from quivr_api.middlewares.cors import add_cors_middleware +from quivr_api.middlewares.logging_middleware import LoggingMiddleware from quivr_api.modules.analytics.controller.analytics_routes import analytics_router from quivr_api.modules.api_key.controller import api_key_router from quivr_api.modules.assistant.controller import assistant_router @@ -27,21 +27,13 @@ from quivr_api.modules.upload.controller import upload_router from quivr_api.modules.user.controller import user_router from quivr_api.routes.crawl_routes import crawl_router from quivr_api.routes.subscription_routes import subscription_router -from quivr_api.utils import handle_request_validation_error from quivr_api.utils.telemetry import maybe_send_telemetry load_dotenv() -# Set the logging level for all loggers to WARNING + logging.basicConfig(level=logging.INFO) -logging.getLogger("httpx").setLevel(logging.WARNING) -logging.getLogger("LiteLLM").setLevel(logging.WARNING) -logging.getLogger("litellm").setLevel(logging.WARNING) -get_logger("quivr_core") -litellm.set_verbose = False # type: ignore - - -logger = get_logger(__name__) +logger = get_logger("quivr_api") def before_send(event, hint): @@ -72,6 +64,9 @@ if sentry_dsn: app = FastAPI() add_cors_middleware(app) +app.add_middleware(LoggingMiddleware) + + app.include_router(brain_router) app.include_router(chat_router) app.include_router(crawl_router) @@ -106,16 +101,11 @@ if PROFILING: return await call_next(request) -@app.exception_handler(HTTPException) -async def http_exception_handler(_, exc): - return JSONResponse( - status_code=exc.status_code, - content={"detail": exc.detail}, - ) +@app.on_event("shutdown") +def shutdown_event(): + stop_log_queue.set() -handle_request_validation_error(app) - if os.getenv("TELEMETRY_ENABLED") == "true": logger.info("Telemetry enabled, we use telemetry to collect anonymous usage data.") logger.info( diff --git a/backend/api/quivr_api/middlewares/auth/auth_bearer.py b/backend/api/quivr_api/middlewares/auth/auth_bearer.py index 73e3867cf..a5ce9af9c 100644 --- a/backend/api/quivr_api/middlewares/auth/auth_bearer.py +++ b/backend/api/quivr_api/middlewares/auth/auth_bearer.py @@ -1,6 +1,7 @@ import os from typing import Optional +import structlog from fastapi import Depends, HTTPException, Request from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -13,6 +14,8 @@ from quivr_api.modules.user.entity.user_identity import UserIdentity api_key_service = ApiKeyService() +logger = structlog.stdlib.get_logger("quivr_api.access") + class AuthBearer(HTTPBearer): def __init__(self, auto_error: bool = True): @@ -66,5 +69,10 @@ class AuthBearer(HTTPBearer): auth_bearer = AuthBearer() -def get_current_user(user: UserIdentity = Depends(auth_bearer)) -> UserIdentity: +async def get_current_user(user: UserIdentity = Depends(auth_bearer)) -> UserIdentity: + # Due to context switch in FastAPI executor we can't get this id back + # We log it as an additional log so we can get information if exception was raised + # https://www.structlog.org/en/stable/contextvars.html + structlog.contextvars.bind_contextvars(client_id=str(user.id)) + logger.info("Authentication success") return user diff --git a/backend/api/quivr_api/middlewares/logging_middleware.py b/backend/api/quivr_api/middlewares/logging_middleware.py new file mode 100644 index 000000000..b9911f30d --- /dev/null +++ b/backend/api/quivr_api/middlewares/logging_middleware.py @@ -0,0 +1,95 @@ +import os +import time +import uuid + +import structlog +from fastapi import Request, Response, status +from starlette.middleware.base import BaseHTTPMiddleware +from structlog.contextvars import ( + bind_contextvars, + clear_contextvars, +) + +logger = structlog.stdlib.get_logger("quivr_api.access") + + +git_sha = os.getenv("PORTER_IMAGE_TAG", None) + + +def clean_dict(d): + """Remove None values from a dictionary.""" + return {k: v for k, v in d.items() if v is not None} + + +class LoggingMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + clear_contextvars() + # Generate a unique request ID + request_id = str(uuid.uuid4()) + + client_addr = ( + f"{request.client.host}:{request.client.port}" if request.client else None + ) + url = request.url.path + http_version = request.scope["http_version"] + + bind_contextvars( + **clean_dict( + { + "git_head": git_sha, + "request_id": request_id, + "method": request.method, + "query_params": dict(request.query_params), + "client_addr": client_addr, + "request_user_agent": request.headers.get("user-agent"), + "request_content_type": request.headers.get("content-type"), + "url": url, + "http_version": http_version, + } + ) + ) + + # Start time + start_time = time.perf_counter() + response = Response(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + try: + # Process the request + response: Response = await call_next(request) + process_time = time.perf_counter() - start_time + bind_contextvars( + **clean_dict( + { + "response_content_type": response.headers.get("content-type"), + "response_status": response.status_code, + "response_headers": dict(response.headers), + "timing_request_total_ms": round(process_time * 1e3, 3), + } + ) + ) + + logger.info( + f"""{client_addr} - "{request.method} {url} HTTP/{http_version}" {response.status_code}""", + ) + except Exception: + process_time = time.perf_counter() - start_time + bind_contextvars( + **clean_dict( + { + "response_status": response.status_code, + "timing_request_total_ms": round(process_time * 1000, 3), + } + ) + ) + structlog.stdlib.get_logger("quivr_api.error").exception( + "Request failed with exception" + ) + raise + + finally: + clear_contextvars() + + # Add X-Request-ID to response headers + response.headers["X-Request-ID"] = request_id + response.headers["X-Process-Time"] = str(process_time) + + return response diff --git a/backend/api/quivr_api/models/settings.py b/backend/api/quivr_api/models/settings.py index 5987addc7..5f4050a99 100644 --- a/backend/api/quivr_api/models/settings.py +++ b/backend/api/quivr_api/models/settings.py @@ -1,15 +1,7 @@ -from typing import Optional from uuid import UUID from posthog import Posthog from pydantic_settings import BaseSettings, SettingsConfigDict -from sqlalchemy import Engine - -from quivr_api.logger import get_logger -from quivr_api.models.databases.supabase.supabase import SupabaseDB -from supabase.client import AsyncClient, Client - -logger = get_logger(__name__) class BrainRateLimiting(BaseSettings): @@ -122,7 +114,7 @@ class BrainSettings(BaseSettings): langfuse_secret_key: str | None = None pg_database_url: str pg_database_async_url: str - embedding_dim: int + embedding_dim: int = 1536 class ResendSettings(BaseSettings): @@ -134,11 +126,13 @@ class ResendSettings(BaseSettings): quivr_smtp_password: str = "" -# Global variables to store the Supabase client and database instances -_supabase_client: Optional[Client] = None -_supabase_async_client: Optional[AsyncClient] = None -_supabase_db: Optional[SupabaseDB] = None -_db_engine: Optional[Engine] = None -_embedding_service = None +class ParseableSettings(BaseSettings): + model_config = SettingsConfigDict(validate_default=False) + use_parseable: bool = False + parseable_url: str | None = None + parseable_auth: str | None = None + parseable_stream_name: str | None = None + settings = BrainSettings() # type: ignore +parseable_settings = ParseableSettings() diff --git a/backend/api/quivr_api/modules/dependencies.py b/backend/api/quivr_api/modules/dependencies.py index fd71696cd..e74bdee45 100644 --- a/backend/api/quivr_api/modules/dependencies.py +++ b/backend/api/quivr_api/modules/dependencies.py @@ -30,7 +30,7 @@ _embedding_service = None settings = BrainSettings() # type: ignore -logger = get_logger(__name__) +logger = get_logger("quivr_api") class BaseRepository: diff --git a/backend/api/quivr_api/modules/misc/controller/misc_routes.py b/backend/api/quivr_api/modules/misc/controller/misc_routes.py index 054798b34..ed085fe79 100644 --- a/backend/api/quivr_api/modules/misc/controller/misc_routes.py +++ b/backend/api/quivr_api/modules/misc/controller/misc_routes.py @@ -4,16 +4,21 @@ from quivr_api.modules.dependencies import get_async_session from sqlmodel import text from sqlmodel.ext.asyncio.session import AsyncSession -logger = get_logger(__name__) - +logger = get_logger() misc_router = APIRouter() +@misc_router.get("/excp") +async def excp(): + raise ValueError + + @misc_router.get("/") async def root(): """ Root endpoint to check the status of the API. """ + logger.info("this is a test", a=10) return {"status": "OK"} diff --git a/backend/api/quivr_api/utils/__init__.py b/backend/api/quivr_api/utils/__init__.py index c9e648c90..e69de29bb 100644 --- a/backend/api/quivr_api/utils/__init__.py +++ b/backend/api/quivr_api/utils/__init__.py @@ -1,2 +0,0 @@ -from .handle_request_validation_error import handle_request_validation_error -from .parse_message_time import parse_message_time diff --git a/backend/api/quivr_api/utils/handle_request_validation_error.py b/backend/api/quivr_api/utils/handle_request_validation_error.py deleted file mode 100644 index d539c7885..000000000 --- a/backend/api/quivr_api/utils/handle_request_validation_error.py +++ /dev/null @@ -1,24 +0,0 @@ -from fastapi import FastAPI, Request, status -from fastapi.exceptions import RequestValidationError -from fastapi.responses import JSONResponse - -from quivr_api.logger import get_logger - -logger = get_logger(__name__) - - -def handle_request_validation_error(app: FastAPI): - @app.exception_handler(RequestValidationError) - async def validation_exception_handler( - request: Request, exc: RequestValidationError - ): - exc_str = f"{exc}".replace("\n", " ").replace(" ", " ") - logger.error(request, exc_str) - content = { - "status_code": status.HTTP_422_UNPROCESSABLE_ENTITY, - "message": exc_str, - "data": None, - } - return JSONResponse( - content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY - ) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 1565ff6c9..80298fb94 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -10,7 +10,10 @@ authors = [ ] dependencies = [ "packaging>=22.0", - "langchain-anthropic>=0.1.23", + # Logging packages + "structlog>=24.4.0", + "python-json-logger>=2.0.7", + "orjson>=3.10.7", ] readme = "README.md" requires-python = ">= 3.11" diff --git a/backend/requirements-dev.lock b/backend/requirements-dev.lock index d96ff2965..498ff5d20 100644 --- a/backend/requirements-dev.lock +++ b/backend/requirements-dev.lock @@ -423,7 +423,6 @@ langchain==0.2.16 # via quivr-diff-assistant langchain-anthropic==0.1.23 # via quivr-core - # via quivr-monorepo langchain-cohere==0.2.2 # via quivr-api langchain-community==0.2.12 @@ -735,6 +734,7 @@ opentelemetry-semantic-conventions==0.48b0 # via opentelemetry-sdk orjson==3.10.7 # via langsmith + # via quivr-monorepo packaging==23.2 # via black # via chainlit @@ -982,6 +982,8 @@ python-iso639==2024.4.27 # via unstructured python-jose==3.3.0 # via quivr-api +python-json-logger==2.0.7 + # via quivr-monorepo python-magic==0.4.27 # via quivr-diff-assistant # via unstructured @@ -1135,6 +1137,8 @@ strenum==0.4.15 # via postgrest striprtf==0.0.26 # via llama-index-readers-file +structlog==24.4.0 + # via quivr-monorepo supabase==2.7.2 # via quivr-api supafunc==0.5.1 diff --git a/backend/requirements.lock b/backend/requirements.lock index ff4e2f9fb..7bb40f61e 100644 --- a/backend/requirements.lock +++ b/backend/requirements.lock @@ -374,7 +374,6 @@ langchain==0.2.16 # via quivr-diff-assistant langchain-anthropic==0.1.23 # via quivr-core - # via quivr-monorepo langchain-cohere==0.2.2 # via quivr-api langchain-community==0.2.12 @@ -645,6 +644,7 @@ openpyxl==3.1.5 # via unstructured orjson==3.10.7 # via langsmith + # via quivr-monorepo packaging==24.1 # via deprecation # via faiss-cpu @@ -849,6 +849,8 @@ python-iso639==2024.4.27 # via unstructured python-jose==3.3.0 # via quivr-api +python-json-logger==2.0.7 + # via quivr-monorepo python-magic==0.4.27 # via quivr-diff-assistant # via unstructured @@ -991,6 +993,8 @@ strenum==0.4.15 # via postgrest striprtf==0.0.26 # via llama-index-readers-file +structlog==24.4.0 + # via quivr-monorepo supabase==2.7.2 # via quivr-api supafunc==0.5.1 diff --git a/backend/worker/quivr_worker/celery_monitor.py b/backend/worker/quivr_worker/celery_monitor.py index 5ce7f0e96..795ea3d5d 100644 --- a/backend/worker/quivr_worker/celery_monitor.py +++ b/backend/worker/quivr_worker/celery_monitor.py @@ -7,7 +7,7 @@ from uuid import UUID from attr import dataclass from celery.result import AsyncResult from quivr_api.celery_config import celery -from quivr_api.logger import get_logger +from quivr_api.logger import get_logger, setup_logger from quivr_api.modules.dependencies import async_engine from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository from quivr_api.modules.assistant.repository.tasks import TasksRepository @@ -21,7 +21,8 @@ from quivr_api.modules.notification.service.notification_service import ( from quivr_core.models import KnowledgeStatus from sqlmodel.ext.asyncio.session import AsyncSession -logger = get_logger("notifier_service", "notifier_service.log") +setup_logger("notifier.log", send_log_server=False) +logger = get_logger("notifier_service") notification_service = NotificationService() queue = Queue() diff --git a/backend/worker/quivr_worker/celery_worker.py b/backend/worker/quivr_worker/celery_worker.py index 10e983322..dce4a301d 100644 --- a/backend/worker/quivr_worker/celery_worker.py +++ b/backend/worker/quivr_worker/celery_worker.py @@ -2,12 +2,15 @@ import asyncio import os from uuid import UUID +import structlog import torch +from celery import signals from celery.schedules import crontab from celery.signals import worker_process_init +from celery.utils.log import get_task_logger from dotenv import load_dotenv from quivr_api.celery_config import celery -from quivr_api.logger import get_logger +from quivr_api.logger import setup_logger from quivr_api.models.settings import settings from quivr_api.modules.assistant.repository.tasks import TasksRepository from quivr_api.modules.assistant.services.tasks_service import TasksService @@ -49,11 +52,11 @@ from quivr_worker.utils.utils import _patch_json torch.set_num_threads(1) - +setup_logger("worker.log", send_log_server=False) load_dotenv() -get_logger("quivr_core") -logger = get_logger("celery_worker") +logger = structlog.wrap_logger(get_task_logger(__name__)) + _patch_json() @@ -73,6 +76,13 @@ async_engine: AsyncEngine | None = None engine: Engine | None = None +@signals.task_prerun.connect +def on_task_prerun(sender, task_id, task, args, kwargs, **_): + structlog.contextvars.bind_contextvars(task_id=task_id, task_name=task.name) + if vars := kwargs.get("contextvars", None): + structlog.contextvars.bind_contextvars(**vars) + + @worker_process_init.connect def init_worker(**kwargs): global async_engine