Spaces:
Sleeping
Sleeping
Last commit not found
import numpy as np | |
import faiss | |
import arxiv | |
from datasets import Dataset | |
import os | |
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer | |
import torch | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
# Define data paths | |
DATA_DIR = os.getenv("DATA_DIR", "/data" if os.path.exists("/data") else ".") | |
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset") | |
def fetch_arxiv_papers(query, max_results=10): | |
"""Fetch papers from arXiv and format them for RAG""" | |
client = arxiv.Client() | |
# Clean and prepare the search query | |
query = query.replace('and', '').strip() # Remove 'and' as it's treated as AND operator | |
terms = [term.strip() for term in query.split() if term.strip()] | |
# Always include autism in the search | |
if 'autism' not in [t.lower() for t in terms]: | |
terms.insert(0, 'autism') | |
# Create search query with required autism term | |
search_query = f'(abs:"autism" OR ti:"autism") AND ({" OR ".join([f'abs:"{term}" OR ti:"{term}"' for term in terms if term.lower() != "autism"])})' | |
search_query = f'({search_query}) AND (cat:q-bio* OR cat:med*)' | |
logging.info(f"Searching arXiv with query: {search_query}") | |
search = arxiv.Search( | |
query=search_query, | |
max_results=max_results * 2, # Get more results to filter | |
sort_by=arxiv.SortCriterion.Relevance | |
) | |
try: | |
results = list(client.results(search)) | |
papers = [] | |
for i, result in enumerate(results): | |
# Only include papers that mention autism | |
text = (result.title + " " + result.summary).lower() | |
if 'autism' in text: | |
papers.append({ | |
"id": str(i), | |
"text": result.summary, | |
"title": result.title, | |
"url": result.entry_id, # Add the paper URL | |
"published": result.published.strftime("%Y-%m-%d") # Add publication date | |
}) | |
if len(papers) >= max_results: | |
break | |
logging.info(f"Found {len(papers)} relevant papers about autism from arXiv") | |
return papers | |
except Exception as e: | |
logging.error(f"Error fetching papers from arXiv: {str(e)}") | |
return [] | |
def build_faiss_index(papers, dataset_dir=DATASET_DIR): | |
"""Build and save dataset with FAISS index for RAG""" | |
if not papers: | |
logging.warning("No papers found. Creating empty dataset.") | |
# Create an empty dataset with the expected structure | |
dataset = Dataset.from_dict({ | |
"text": [], | |
"embeddings": [], | |
"title": [] | |
}) | |
os.makedirs(dataset_dir, exist_ok=True) | |
dataset.save_to_disk(os.path.join(dataset_dir, "dataset")) | |
return dataset_dir | |
# Initialize smaller DPR encoder | |
ctx_encoder = DPRContextEncoder.from_pretrained( | |
"facebook/dpr-ctx_encoder-single-nq-base", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True | |
) | |
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base") | |
# Create embeddings with smaller batches and memory optimization | |
texts = [p["text"] for p in papers] | |
embeddings = [] | |
batch_size = 4 # Smaller batch size | |
with torch.inference_mode(): | |
for i in range(0, len(texts), batch_size): | |
batch_texts = texts[i:i + batch_size] | |
inputs = ctx_tokenizer( | |
batch_texts, | |
max_length=256, # Reduced from default | |
padding=True, | |
truncation=True, | |
return_tensors="pt" | |
) | |
outputs = ctx_encoder(**inputs) | |
embeddings.extend(outputs.pooler_output.cpu().numpy()) | |
# Clear memory | |
del outputs | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Convert to numpy array and build FAISS index | |
embeddings = np.array(embeddings, dtype=np.float32) # Ensure float32 type | |
dimension = embeddings.shape[1] | |
# Normalize the vectors manually | |
norms = np.linalg.norm(embeddings, axis=1, keepdims=True) | |
embeddings = embeddings / norms | |
# Create FAISS index | |
index = faiss.IndexFlatIP(dimension) | |
index.add(embeddings) | |
# Create and save the dataset | |
dataset = Dataset.from_dict({ | |
"text": texts, | |
"embeddings": embeddings.tolist(), # Convert to list for storage | |
"title": [p["title"] for p in papers] | |
}) | |
# Create directory if it doesn't exist | |
os.makedirs(dataset_dir, exist_ok=True) | |
# Save dataset | |
dataset.save_to_disk(os.path.join(dataset_dir, "dataset")) | |
logging.info(f"Dataset saved to {dataset_dir}") | |
return dataset_dir | |