File size: 3,179 Bytes
99637f2
 
 
f91cc3b
 
0f8445a
 
8108db5
 
 
 
99637f2
0452175
 
 
 
99637f2
8108db5
f68ac31
 
f99a008
f68ac31
f99a008
f68ac31
 
 
 
 
99637f2
0452175
8108db5
f99a008
 
f68ac31
 
f99a008
f68ac31
 
f99a008
 
f68ac31
 
f99a008
f68ac31
 
 
 
f99a008
 
f68ac31
 
 
 
 
 
 
 
 
 
 
 
 
 
f99a008
 
 
f68ac31
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
import faiss
import arxiv
from datasets import Dataset
import os
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
import torch
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)

# Define data paths
DATA_DIR = os.getenv("DATA_DIR", "/data" if os.path.exists("/data") else ".")
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")

def fetch_arxiv_papers(query, max_results=10):
    """Fetch papers from arXiv and format them for RAG"""
    client = arxiv.Client()
    search = arxiv.Search(
        query=f"{query} AND (cat:q-bio.NC OR cat:q-bio.QM OR cat:q-bio.GN OR cat:q-bio.CB OR cat:q-bio.MN)",  # Focus on biology and medical categories
        max_results=max_results,
        sort_by=arxiv.SortCriterion.Relevance  # Changed to relevance-based sorting
    )
    results = list(client.results(search))
    papers = [{"id": str(i), "text": result.summary, "title": result.title} for i, result in enumerate(results)]
    logging.info(f"Fetched {len(papers)} papers from arXiv")
    return papers

def build_faiss_index(papers, dataset_dir=DATASET_DIR):
    """Build and save dataset with FAISS index for RAG"""
    # Initialize smaller DPR encoder
    ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", device_map="auto", load_in_8bit=True)
    ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
    
    # Create embeddings with smaller batches and memory optimization
    texts = [p["text"] for p in papers]
    embeddings = []
    batch_size = 4  # Smaller batch size
    
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i + batch_size]
        inputs = ctx_tokenizer(batch, max_length=256, padding=True, truncation=True, return_tensors="pt")  # Reduced max_length
        with torch.no_grad():
            outputs = ctx_encoder(**inputs)
            batch_embeddings = outputs.pooler_output.cpu().numpy()
            embeddings.append(batch_embeddings)
            del outputs  # Explicit cleanup
            torch.cuda.empty_cache()  # Clear GPU memory
    
    embeddings = np.vstack(embeddings)
    logging.info(f"Created embeddings with shape {embeddings.shape}")
    
    # Create dataset
    dataset = Dataset.from_dict({
        "id": [p["id"] for p in papers],
        "text": [p["text"] for p in papers],
        "title": [p["title"] for p in papers],
        "embeddings": [emb.tolist() for emb in embeddings],
    })
    
    # Create FAISS index
    dimension = embeddings.shape[1]
    quantizer = faiss.IndexFlatL2(dimension)
    index = faiss.IndexQuantizer(dimension, quantizer, 8)
    index.train(embeddings.astype(np.float32))
    index.add(embeddings.astype(np.float32))
    
    # Save dataset and index
    os.makedirs(dataset_dir, exist_ok=True)
    dataset_path = os.path.join(dataset_dir, "dataset")
    index_path = os.path.join(dataset_dir, "embeddings.faiss")
    dataset.save_to_disk(dataset_path)
    faiss.write_index(index, index_path)
    logging.info(f"Saved dataset to {dataset_path}")
    logging.info(f"Saved index to {index_path}")
    return dataset_dir