I am implementing queue management using Celery to get estimated time and queue position for any task though polling every 10 secs.
Teck Stack: Using Redis as Message Broker and Backend to store state and metadata of estimated time based on task_name
Issue: Unable to fetch estimated_times for request when used in Polling with end point /{request_id}/fetch
Unable to figure out if the issue is because of using Redis as Redis is in-memory database therefore it is not fetching estimated_times without celery task worker
I am able to correctly fetch ‘progress’ (in_queue/in_progress/success) correctly but it is returning empty on estimated_times
Celery_config.py:
from celery import Celery
celery_app = Celery(
'tasks',
broker='redis://localhost:6379/0',
backend='redis://localhost:6379/0',
include=["tasks"]
)
celery_app.conf.update(
task_serializer='json',
result_serializer='json',
accept_content=['json'],
result_expires=3600, # 1 hour
timezone="UTC",
enable_utc=True,
)
Tasks.py:
from handle import {
handle_square,
handle_cube,
handle_default
}
logger = logging.getLogger(__name__)
class CommonRequest:
setting: str
number: int
@celery_app.task(bind=True, name="tasks.process_request")
def process_request(self, request: CommonRequest):
task_id = self.request.id
try:
common_request = CommonRequest(**request) # Deserialize dict into CommonRequest model
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid request format: {e}")
setting = common_request.setting
number = common_request.number
if setting == "square":
return handle_square(task_id, number)
elif setting == "cube":
return handle_cube(task_id, number)
else:
return handle_default(task_id, number)
(All handlers) handle.py:
import logging
from datetime import datetime, timezone
import time
import redis
logger = logging.getLogger(__name__)
redis_client = redis.StrictRedis(host='localhost', port=6379, db=2)
def record_execution_time(task_id, task_name, start_time, end_time):
execution_time = (end_time - start_time).total_seconds()
logger.info(f"For {task_name}, execution time - {execution_time}")
key = f'{task_name}_execution_times'
redis_client.lpush(key, execution_time)
execution_times = redis_client.lrange(key, 0, -1)
logger.info(f"execution_times: {execution_times}")
execution_times_floats = [float(x.decode('utf-8')) for x in execution_times]
logger.info(f"execution_times: {execution_times_floats}")
return execution_time
def get_average_execution_time(task_name):
key = f'{task_name}_execution_times'
execution_times = redis_client.lrange(key, 0, -1)
logger.info(f"get_avg: execution_times = {execution_times}")
if execution_times:
execution_times_floats = [float(x.decode('utf-8')) for x in execution_times]
return mean(execution_times_floats)
else:
return 0
def handle_square(task_id, number: int):
start_time = datetime.now(timezone.utc)
task_name = 'square'
progress = 'in_process'
update_task_progress(request_id, task_name, progress)
responses = number**2
# Simulating time to process gpu and cpu bound process
time.sleep(30)
end_time = datetime.now(timezone.utc)
logger.info(f"For {task_name}, end time: {end_time}")
execution_time = record_execution_time(request_id, task_name, start_time, end_time)
redis_client.hset(request_id, 'execution_time', execution_time)
progress = 'completed'
update_task_progress(request_id, task_name, progress)
get_average_execution_time(task_name)
return {"predictions": responses}
def handle_cube(task_id, number: int):
start_time = datetime.now(timezone.utc)
task_name = 'cube'
progress = 'in_process'
update_task_progress(request_id, task_name, progress)
responses = number**3
# Simulating time to process gpu and cpu bound process
time.sleep(20)
end_time = datetime.now(timezone.utc)
logger.info(f"For {task_name}, end time: {end_time}")
execution_time = record_execution_time(request_id, task_name, start_time, end_time)
redis_client.hset(request_id, 'execution_time', execution_time)
progress = 'completed'
update_task_progress(request_id, task_name, progress)
get_average_execution_time(task_name)
return {"predictions": responses}
main.py:
from celery.result import AsyncResult
from celeryconfig import celery_app
from common import (
CommonRequest,
get_average_execution_time
)
from tasks import process_request
from fastapi import FastAPI, HTTPException, status
from fastapi.responses import JSONResponse
import logging, redis
logger = logging.getLogger(__name__)
redis_client = redis.StrictRedis(host='localhost', port=6379, db=2)
app = FastAPI(title="appcustom")
def estimate_remaining_time(task_name, position_in_queue):
average_time = get_average_execution_time(task_name)
if average_time is not None:
return position_in_queue * average_time
return 'Unknown'
@app.post(/process)
def process(request: CommonRequest):
task = process_request.delay(request.serialize())
task_id = task.id
redis_client.hset(task_id, 'task_type', request.setting)
redis_client.hset(task_id, 'progress', 'in_queue')
return {'request_id': task.id}
@app.get('/{request_id}/fetch')
def get_task_result(request_id: str):
result = AsyncResult(request_id, app=celery_app)
if result.ready():
if result.successful():
execution_time = redis_client.hget(request_id, 'execution_time').decode()
task_type = redis_client.hget(request_id, 'task_type').decode()
progress = redis_client.hget(request_id, 'progress').decode()
return {
"result": result.get(),
"metadata": {
'state': result.state,
'status': progress if progress else None,
'execution_time': execution_time if execution_time else None,
'task_type': task_type if task_type else None
# 'estimated_time_remaining': estimated_time_remaining
}
}
else:
raise HTTPException(status_code=400, detail='Task failed')
else:
task_state = result.state
if task_state == 'PENDING':
queue_position = estimate_queue_position(request_id)
task_type = redis_client.hget(request_id, 'task_type')
progress = redis_client.hget(request_id, 'progress')
estimated_time_remaining = estimate_remaining_time(task_type, queue_position)
task_type = task_type.decode() if task_type else None
progress = progress.decode() if progress else None
response = {
'state': task_state,
'status': progress,
'task_type': task_type,
'queue_position': queue_position,
'estimated_time_remaining': estimated_time_remaining
}
return JSONResponse(status_code=status.HTTP_201_CREATED, content=response)
if task_state == 'STARTED':
# Estimate queue position and time remaining
queue_position = estimate_queue_position(request_id)
task_type = redis_client.hget(request_id, 'task_type')
progress = redis_client.hget(request_id, 'progress')
estimated_time_remaining = estimate_remaining_time(task_type, queue_position)
task_type = task_type.decode() if task_type else None
progress = progress.decode() if progress else None
response = {
'state': task_state,
'status': progress,
'task_type': task_type,
'queue_position': queue_position,
'estimated_time_remaining': estimated_time_remaining
}
return JSONResponse(status_code=status.HTTP_201_CREATED, content=response)
def estimate_queue_position(task_id: str) -> str:
inspect = celery_app.control.inspect()
active_tasks = inspect.active()
if active_tasks:
active_tasks = active_tasks.values()
# Flatten the list of active tasks
active_task_ids = [task['id'] for sublist in active_tasks for task in sublist]
if task_id in active_task_ids:
return str(0)
else:
registered_tasks = inspect.reserved()
if registered_tasks:
registered_tasks = registered_tasks.values()
registered_tasks_ids = [task['id'] for sublist in registered_tasks for task in sublist]
index = registered_tasks_ids.index(task_id) + 1
return str(index)
else:
return 'Unknown'
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8080)
LOGS:
In celery Logs, it is logging correctly when called at the end of process completed :
INFO 71848-8640711680 2024-07-11T14:17:48.199121-0700 common:123 - get_avg: execution_times = [b'30.007132', b'30.009898', b'30.012423', b'30.011808', b'30.019006', b'30.014377', b'30.018964', b'30.01835', b'30.01196', b'30.009163'
But in uvicorn logs:
It is not fetching estimated times
INFO 73246-13404745728 2024-07-11T14:17:30.398629-0700 common:123 - get_avg: execution_times = []
It is returning task_type and progress correctly in Fast API endpoint during polling but failing to fetch estimated_times:
Response:
Polling request a72466b3-27f3-42a8-89a2-f5f91f34ffbf - Status: 201, result: {'state': 'PENDING', 'status': 'in_queue', 'task_type': 'square', 'queue_position': '1', 'estimated_time_remaining': ''}
Polling request a72466b3-27f3-42a8-89a2-f5f91f34ffbf - Status: 201, result: {'state': 'PENDING', 'status': 'in_process', 'task_type': 'square', 'queue_position': '0', 'estimated_time_remaining': ''}
Result for request a72466b3-27f3-42a8-89a2-f5f91f34ffbf: {'result': {'predictions': [{'result': ""4}]}, 'metadata': {'state': 'SUCCESS', 'status': 'completed', 'execution_time': '30.007132', 'task_type': 'square'}}
Any error in the code or suggestion on how to improve it or use any other library/backend would be helpful
Note: These APIs are called from mobile client so didn’t add flower for monitoring celery tasks as the only requirement is to fetch in progress or in queue tasks and their historical data for time to return estimated time for each type of task