demo / backend /routes /download.py
tfrere's picture
add url importer | improve yourbench error handling | refactor
c750639
raw
history blame
6.42 kB
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from huggingface_hub import hf_hub_download, snapshot_download
import os
import tempfile
import shutil
import zipfile
import io
import logging
import json
from datasets import load_dataset
router = APIRouter(tags=["download"])
@router.get("/download-dataset/{session_id}")
async def download_dataset(session_id: str):
"""
Télécharge le dataset HuggingFace associé à une session et le renvoie au client
Args:
session_id: Identifiant de la session
Returns:
Fichier ZIP contenant le dataset
"""
try:
# Créer un répertoire temporaire pour stocker les fichiers du dataset
with tempfile.TemporaryDirectory() as temp_dir:
# Identifiant du repo HuggingFace
repo_id = f"yourbench/yourbench_{session_id}"
try:
# Télécharger le snapshot du dataset depuis HuggingFace
logging.info(f"Téléchargement du dataset {repo_id}")
snapshot_path = snapshot_download(
repo_id=repo_id,
repo_type="dataset",
local_dir=temp_dir,
token=os.environ.get("HF_TOKEN")
)
logging.info(f"Dataset téléchargé dans {snapshot_path}")
# Créer un fichier ZIP en mémoire
zip_io = io.BytesIO()
with zipfile.ZipFile(zip_io, 'w', zipfile.ZIP_DEFLATED) as zip_file:
# Parcourir tous les fichiers du dataset et les ajouter au ZIP
for root, _, files in os.walk(snapshot_path):
for file in files:
file_path = os.path.join(root, file)
arc_name = os.path.relpath(file_path, snapshot_path)
zip_file.write(file_path, arcname=arc_name)
# Remettre le curseur au début du stream
zip_io.seek(0)
# Renvoyer le ZIP au client
filename = f"yourbench_{session_id}_dataset.zip"
return StreamingResponse(
zip_io,
media_type="application/zip",
headers={"Content-Disposition": f"attachment; filename={filename}"}
)
except Exception as e:
logging.error(f"Erreur lors du téléchargement du dataset: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Erreur lors du téléchargement du dataset: {str(e)}"
)
except Exception as e:
logging.error(f"Erreur générale: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Erreur lors du téléchargement: {str(e)}"
)
@router.get("/download-questions/{session_id}")
async def download_questions(session_id: str):
"""
Télécharge les questions générées pour une session au format JSON
Args:
session_id: Identifiant de la session
Returns:
Fichier JSON contenant les questions générées
"""
try:
# Identifiant du repo HuggingFace
dataset_repo_id = f"yourbench/yourbench_{session_id}"
# Initialize questions list
all_questions = []
# Try to load single-shot questions
try:
single_dataset = load_dataset(dataset_repo_id, 'single_shot_questions')
if single_dataset and len(single_dataset['train']) > 0:
for idx in range(len(single_dataset['train'])):
all_questions.append({
"id": str(idx),
"question": single_dataset['train'][idx].get("question", ""),
"answer": single_dataset['train'][idx].get("self_answer", "No answer available"),
"type": "single_shot"
})
logging.info(f"Loaded {len(all_questions)} single-shot questions")
except Exception as e:
logging.error(f"Error loading single-shot questions: {str(e)}")
# Try to load multi-hop questions
try:
multi_dataset = load_dataset(dataset_repo_id, 'multi_hop_questions')
if multi_dataset and len(multi_dataset['train']) > 0:
start_idx = len(all_questions)
for idx in range(len(multi_dataset['train'])):
all_questions.append({
"id": str(start_idx + idx),
"question": multi_dataset['train'][idx].get("question", ""),
"answer": multi_dataset['train'][idx].get("self_answer", "No answer available"),
"type": "multi_hop"
})
logging.info(f"Loaded {len(multi_dataset['train'])} multi-hop questions")
except Exception as e:
logging.error(f"Error loading multi-hop questions: {str(e)}")
# If we couldn't load any questions, the dataset might not exist
if len(all_questions) == 0:
raise HTTPException(status_code=404, detail="Aucune question trouvée pour cette session")
# Convert questions to JSON
questions_json = json.dumps({
"session_id": session_id,
"questions": all_questions
}, ensure_ascii=False, indent=2)
# Create a BytesIO object with the JSON data
json_bytes = io.BytesIO(questions_json.encode('utf-8'))
json_bytes.seek(0)
# Return the JSON file for download
filename = f"yourbench_{session_id}_questions.json"
return StreamingResponse(
json_bytes,
media_type="application/json",
headers={"Content-Disposition": f"attachment; filename={filename}"}
)
except HTTPException:
# Re-raise HTTP exceptions
raise
except Exception as e:
logging.error(f"Erreur lors de la récupération des questions: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Erreur lors du téléchargement des questions: {str(e)}"
)