File size: 3,191 Bytes
99637f2
 
 
f91cc3b
 
0f8445a
 
8108db5
 
 
 
99637f2
0452175
 
 
 
99637f2
8108db5
f68ac31
 
f99a008
f68ac31
f99a008
f68ac31
 
 
 
 
99637f2
0452175
8108db5
f99a008
42d1dd5
 
 
 
 
f68ac31
 
f99a008
f68ac31
 
f99a008
 
42d1dd5
 
 
 
 
 
 
 
 
 
f68ac31
42d1dd5
 
 
 
 
 
 
 
92c1c48
42d1dd5
 
92c1c48
 
 
f68ac31
92c1c48
 
42d1dd5
f68ac31
42d1dd5
f68ac31
42d1dd5
92c1c48
42d1dd5
f68ac31
 
42d1dd5
f68ac31
42d1dd5
 
 
 
f68ac31
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import numpy as np
import faiss
import arxiv
from datasets import Dataset
import os
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
import torch
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)

# Define data paths
DATA_DIR = os.getenv("DATA_DIR", "/data" if os.path.exists("/data") else ".")
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")

def fetch_arxiv_papers(query, max_results=10):
    """Fetch papers from arXiv and format them for RAG"""
    client = arxiv.Client()
    search = arxiv.Search(
        query=f"{query} AND (cat:q-bio.NC OR cat:q-bio.QM OR cat:q-bio.GN OR cat:q-bio.CB OR cat:q-bio.MN)",  # Focus on biology and medical categories
        max_results=max_results,
        sort_by=arxiv.SortCriterion.Relevance  # Changed to relevance-based sorting
    )
    results = list(client.results(search))
    papers = [{"id": str(i), "text": result.summary, "title": result.title} for i, result in enumerate(results)]
    logging.info(f"Fetched {len(papers)} papers from arXiv")
    return papers

def build_faiss_index(papers, dataset_dir=DATASET_DIR):
    """Build and save dataset with FAISS index for RAG"""
    # Initialize smaller DPR encoder
    ctx_encoder = DPRContextEncoder.from_pretrained(
        "facebook/dpr-ctx_encoder-single-nq-base",
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True
    )
    ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
    
    # Create embeddings with smaller batches and memory optimization
    texts = [p["text"] for p in papers]
    embeddings = []
    batch_size = 4  # Smaller batch size
    
    with torch.inference_mode():
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i + batch_size]
            inputs = ctx_tokenizer(
                batch_texts,
                max_length=256,  # Reduced from default
                padding=True,
                truncation=True,
                return_tensors="pt"
            )
            outputs = ctx_encoder(**inputs)
            embeddings.extend(outputs.pooler_output.cpu().numpy())
            
            # Clear memory
            del outputs
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    # Convert to numpy array and build FAISS index
    embeddings = np.array(embeddings, dtype=np.float32)  # Ensure float32 type
    dimension = embeddings.shape[1]
    
    # Normalize the vectors manually
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    embeddings = embeddings / norms
    
    # Create FAISS index
    index = faiss.IndexFlatIP(dimension)
    index.add(embeddings)
    
    # Create and save the dataset
    dataset = Dataset.from_dict({
        "text": texts,
        "embeddings": embeddings.tolist(),  # Convert to list for storage
        "title": [p["title"] for p in papers]
    })
    
    # Create directory if it doesn't exist
    os.makedirs(dataset_dir, exist_ok=True)
    
    # Save dataset
    dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
    logging.info(f"Dataset saved to {dataset_dir}")
    return dataset_dir