Spaces:
Sleeping
Sleeping
import numpy as np | |
import faiss | |
import arxiv | |
from datasets import Dataset | |
import os | |
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer | |
import torch | |
# Fetch arXiv papers | |
def fetch_arxiv_papers(query, max_results=10): | |
client = arxiv.Client() | |
search = arxiv.Search( | |
query=query, | |
max_results=max_results, | |
sort_by=arxiv.SortCriterion.SubmittedDate | |
) | |
results = list(client.results(search)) | |
papers = [{"title": result.title, "text": result.summary, "id": str(i)} for i, result in enumerate(results)] | |
return papers | |
# Build and save dataset with FAISS index | |
def build_faiss_index(papers, dataset_dir="rag_dataset"): | |
# Create dataset | |
dataset = Dataset.from_dict({ | |
"id": [p["id"] for p in papers], | |
"title": [p["title"] for p in papers], | |
"text": [p["text"] for p in papers], | |
}) | |
# Initialize DPR context encoder (same as used by RAG) | |
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base") | |
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base") | |
# Create embeddings | |
embeddings = [] | |
batch_size = 8 | |
for i in range(0, len(dataset), batch_size): | |
batch = dataset[i:i + batch_size]["text"] | |
inputs = ctx_tokenizer(batch, max_length=512, padding=True, truncation=True, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = ctx_encoder(**inputs) | |
batch_embeddings = outputs.pooler_output.cpu().numpy() | |
embeddings.append(batch_embeddings) | |
embeddings = np.vstack(embeddings) | |
# Add embeddings to dataset | |
dataset = dataset.add_column("embeddings", [emb.tolist() for emb in embeddings]) | |
# Create FAISS index | |
dimension = embeddings.shape[1] # Should be 768 for DPR | |
index = faiss.IndexFlatL2(dimension) | |
index.add(embeddings.astype(np.float32)) | |
# Save dataset and index | |
os.makedirs(dataset_dir, exist_ok=True) | |
dataset.save_to_disk(os.path.join(dataset_dir, "dataset")) | |
faiss.write_index(index, os.path.join(dataset_dir, "embeddings.faiss")) | |
return dataset_dir | |
# Example usage | |
if __name__ == "__main__": | |
query = "quantum computing" | |
papers = fetch_arxiv_papers(query) | |
build_faiss_index(papers) |