Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from fastapi import APIRouter, HTTPException | |
from fastapi.responses import StreamingResponse, JSONResponse | |
from huggingface_hub import hf_hub_download, snapshot_download | |
import os | |
import tempfile | |
import shutil | |
import zipfile | |
import io | |
import logging | |
import json | |
from datasets import load_dataset | |
router = APIRouter(tags=["download"]) | |
async def download_dataset(session_id: str): | |
""" | |
Télécharge le dataset HuggingFace associé à une session et le renvoie au client | |
Args: | |
session_id: Identifiant de la session | |
Returns: | |
Fichier ZIP contenant le dataset | |
""" | |
try: | |
# Créer un répertoire temporaire pour stocker les fichiers du dataset | |
with tempfile.TemporaryDirectory() as temp_dir: | |
# Identifiant du repo HuggingFace | |
repo_id = f"yourbench/yourbench_{session_id}" | |
try: | |
# Télécharger le snapshot du dataset depuis HuggingFace | |
logging.info(f"Téléchargement du dataset {repo_id}") | |
snapshot_path = snapshot_download( | |
repo_id=repo_id, | |
repo_type="dataset", | |
local_dir=temp_dir, | |
token=os.environ.get("HF_TOKEN") | |
) | |
logging.info(f"Dataset téléchargé dans {snapshot_path}") | |
# Créer un fichier ZIP en mémoire | |
zip_io = io.BytesIO() | |
with zipfile.ZipFile(zip_io, 'w', zipfile.ZIP_DEFLATED) as zip_file: | |
# Parcourir tous les fichiers du dataset et les ajouter au ZIP | |
for root, _, files in os.walk(snapshot_path): | |
for file in files: | |
file_path = os.path.join(root, file) | |
arc_name = os.path.relpath(file_path, snapshot_path) | |
zip_file.write(file_path, arcname=arc_name) | |
# Remettre le curseur au début du stream | |
zip_io.seek(0) | |
# Renvoyer le ZIP au client | |
filename = f"yourbench_{session_id}_dataset.zip" | |
return StreamingResponse( | |
zip_io, | |
media_type="application/zip", | |
headers={"Content-Disposition": f"attachment; filename={filename}"} | |
) | |
except Exception as e: | |
logging.error(f"Erreur lors du téléchargement du dataset: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Erreur lors du téléchargement du dataset: {str(e)}" | |
) | |
except Exception as e: | |
logging.error(f"Erreur générale: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Erreur lors du téléchargement: {str(e)}" | |
) | |
async def download_questions(session_id: str): | |
""" | |
Télécharge les questions générées pour une session au format JSON | |
Args: | |
session_id: Identifiant de la session | |
Returns: | |
Fichier JSON contenant les questions générées | |
""" | |
try: | |
# Identifiant du repo HuggingFace | |
dataset_repo_id = f"yourbench/yourbench_{session_id}" | |
# Initialize questions list | |
all_questions = [] | |
# Try to load single-shot questions | |
try: | |
single_dataset = load_dataset(dataset_repo_id, 'single_shot_questions') | |
if single_dataset and len(single_dataset['train']) > 0: | |
for idx in range(len(single_dataset['train'])): | |
all_questions.append({ | |
"id": str(idx), | |
"question": single_dataset['train'][idx].get("question", ""), | |
"answer": single_dataset['train'][idx].get("self_answer", "No answer available"), | |
"type": "single_shot" | |
}) | |
logging.info(f"Loaded {len(all_questions)} single-shot questions") | |
except Exception as e: | |
logging.error(f"Error loading single-shot questions: {str(e)}") | |
# Try to load multi-hop questions | |
try: | |
multi_dataset = load_dataset(dataset_repo_id, 'multi_hop_questions') | |
if multi_dataset and len(multi_dataset['train']) > 0: | |
start_idx = len(all_questions) | |
for idx in range(len(multi_dataset['train'])): | |
all_questions.append({ | |
"id": str(start_idx + idx), | |
"question": multi_dataset['train'][idx].get("question", ""), | |
"answer": multi_dataset['train'][idx].get("self_answer", "No answer available"), | |
"type": "multi_hop" | |
}) | |
logging.info(f"Loaded {len(multi_dataset['train'])} multi-hop questions") | |
except Exception as e: | |
logging.error(f"Error loading multi-hop questions: {str(e)}") | |
# If we couldn't load any questions, the dataset might not exist | |
if len(all_questions) == 0: | |
raise HTTPException(status_code=404, detail="Aucune question trouvée pour cette session") | |
# Convert questions to JSON | |
questions_json = json.dumps({ | |
"session_id": session_id, | |
"questions": all_questions | |
}, ensure_ascii=False, indent=2) | |
# Create a BytesIO object with the JSON data | |
json_bytes = io.BytesIO(questions_json.encode('utf-8')) | |
json_bytes.seek(0) | |
# Return the JSON file for download | |
filename = f"yourbench_{session_id}_questions.json" | |
return StreamingResponse( | |
json_bytes, | |
media_type="application/json", | |
headers={"Content-Disposition": f"attachment; filename={filename}"} | |
) | |
except HTTPException: | |
# Re-raise HTTP exceptions | |
raise | |
except Exception as e: | |
logging.error(f"Erreur lors de la récupération des questions: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Erreur lors du téléchargement des questions: {str(e)}" | |
) |