import numpy as np import faiss import arxiv from datasets import Dataset import os from transformers import DPRContextEncoder, DPRContextEncoderTokenizer import torch import logging import requests from datetime import datetime import xml.etree.ElementTree as ET from time import sleep # Configure logging logging.basicConfig(level=logging.INFO) # Define data paths DATA_DIR = os.getenv("DATA_DIR", "/data" if os.path.exists("/data") else ".") DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset") def fetch_arxiv_papers(query, max_results=10): """Fetch papers from arXiv and format them for RAG""" client = arxiv.Client() # Clean and prepare the search query query = query.replace('and', '').strip() terms = [term.strip() for term in query.split() if term.strip()] # Always include autism in the search if 'autism' not in [t.lower() for t in terms]: terms.insert(0, 'autism') # Create search query with required autism term term_queries = [] for term in terms: if term.lower() != "autism": term_queries.append(f'abs:"{term}" OR ti:"{term}"') search_query = '(abs:"autism" OR ti:"autism")' if term_queries: search_query += f' AND ({" OR ".join(term_queries)})' search_query = f'({search_query}) AND (cat:q-bio* OR cat:med*)' logging.info(f"Searching arXiv with query: {search_query}") search = arxiv.Search( query=search_query, max_results=max_results * 2, sort_by=arxiv.SortCriterion.Relevance ) try: results = list(client.results(search)) papers = [] for i, result in enumerate(results): # Only include papers that mention autism text = (result.title + " " + result.summary).lower() if 'autism' in text: papers.append({ "id": f"arxiv_{i}", "text": result.summary, "title": result.title, "url": result.entry_id, "published": result.published.strftime("%Y-%m-%d") }) if len(papers) >= max_results: break logging.info(f"Found {len(papers)} relevant papers about autism from arXiv") return papers except Exception as e: logging.error(f"Error fetching papers from arXiv: {str(e)}") return [] def fetch_pubmed_papers(query, max_results=10): """Fetch papers from PubMed using E-utilities""" base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils" # Search for papers search_url = f"{base_url}/esearch.fcgi" search_params = { 'db': 'pubmed', 'term': f"{query} AND autism", 'retmax': max_results, 'sort': 'relevance', 'retmode': 'xml' } try: # Get paper IDs response = requests.get(search_url, params=search_params) root = ET.fromstring(response.content) id_list = [id_elem.text for id_elem in root.findall('.//Id')] if not id_list: return [] # Fetch paper details fetch_url = f"{base_url}/efetch.fcgi" fetch_params = { 'db': 'pubmed', 'id': ','.join(id_list), 'retmode': 'xml' } response = requests.get(fetch_url, params=fetch_params) root = ET.fromstring(response.content) papers = [] for article in root.findall('.//PubmedArticle'): try: # Extract article information title = article.find('.//ArticleTitle').text abstract = article.find('.//Abstract/AbstractText') abstract = abstract.text if abstract is not None else "" if 'autism' in (title + abstract).lower(): pmid = article.find('.//PMID').text date = article.find('.//PubDate') year = date.find('Year').text if date.find('Year') is not None else "Unknown" papers.append({ "id": f"pubmed_{pmid}", "text": abstract, "title": title, "url": f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/", "published": year }) except Exception as e: logging.warning(f"Error processing PubMed article: {str(e)}") continue logging.info(f"Found {len(papers)} relevant papers from PubMed") return papers except Exception as e: logging.error(f"Error fetching papers from PubMed: {str(e)}") return [] def fetch_papers(query, max_results=10): """Fetch papers from both arXiv and PubMed""" arxiv_papers = fetch_arxiv_papers(query, max_results=max_results) sleep(1) # Respect rate limits pubmed_papers = fetch_pubmed_papers(query, max_results=max_results) # Combine and deduplicate papers based on title similarity all_papers = arxiv_papers + pubmed_papers unique_papers = [] seen_titles = set() for paper in all_papers: title_lower = paper['title'].lower() if not any(title_lower in seen_title or seen_title in title_lower for seen_title in seen_titles): unique_papers.append(paper) seen_titles.add(title_lower) # Sort by relevance (papers with 'autism' in title first) unique_papers.sort(key=lambda x: 'autism' in x['title'].lower(), reverse=True) return unique_papers[:max_results] def build_faiss_index(papers, dataset_dir=DATASET_DIR): """Build and save dataset with FAISS index for RAG""" if not papers: logging.warning("No papers found. Creating empty dataset.") # Create an empty dataset with the expected structure dataset = Dataset.from_dict({ "text": [], "embeddings": [], "title": [] }) os.makedirs(dataset_dir, exist_ok=True) dataset.save_to_disk(os.path.join(dataset_dir, "dataset")) return dataset_dir # Initialize smaller DPR encoder ctx_encoder = DPRContextEncoder.from_pretrained( "facebook/dpr-ctx_encoder-single-nq-base", torch_dtype=torch.float16, low_cpu_mem_usage=True ) ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base") # Create embeddings with smaller batches and memory optimization texts = [p["text"] for p in papers] embeddings = [] batch_size = 4 # Smaller batch size with torch.inference_mode(): for i in range(0, len(texts), batch_size): batch_texts = texts[i:i + batch_size] inputs = ctx_tokenizer( batch_texts, max_length=256, # Reduced from default padding=True, truncation=True, return_tensors="pt" ) outputs = ctx_encoder(**inputs) embeddings.extend(outputs.pooler_output.cpu().numpy()) # Clear memory del outputs if torch.cuda.is_available(): torch.cuda.empty_cache() # Convert to numpy array and build FAISS index embeddings = np.array(embeddings, dtype=np.float32) # Ensure float32 type dimension = embeddings.shape[1] # Normalize the vectors manually norms = np.linalg.norm(embeddings, axis=1, keepdims=True) embeddings = embeddings / norms # Create FAISS index index = faiss.IndexFlatIP(dimension) index.add(embeddings) # Create and save the dataset dataset = Dataset.from_dict({ "text": texts, "embeddings": embeddings.tolist(), # Convert to list for storage "title": [p["title"] for p in papers] }) # Create directory if it doesn't exist os.makedirs(dataset_dir, exist_ok=True) # Save dataset dataset.save_to_disk(os.path.join(dataset_dir, "dataset")) logging.info(f"Dataset saved to {dataset_dir}") return dataset_dir