File size: 2,328 Bytes
99637f2
 
 
f91cc3b
 
0f8445a
 
99637f2
 
 
 
 
 
 
 
 
 
f91cc3b
99637f2
 
f91cc3b
 
 
 
 
 
 
 
99637f2
0f8445a
 
 
 
f91cc3b
0f8445a
 
 
 
 
 
 
 
 
 
 
 
f91cc3b
 
 
 
99637f2
0f8445a
99637f2
f91cc3b
 
 
 
 
 
99637f2
f91cc3b
99637f2
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np
import faiss
import arxiv
from datasets import Dataset
import os
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
import torch

# Fetch arXiv papers
def fetch_arxiv_papers(query, max_results=10):
    client = arxiv.Client()
    search = arxiv.Search(
        query=query,
        max_results=max_results,
        sort_by=arxiv.SortCriterion.SubmittedDate
    )
    results = list(client.results(search))
    papers = [{"title": result.title, "text": result.summary, "id": str(i)} for i, result in enumerate(results)]
    return papers

# Build and save dataset with FAISS index
def build_faiss_index(papers, dataset_dir="rag_dataset"):
    # Create dataset
    dataset = Dataset.from_dict({
        "id": [p["id"] for p in papers],
        "title": [p["title"] for p in papers],
        "text": [p["text"] for p in papers],
    })

    # Initialize DPR context encoder (same as used by RAG)
    ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
    ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
    
    # Create embeddings
    embeddings = []
    batch_size = 8
    
    for i in range(0, len(dataset), batch_size):
        batch = dataset[i:i + batch_size]["text"]
        inputs = ctx_tokenizer(batch, max_length=512, padding=True, truncation=True, return_tensors="pt")
        with torch.no_grad():
            outputs = ctx_encoder(**inputs)
            batch_embeddings = outputs.pooler_output.cpu().numpy()
            embeddings.append(batch_embeddings)
    
    embeddings = np.vstack(embeddings)
    
    # Add embeddings to dataset
    dataset = dataset.add_column("embeddings", [emb.tolist() for emb in embeddings])
    
    # Create FAISS index
    dimension = embeddings.shape[1]  # Should be 768 for DPR
    index = faiss.IndexFlatL2(dimension)
    index.add(embeddings.astype(np.float32))
    
    # Save dataset and index
    os.makedirs(dataset_dir, exist_ok=True)
    dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
    faiss.write_index(index, os.path.join(dataset_dir, "embeddings.faiss"))

    return dataset_dir

# Example usage
if __name__ == "__main__":
    query = "quantum computing"
    papers = fetch_arxiv_papers(query)
    build_faiss_index(papers)