At FereAI.xyz, we use celery at scale. Each time the trading agent wants to perform a trade or a manual sell is executed or emails are supposed to be sent for scheduled jobs. All these actions happen via celery.
Celery is great, but I always miss the feature from airflow where you can see the logs of each individual task run, and then be able to diagnose or debug something. So, I decided to build something around it.
Requirements
I wanted a solution where
- All logs from celery jobs are stored in DB along with their task id, status & a few other params
- Logging works otb for celery, fastapi & standalone usage (jupyter notebooks)
Here’s what I did
A database table for storing logs
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()from sqlalchemy import Column, Integer, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from datetime import datetime
from sqlalchemy.sql.expression import func
from sqlalchemy import DateTime
from .base import Base
class BatchedLog(Base):
__tablename__ = "batched_log"
id = Column(Integer, primary_key=True, autoincrement=True)
task_id = Column(String(255), nullable=True, index=True)
task_name = Column(String(255), nullable=True, index=True)
status = Column(String(255), nullable=True, index=True)
logs = Column(Text) # Store logs as JSON
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=func.now(),
index=True,
)Logger config
import contextvars
import logging
import os
from celery.app.log import TaskFormatter
# Create and configure the logger
logger = logging.getLogger("trader")
logger.setLevel(logging.DEBUG)
# Remove existing handlers to avoid duplication
if logger.hasHandlers():
logger.handlers.clear()
# Stream handler for outputting to stdout
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.DEBUG)
# Formatter for the log messages
formatter = TaskFormatter(
"%(asctime)s - %(task_id)s - %(task_name)s - %(name)s - %(threadName)s- %(levelname)s - %(message)s"
)
logger.addHandler(stream_handler)
# Prevent propagation to the root logger
logger.propagate = False
task_logs = contextvars.ContextVar("task_logs", default=[])
# Define a context variable to store the logger
logger_context = contextvars.ContextVar("logger_context", default=None)
def get_logger():
"""Retrieve the current logger from the context variable."""
lc = logger_context.get()
if lc is None:
logger.name = "Jupyter" if "ipykernel" in os.environ else "Standalone"
return logger
return lcCelery app
import json
import logging
import traceback
from uuid import UUID
from celery.app import Celery
from celery.app.log import TaskFormatter
from celery.schedules import crontab
from celery.signals import after_setup_task_logger
from celery.utils.log import get_task_logger
from celery.signals import task_prerun, task_postrun, task_failure
from celery.result import AsyncResult
from .logger_config import task_logs, logger_context, get_logger
app = Celery("trader", broker=get_redis_url(), backend=get_redis_url())
app.conf.task_logging_level = logging.DEBUG
# Signal to initialize logs
@task_prerun.connect
def initialize_logs(task_id=None, task=None, args=None, kwargs=None, **extras):
logger_context.set(celery_task_logger)
task_logs.set([])
celery_task_logger.info(f"Logger set for Celery task: {task_id}")
# Signal to persist logs on task completion
@task_postrun.connect
def persist_logs_on_completion(
task_id=None, task=None, args=None, kwargs=None, retval=None, **extras
):
result = AsyncResult(task_id)
status = result.status
task_name = task.name if task else "unknown"
logs = task_logs.get()
print(f"Persisting logs for task_id={task_id}")
# Save logs to the database
with get_trade_db_session() as session:
session.merge(
BatchedLog(
task_id=task_id,
status=status.lower(),
task_name=task_name,
logs=json.dumps(logs),
)
)
session.commit()
logger_context.set(None)
# Signal to handle task failure
@task_failure.connect
def handle_failure(
task_id=None,
task=None,
args=None,
kwargs=None,
exc=None,
traceback=None,
**extras,
):
logger = get_logger()
result = AsyncResult(task_id)
status = result.status
task_name = task.name if task else "unknown"
logger.error(f"Task failed: {exc}")
logger.error(
"".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
)
logs = task_logs.get()
print(f"Task failed. Persisting logs for task_id={task_id}")
# Save logs to the database
with get_trade_db_session() as session:
session.merge(
BatchedLog(
task_id=task_id,
status=status.lower(),
task_name=task_name,
logs=json.dumps(logs),
)
)
session.commit()
class TaskLogHandler(logging.Handler):
def emit(self, record):
logs = task_logs.get()
logs.append(
{
"level": record.levelname,
"message": record.getMessage(),
"time": record.created,
"filename": record.filename,
"lineno": record.lineno,
"funcName": record.funcName,
"thread": record.threadName,
}
)
task_logs.set(logs)
@after_setup_task_logger.connect
def setup_task_logger(logger, *args, **kwargs):
task_handler = TaskLogHandler()
logger.addHandler(task_handler)
for handler in logger.handlers:
handler.setFormatter(
TaskFormatter(
"%(asctime)s - %(task_id)s - %(task_name)s - %(name)s - %(levelname)s - %(message)s"
)
)
@app.task
def test_task():
celery_task_logger.info("Starting main task")
try:
celery_task_logger.info(
f"Started {len(agents_started)} agents with UUIDs: {agents_started}"
)
return agents_started
except Exception as e:
celery_task_logger.error(f"Task failed: {e}")
celery_task_logger.error(traceback.format_exc())
raiseAny other custom claases or functions used should also use the get_logger
from .logger_config import get_logger
class MyClass:
def __init__(self):
self.logger = get_logger()
def foo(self):
self.logger.info("Inside foo")
fastapi app
app = FastAPI(
...
)
@app.middleware("http")
async def set_fastapi_logger(request: Request, call_next):
# Assign a request-specific logger to the logger context
if not hasattr(request.state, "logger"):
request.state.logger = get_logger()
logger_context.set(request.state.logger)
response = await call_next(request)
# Clear the logger context after request handling
logger_context.set(None)
return response


