Spaces:
Sleeping
Sleeping
import numpy as np | |
import faiss | |
import arxiv | |
from datasets import Dataset | |
import os | |
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer | |
import torch | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
# Define data paths | |
DATA_DIR = os.getenv("DATA_DIR", "/data" if os.path.exists("/data") else ".") | |
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset") | |
def fetch_arxiv_papers(query, max_results=10): | |
"""Fetch papers from arXiv and format them for RAG""" | |
client = arxiv.Client() | |
search = arxiv.Search( | |
query=query, | |
max_results=max_results, | |
sort_by=arxiv.SortCriterion.SubmittedDate | |
) | |
results = list(client.results(search)) | |
papers = [{"id": str(i), "text": result.summary, "title": result.title} for i, result in enumerate(results)] | |
logging.info(f"Fetched {len(papers)} papers from arXiv") | |
return papers | |
def build_faiss_index(papers, dataset_dir=DATASET_DIR): | |
"""Build and save dataset with FAISS index for RAG""" | |
# Initialize DPR encoder | |
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base") | |
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base") | |
# Create embeddings | |
texts = [p["text"] for p in papers] | |
embeddings = [] | |
batch_size = 8 | |
for i in range(0, len(texts), batch_size): | |
batch = texts[i:i + batch_size] | |
inputs = ctx_tokenizer(batch, max_length=512, padding=True, truncation=True, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = ctx_encoder(**inputs) | |
batch_embeddings = outputs.pooler_output.cpu().numpy() | |
embeddings.append(batch_embeddings) | |
embeddings = np.vstack(embeddings) | |
logging.info(f"Created embeddings with shape {embeddings.shape}") | |
# Create dataset | |
dataset = Dataset.from_dict({ | |
"id": [p["id"] for p in papers], | |
"text": [p["text"] for p in papers], | |
"title": [p["title"] for p in papers], | |
"embeddings": [emb.tolist() for emb in embeddings], | |
}) | |
# Create FAISS index | |
dimension = embeddings.shape[1] | |
index = faiss.IndexFlatL2(dimension) | |
index.add(embeddings.astype(np.float32)) | |
# Save dataset and index | |
os.makedirs(dataset_dir, exist_ok=True) | |
dataset_path = os.path.join(dataset_dir, "dataset") | |
index_path = os.path.join(dataset_dir, "embeddings.faiss") | |
dataset.save_to_disk(dataset_path) | |
faiss.write_index(index, index_path) | |
logging.info(f"Saved dataset to {dataset_path}") | |
logging.info(f"Saved index to {index_path}") | |
return dataset_dir | |