File size: 4,860 Bytes
99637f2
 
 
f91cc3b
 
0f8445a
 
8108db5
 
 
 
99637f2
0452175
 
 
 
99637f2
8108db5
f68ac31
cc0b0d6
cc41495
 
 
cc0b0d6
62b3157
 
 
 
 
 
cc41495
 
 
cc0b0d6
f68ac31
cc0b0d6
cc41495
cc0b0d6
f68ac31
cc0b0d6
cc41495
 
 
 
 
62b3157
cc41495
62b3157
cc41495
 
 
62b3157
 
 
cc41495
 
 
 
62b3157
cc41495
 
 
 
99637f2
0452175
8108db5
54a5022
 
 
 
 
 
 
 
 
 
 
 
f99a008
42d1dd5
 
 
 
 
f68ac31
 
f99a008
f68ac31
 
f99a008
 
42d1dd5
 
 
 
 
 
 
 
 
 
f68ac31
42d1dd5
 
 
 
 
 
 
 
92c1c48
42d1dd5
 
92c1c48
 
 
f68ac31
92c1c48
 
42d1dd5
f68ac31
42d1dd5
f68ac31
42d1dd5
92c1c48
42d1dd5
f68ac31
 
42d1dd5
f68ac31
42d1dd5
 
 
 
f68ac31
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import numpy as np
import faiss
import arxiv
from datasets import Dataset
import os
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
import torch
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)

# Define data paths
DATA_DIR = os.getenv("DATA_DIR", "/data" if os.path.exists("/data") else ".")
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")

def fetch_arxiv_papers(query, max_results=10):
    """Fetch papers from arXiv and format them for RAG"""
    client = arxiv.Client()
    
    # Clean and prepare the search query
    query = query.replace('and', '').strip()  # Remove 'and' as it's treated as AND operator
    terms = [term.strip() for term in query.split() if term.strip()]
    
    # Always include autism in the search
    if 'autism' not in [t.lower() for t in terms]:
        terms.insert(0, 'autism')
    
    # Create search query with required autism term
    search_query = f'(abs:"autism" OR ti:"autism") AND ({" OR ".join([f'abs:"{term}" OR ti:"{term}"' for term in terms if term.lower() != "autism"])})'
    search_query = f'({search_query}) AND (cat:q-bio* OR cat:med*)'
    
    logging.info(f"Searching arXiv with query: {search_query}")
    
    search = arxiv.Search(
        query=search_query,
        max_results=max_results * 2,  # Get more results to filter
        sort_by=arxiv.SortCriterion.Relevance
    )
    
    try:
        results = list(client.results(search))
        papers = []
        
        for i, result in enumerate(results):
            # Only include papers that mention autism
            text = (result.title + " " + result.summary).lower()
            if 'autism' in text:
                papers.append({
                    "id": str(i),
                    "text": result.summary,
                    "title": result.title,
                    "url": result.entry_id,  # Add the paper URL
                    "published": result.published.strftime("%Y-%m-%d")  # Add publication date
                })
                if len(papers) >= max_results:
                    break
        
        logging.info(f"Found {len(papers)} relevant papers about autism from arXiv")
        return papers
    except Exception as e:
        logging.error(f"Error fetching papers from arXiv: {str(e)}")
        return []

def build_faiss_index(papers, dataset_dir=DATASET_DIR):
    """Build and save dataset with FAISS index for RAG"""
    if not papers:
        logging.warning("No papers found. Creating empty dataset.")
        # Create an empty dataset with the expected structure
        dataset = Dataset.from_dict({
            "text": [],
            "embeddings": [],
            "title": []
        })
        os.makedirs(dataset_dir, exist_ok=True)
        dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
        return dataset_dir

    # Initialize smaller DPR encoder
    ctx_encoder = DPRContextEncoder.from_pretrained(
        "facebook/dpr-ctx_encoder-single-nq-base",
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True
    )
    ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
    
    # Create embeddings with smaller batches and memory optimization
    texts = [p["text"] for p in papers]
    embeddings = []
    batch_size = 4  # Smaller batch size
    
    with torch.inference_mode():
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i + batch_size]
            inputs = ctx_tokenizer(
                batch_texts,
                max_length=256,  # Reduced from default
                padding=True,
                truncation=True,
                return_tensors="pt"
            )
            outputs = ctx_encoder(**inputs)
            embeddings.extend(outputs.pooler_output.cpu().numpy())
            
            # Clear memory
            del outputs
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    # Convert to numpy array and build FAISS index
    embeddings = np.array(embeddings, dtype=np.float32)  # Ensure float32 type
    dimension = embeddings.shape[1]
    
    # Normalize the vectors manually
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    embeddings = embeddings / norms
    
    # Create FAISS index
    index = faiss.IndexFlatIP(dimension)
    index.add(embeddings)
    
    # Create and save the dataset
    dataset = Dataset.from_dict({
        "text": texts,
        "embeddings": embeddings.tolist(),  # Convert to list for storage
        "title": [p["title"] for p in papers]
    })
    
    # Create directory if it doesn't exist
    os.makedirs(dataset_dir, exist_ok=True)
    
    # Save dataset
    dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
    logging.info(f"Dataset saved to {dataset_dir}")
    return dataset_dir