Spaces:
Sleeping
Sleeping
File size: 2,328 Bytes
99637f2 f91cc3b 0f8445a 99637f2 f91cc3b 99637f2 f91cc3b 99637f2 0f8445a f91cc3b 0f8445a f91cc3b 99637f2 0f8445a 99637f2 f91cc3b 99637f2 f91cc3b 99637f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import numpy as np
import faiss
import arxiv
from datasets import Dataset
import os
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
import torch
# Fetch arXiv papers
def fetch_arxiv_papers(query, max_results=10):
client = arxiv.Client()
search = arxiv.Search(
query=query,
max_results=max_results,
sort_by=arxiv.SortCriterion.SubmittedDate
)
results = list(client.results(search))
papers = [{"title": result.title, "text": result.summary, "id": str(i)} for i, result in enumerate(results)]
return papers
# Build and save dataset with FAISS index
def build_faiss_index(papers, dataset_dir="rag_dataset"):
# Create dataset
dataset = Dataset.from_dict({
"id": [p["id"] for p in papers],
"title": [p["title"] for p in papers],
"text": [p["text"] for p in papers],
})
# Initialize DPR context encoder (same as used by RAG)
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
# Create embeddings
embeddings = []
batch_size = 8
for i in range(0, len(dataset), batch_size):
batch = dataset[i:i + batch_size]["text"]
inputs = ctx_tokenizer(batch, max_length=512, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
outputs = ctx_encoder(**inputs)
batch_embeddings = outputs.pooler_output.cpu().numpy()
embeddings.append(batch_embeddings)
embeddings = np.vstack(embeddings)
# Add embeddings to dataset
dataset = dataset.add_column("embeddings", [emb.tolist() for emb in embeddings])
# Create FAISS index
dimension = embeddings.shape[1] # Should be 768 for DPR
index = faiss.IndexFlatL2(dimension)
index.add(embeddings.astype(np.float32))
# Save dataset and index
os.makedirs(dataset_dir, exist_ok=True)
dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
faiss.write_index(index, os.path.join(dataset_dir, "embeddings.faiss"))
return dataset_dir
# Example usage
if __name__ == "__main__":
query = "quantum computing"
papers = fetch_arxiv_papers(query)
build_faiss_index(papers) |