Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from fastapi import APIRouter, HTTPException | |
from fastapi.responses import StreamingResponse | |
from huggingface_hub import hf_hub_download, snapshot_download | |
import os | |
import tempfile | |
import shutil | |
import zipfile | |
import io | |
import logging | |
router = APIRouter(tags=["download"]) | |
async def download_dataset(session_id: str): | |
""" | |
Télécharge le dataset HuggingFace associé à une session et le renvoie au client | |
Args: | |
session_id: Identifiant de la session | |
Returns: | |
Fichier ZIP contenant le dataset | |
""" | |
try: | |
# Créer un répertoire temporaire pour stocker les fichiers du dataset | |
with tempfile.TemporaryDirectory() as temp_dir: | |
# Identifiant du repo HuggingFace | |
repo_id = f"yourbench/yourbench_{session_id}" | |
try: | |
# Télécharger le snapshot du dataset depuis HuggingFace | |
logging.info(f"Téléchargement du dataset {repo_id}") | |
snapshot_path = snapshot_download( | |
repo_id=repo_id, | |
repo_type="dataset", | |
local_dir=temp_dir, | |
token=os.environ.get("HF_TOKEN") | |
) | |
logging.info(f"Dataset téléchargé dans {snapshot_path}") | |
# Créer un fichier ZIP en mémoire | |
zip_io = io.BytesIO() | |
with zipfile.ZipFile(zip_io, 'w', zipfile.ZIP_DEFLATED) as zip_file: | |
# Parcourir tous les fichiers du dataset et les ajouter au ZIP | |
for root, _, files in os.walk(snapshot_path): | |
for file in files: | |
file_path = os.path.join(root, file) | |
arc_name = os.path.relpath(file_path, snapshot_path) | |
zip_file.write(file_path, arcname=arc_name) | |
# Remettre le curseur au début du stream | |
zip_io.seek(0) | |
# Renvoyer le ZIP au client | |
filename = f"yourbench_{session_id}_dataset.zip" | |
return StreamingResponse( | |
zip_io, | |
media_type="application/zip", | |
headers={"Content-Disposition": f"attachment; filename={filename}"} | |
) | |
except Exception as e: | |
logging.error(f"Erreur lors du téléchargement du dataset: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Erreur lors du téléchargement du dataset: {str(e)}" | |
) | |
except Exception as e: | |
logging.error(f"Erreur générale: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Erreur lors du téléchargement: {str(e)}" | |
) |