demo / backend /routes /download.py
tfrere's picture
update question download format
e64aebd
raw
history blame
6.29 kB
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from huggingface_hub import hf_hub_download, snapshot_download
import os
import tempfile
import shutil
import zipfile
import io
import logging
import json
from datasets import load_dataset
router = APIRouter(tags=["download"])
@router.get("/download-dataset/{session_id}")
async def download_dataset(session_id: str):
"""
Downloads the HuggingFace dataset associated with a session and returns it to the client
Args:
session_id: Session identifier
Returns:
ZIP file containing the dataset
"""
try:
# Create a temporary directory to store the dataset files
with tempfile.TemporaryDirectory() as temp_dir:
# HuggingFace repo identifier
repo_id = f"yourbench/yourbench_{session_id}"
try:
# Download the dataset snapshot from HuggingFace
logging.info(f"Downloading dataset {repo_id}")
snapshot_path = snapshot_download(
repo_id=repo_id,
repo_type="dataset",
local_dir=temp_dir,
token=os.environ.get("HF_TOKEN")
)
logging.info(f"Dataset downloaded to {snapshot_path}")
# Create a ZIP file in memory
zip_io = io.BytesIO()
with zipfile.ZipFile(zip_io, 'w', zipfile.ZIP_DEFLATED) as zip_file:
# Loop through all files in the dataset and add them to the ZIP
for root, _, files in os.walk(snapshot_path):
for file in files:
file_path = os.path.join(root, file)
arc_name = os.path.relpath(file_path, snapshot_path)
zip_file.write(file_path, arcname=arc_name)
# Reset the cursor to the beginning of the stream
zip_io.seek(0)
# Return the ZIP to the client
filename = f"yourbench_{session_id}_dataset.zip"
return StreamingResponse(
zip_io,
media_type="application/zip",
headers={"Content-Disposition": f"attachment; filename={filename}"}
)
except Exception as e:
logging.error(f"Error while downloading the dataset: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Error while downloading the dataset: {str(e)}"
)
except Exception as e:
logging.error(f"General error: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Error during download: {str(e)}"
)
@router.get("/download-questions/{session_id}")
async def download_questions(session_id: str):
"""
Downloads the questions generated for a session in JSON format
Args:
session_id: Session identifier
Returns:
JSON file containing only the list of generated questions
"""
try:
# HuggingFace repo identifier
dataset_repo_id = f"yourbench/yourbench_{session_id}"
# Initialize questions list
all_questions = []
# Try to load single-shot questions
try:
single_dataset = load_dataset(dataset_repo_id, 'single_shot_questions')
if single_dataset and len(single_dataset['train']) > 0:
for idx in range(len(single_dataset['train'])):
all_questions.append({
"id": str(idx),
"question": single_dataset['train'][idx].get("question", ""),
"answer": single_dataset['train'][idx].get("self_answer", "No answer available"),
"type": "single_shot"
})
logging.info(f"Loaded {len(all_questions)} single-shot questions")
except Exception as e:
logging.error(f"Error loading single-shot questions: {str(e)}")
# Try to load multi-hop questions
try:
multi_dataset = load_dataset(dataset_repo_id, 'multi_hop_questions')
if multi_dataset and len(multi_dataset['train']) > 0:
start_idx = len(all_questions)
for idx in range(len(multi_dataset['train'])):
all_questions.append({
"id": str(start_idx + idx),
"question": multi_dataset['train'][idx].get("question", ""),
"answer": multi_dataset['train'][idx].get("self_answer", "No answer available"),
"type": "multi_hop"
})
logging.info(f"Loaded {len(multi_dataset['train'])} multi-hop questions")
except Exception as e:
logging.error(f"Error loading multi-hop questions: {str(e)}")
# If we couldn't load any questions, the dataset might not exist
if len(all_questions) == 0:
raise HTTPException(status_code=404, detail="No questions found for this session")
# Convert only the list of questions to JSON (without session_id and without wrapping object)
questions_json = json.dumps(all_questions, ensure_ascii=False, indent=2)
# Create a BytesIO object with the JSON data
json_bytes = io.BytesIO(questions_json.encode('utf-8'))
json_bytes.seek(0)
# Return the JSON file for download
filename = f"yourbench_{session_id}_questions.json"
return StreamingResponse(
json_bytes,
media_type="application/json",
headers={"Content-Disposition": f"attachment; filename={filename}"}
)
except HTTPException:
# Re-raise HTTP exceptions
raise
except Exception as e:
logging.error(f"Error retrieving questions: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Error downloading questions: {str(e)}"
)