Task: Upload multiple files coming from UI clients ( NextJS ) using FastAPI POST method. There are two parts to it: Validating each file and dumping it in s3. Below is the code which i am trying to make work ( For simplicity, this one includes a simple html file and dumping onto disk rather than s3 ). I have left comments and print statements as is
Problems and what’s needed:
- The code is quite dependent on
asyncio.sleep(value)
. If I change the value, sometimes the upload file method goes really fast and tries to upload bunch of files and then websocket sends their progress. What I want is that each progress update ( % completion ) should be sent to client i.e.one by one and should be consistent. - How to use
file_len
andcnt
variables inProgressTracker
Class itself. I tried using them in__init__
method but was getting some inconsistent results. - Also, during run vs debug modes in pycharm, the results are not always consistent. Not able to narrow down the reason behind that ( might by sleep methods ? )
main.py
from fastapi import FastAPI, UploadFile, WebSocket, WebSocketDisconnect, Form, Request, Depends
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from typing import List, Dict
import os
import asyncio
import json
app = FastAPI()
# Initialize Jinja2 templates
templates = Jinja2Templates(directory="templates")
class ProgressTracker:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(ProgressTracker, cls).__new__(cls)
cls._instance.client_progress = {}
return cls._instance
def set_progress(self, client_id: str, file_name: str, progress_value: int):
if client_id not in self.client_progress:
self.client_progress[client_id] = {}
self.client_progress[client_id][file_name] = progress_value
def get_progress(self, client_id: str):
return self.client_progress.get(client_id, {})
def clear_progress(self, client_id: str):
if client_id in self.client_progress:
del self.client_progress[client_id]
# Dependency injection function to provide ProgressTracker
def get_progress_tracker():
return ProgressTracker()
# Serve the upload page
@app.get("/", response_class=HTMLResponse)
async def get(request: Request):
return templates.TemplateResponse("upload.html", {"request": request})
# WebSocket endpoint for progress updates
@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str, tracker: ProgressTracker = Depends(get_progress_tracker)):
global cnt
cnt = 0
await websocket.accept()
try:
while True:
client_progress = tracker.get_progress(client_id)
# if not client_progress:
# # If there's no progress yet, sleep and wait
# await asyncio.sleep(0.1)
# continue
all_done = True
# print('--')
for file_name, prog in client_progress.items():
if prog < 100:
await websocket.send_text(json.dumps({"file": file_name, "progress": prog}))
all_done = False
# Ensure the 100% update is sent
if prog == 100:
cnt += 1
await websocket.send_text(json.dumps({"file": file_name, "progress": 100}))
# del client_progress[file_name]
# print('here..')
print(tracker.get_progress(client_id))
tracker.clear_progress(client_id)
# all_done = False
tracker.get_progress(client_id)
print('--------------------------------')
if all_done and client_progress and cnt >= file_len:
break
await asyncio.sleep(0.1)
except WebSocketDisconnect:
print(f"Client {client_id} disconnected")
finally:
tracker.clear_progress(client_id)
# Endpoint to handle multiple file uploads with validation
@app.post("/uploadfiles/")
async def upload_files(
client_id: str = Form(...),
files: List[UploadFile] = Form(...),
tracker: ProgressTracker = Depends(get_progress_tracker)
):
global file_len
file_len=len(files)
print(file_len)
try:
for file in files:
file_name = file.filename
print(file_name + "---------")
tracker.set_progress(client_id, file_name, 0)
# Step 1: File validation (50% progress)
await asyncio.sleep(2) # Simulate file validation
tracker.set_progress(client_id, file_name, 50) # Validation completed
# Step 2: File upload (50% progress)
filepath = f"uploads/{file_name}"
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, "wb") as f:
chunk_size = 1024 * 1024 # 1MB chunk size
total_size = file.file._file.seek(0, os.SEEK_END)
file.file._file.seek(0)
uploaded_size = 0
# for chunk in iter(lambda: file.file.read(chunk_size), b''):
# f.write(chunk)
# uploaded_size += len(chunk)
# tracker.set_progress(client_id, file_name, 50 + int((uploaded_size / total_size) * 50))
# Read and write file in chunks
while True:
chunk = await file.read(chunk_size)
if not chunk:
break
f.write(chunk)
uploaded_size += len(chunk)
# Update progress based on the chunk size
tracker.set_progress(client_id, file_name, 50 + int((uploaded_size / total_size) * 50))
# Yield control so the WebSocket can send the update
await asyncio.sleep(0.1)
# tracker.set_progress(client_id, file_name, 100) # Ensure 100% at the end
return {"filenames": [file.filename for file in files]}
except Exception as e:
for file in files:
tracker.set_progress(client_id, file.filename, 0)
return {"error": str(e)}
if __name__ == "__main__":
host = "127.0.0.1"
port = 8000
uvicorn.run(app, host=host, port=port)
UI – upload.html // in templates folder in same dir
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>File Upload with WebSocket Progress</title>
<style>
.progress {
border: 1px solid #000;
width: 100%;
height: 24px;
background-color: #f3f3f3;
position: relative;
margin-bottom: 10px;
}
.progress-bar {
height: 100%;
background-color: #4caf50;
text-align: center;
color: white;
white-space: nowrap;
overflow: hidden;
}
</style>
</head>
<body>
<h1>Upload Files</h1>
<form id="uploadForm" enctype="multipart/form-data">
<input type="file" id="files" name="files" multiple>
<button type="button" onclick="startUpload()">Upload</button>
</form>
<div id="progressDisplay"></div>
<script>
function startUpload() {
const clientId = Math.random().toString(36).substring(7);
const ws = new WebSocket(`ws://localhost:8000/ws/${clientId}`);
console.log(ws);
const progressDisplay = document.getElementById('progressDisplay');
progressDisplay.innerHTML = '';
// Track the number of files to ensure the WebSocket only closes after all files are processed.
let totalFiles = document.getElementById('files').files.length;
let processedFiles = 0;
// Wait for the WebSocket connection to open
ws.onopen = function() {
console.log("WebSocket connection opened");
const form = document.getElementById("uploadForm");
const formData = new FormData(form);
formData.append("client_id", clientId);
// Send files via POST request to the backend
fetch("/uploadfiles/", {
method: "POST",
body: formData
})
.then(response => response.json())
.then(data => {
if (data.error) {
console.error("Error:", data.error);
} else {
console.log("Files uploaded:", data.filenames);
}
})
.catch(error => console.error("Upload failed:", error));
};
// Handle WebSocket messages to update progress for multiple files
ws.onmessage = function(event) {
const data = JSON.parse(event.data);
const fileName = data.file;
const progressValue = data.progress;
let progressElem = document.getElementById(fileName);
// If progress element doesn't exist, create a new one
if (!progressElem) {
const container = document.createElement('div');
container.id = fileName;
container.innerHTML = `
<p>${fileName}:</p>
<div class="progress">
<div class="progress-bar" style="width: 0%;">0%</div>
</div>`;
progressDisplay.appendChild(container);
progressElem = document.getElementById(fileName);
}
const progressBar = progressElem.querySelector('.progress-bar');
progressBar.style.width = `${progressValue}%`;
progressBar.innerHTML = `${progressValue}%`;
// Stop updating after reaching 100% and track the number of processed files
if (progressValue >= 100) {
processedFiles++;
setTimeout(() => {
progressElem.remove();
}, 2000);
// Close the WebSocket only after all files are processed
if (processedFiles >= totalFiles) {
ws.close();
}
}
};
ws.onclose = function() {
console.log("WebSocket connection closed");
};
ws.onerror = function(error) {
console.error("WebSocket error:", error);
};
}
</script>
</body>
</html>