File size: 8,218 Bytes
99637f2
 
 
f91cc3b
 
0f8445a
 
8108db5
4a9703a
 
 
 
8108db5
 
 
99637f2
0452175
 
 
 
99637f2
8108db5
f68ac31
cc0b0d6
cc41495
4a9703a
cc41495
cc0b0d6
62b3157
 
 
 
 
4a9703a
 
 
 
 
 
 
 
cc41495
 
 
cc0b0d6
f68ac31
cc0b0d6
4a9703a
cc0b0d6
f68ac31
cc0b0d6
cc41495
 
 
 
 
62b3157
cc41495
62b3157
cc41495
4a9703a
cc41495
62b3157
4a9703a
 
cc41495
 
 
 
62b3157
cc41495
 
 
 
99637f2
4a9703a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0452175
8108db5
54a5022
 
 
 
 
 
 
 
 
 
 
 
f99a008
42d1dd5
 
 
 
 
f68ac31
 
f99a008
f68ac31
 
f99a008
 
42d1dd5
 
 
 
 
 
 
 
 
 
f68ac31
42d1dd5
 
 
 
 
 
 
 
92c1c48
42d1dd5
 
92c1c48
 
 
f68ac31
92c1c48
 
42d1dd5
f68ac31
42d1dd5
f68ac31
42d1dd5
92c1c48
42d1dd5
f68ac31
 
42d1dd5
f68ac31
42d1dd5
 
 
 
f68ac31
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import numpy as np
import faiss
import arxiv
from datasets import Dataset
import os
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
import torch
import logging
import requests
from datetime import datetime
import xml.etree.ElementTree as ET
from time import sleep

# Configure logging
logging.basicConfig(level=logging.INFO)

# Define data paths
DATA_DIR = os.getenv("DATA_DIR", "/data" if os.path.exists("/data") else ".")
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")

def fetch_arxiv_papers(query, max_results=10):
    """Fetch papers from arXiv and format them for RAG"""
    client = arxiv.Client()
    
    # Clean and prepare the search query
    query = query.replace('and', '').strip()
    terms = [term.strip() for term in query.split() if term.strip()]
    
    # Always include autism in the search
    if 'autism' not in [t.lower() for t in terms]:
        terms.insert(0, 'autism')
    
    # Create search query with required autism term
    term_queries = []
    for term in terms:
        if term.lower() != "autism":
            term_queries.append(f'abs:"{term}" OR ti:"{term}"')
    
    search_query = '(abs:"autism" OR ti:"autism")'
    if term_queries:
        search_query += f' AND ({" OR ".join(term_queries)})'
    search_query = f'({search_query}) AND (cat:q-bio* OR cat:med*)'
    
    logging.info(f"Searching arXiv with query: {search_query}")
    
    search = arxiv.Search(
        query=search_query,
        max_results=max_results * 2,
        sort_by=arxiv.SortCriterion.Relevance
    )
    
    try:
        results = list(client.results(search))
        papers = []
        
        for i, result in enumerate(results):
            # Only include papers that mention autism
            text = (result.title + " " + result.summary).lower()
            if 'autism' in text:
                papers.append({
                    "id": f"arxiv_{i}",
                    "text": result.summary,
                    "title": result.title,
                    "url": result.entry_id,
                    "published": result.published.strftime("%Y-%m-%d")
                })
                if len(papers) >= max_results:
                    break
        
        logging.info(f"Found {len(papers)} relevant papers about autism from arXiv")
        return papers
    except Exception as e:
        logging.error(f"Error fetching papers from arXiv: {str(e)}")
        return []

def fetch_pubmed_papers(query, max_results=10):
    """Fetch papers from PubMed using E-utilities"""
    base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
    
    # Search for papers
    search_url = f"{base_url}/esearch.fcgi"
    search_params = {
        'db': 'pubmed',
        'term': f"{query} AND autism",
        'retmax': max_results,
        'sort': 'relevance',
        'retmode': 'xml'
    }
    
    try:
        # Get paper IDs
        response = requests.get(search_url, params=search_params)
        root = ET.fromstring(response.content)
        id_list = [id_elem.text for id_elem in root.findall('.//Id')]
        
        if not id_list:
            return []
        
        # Fetch paper details
        fetch_url = f"{base_url}/efetch.fcgi"
        fetch_params = {
            'db': 'pubmed',
            'id': ','.join(id_list),
            'retmode': 'xml'
        }
        
        response = requests.get(fetch_url, params=fetch_params)
        root = ET.fromstring(response.content)
        papers = []
        
        for article in root.findall('.//PubmedArticle'):
            try:
                # Extract article information
                title = article.find('.//ArticleTitle').text
                abstract = article.find('.//Abstract/AbstractText')
                abstract = abstract.text if abstract is not None else ""
                
                if 'autism' in (title + abstract).lower():
                    pmid = article.find('.//PMID').text
                    date = article.find('.//PubDate')
                    year = date.find('Year').text if date.find('Year') is not None else "Unknown"
                    
                    papers.append({
                        "id": f"pubmed_{pmid}",
                        "text": abstract,
                        "title": title,
                        "url": f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/",
                        "published": year
                    })
            except Exception as e:
                logging.warning(f"Error processing PubMed article: {str(e)}")
                continue
        
        logging.info(f"Found {len(papers)} relevant papers from PubMed")
        return papers
        
    except Exception as e:
        logging.error(f"Error fetching papers from PubMed: {str(e)}")
        return []

def fetch_papers(query, max_results=10):
    """Fetch papers from both arXiv and PubMed"""
    arxiv_papers = fetch_arxiv_papers(query, max_results=max_results)
    sleep(1)  # Respect rate limits
    pubmed_papers = fetch_pubmed_papers(query, max_results=max_results)
    
    # Combine and deduplicate papers based on title similarity
    all_papers = arxiv_papers + pubmed_papers
    unique_papers = []
    seen_titles = set()
    
    for paper in all_papers:
        title_lower = paper['title'].lower()
        if not any(title_lower in seen_title or seen_title in title_lower for seen_title in seen_titles):
            unique_papers.append(paper)
            seen_titles.add(title_lower)
    
    # Sort by relevance (papers with 'autism' in title first)
    unique_papers.sort(key=lambda x: 'autism' in x['title'].lower(), reverse=True)
    return unique_papers[:max_results]

def build_faiss_index(papers, dataset_dir=DATASET_DIR):
    """Build and save dataset with FAISS index for RAG"""
    if not papers:
        logging.warning("No papers found. Creating empty dataset.")
        # Create an empty dataset with the expected structure
        dataset = Dataset.from_dict({
            "text": [],
            "embeddings": [],
            "title": []
        })
        os.makedirs(dataset_dir, exist_ok=True)
        dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
        return dataset_dir

    # Initialize smaller DPR encoder
    ctx_encoder = DPRContextEncoder.from_pretrained(
        "facebook/dpr-ctx_encoder-single-nq-base",
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True
    )
    ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
    
    # Create embeddings with smaller batches and memory optimization
    texts = [p["text"] for p in papers]
    embeddings = []
    batch_size = 4  # Smaller batch size
    
    with torch.inference_mode():
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i + batch_size]
            inputs = ctx_tokenizer(
                batch_texts,
                max_length=256,  # Reduced from default
                padding=True,
                truncation=True,
                return_tensors="pt"
            )
            outputs = ctx_encoder(**inputs)
            embeddings.extend(outputs.pooler_output.cpu().numpy())
            
            # Clear memory
            del outputs
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    # Convert to numpy array and build FAISS index
    embeddings = np.array(embeddings, dtype=np.float32)  # Ensure float32 type
    dimension = embeddings.shape[1]
    
    # Normalize the vectors manually
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    embeddings = embeddings / norms
    
    # Create FAISS index
    index = faiss.IndexFlatIP(dimension)
    index.add(embeddings)
    
    # Create and save the dataset
    dataset = Dataset.from_dict({
        "text": texts,
        "embeddings": embeddings.tolist(),  # Convert to list for storage
        "title": [p["title"] for p in papers]
    })
    
    # Create directory if it doesn't exist
    os.makedirs(dataset_dir, exist_ok=True)
    
    # Save dataset
    dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
    logging.info(f"Dataset saved to {dataset_dir}")
    return dataset_dir