Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import faiss | |
| import arxiv | |
| from datasets import Dataset | |
| import os | |
| from transformers import DPRContextEncoder, DPRContextEncoderTokenizer | |
| import torch | |
| import logging | |
| import requests | |
| from datetime import datetime | |
| import xml.etree.ElementTree as ET | |
| from time import sleep | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| # Define data paths | |
| DATA_DIR = os.getenv("DATA_DIR", "/data" if os.path.exists("/data") else ".") | |
| DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset") | |
| def fetch_arxiv_papers(query, max_results=10): | |
| """Fetch papers from arXiv and format them for RAG""" | |
| client = arxiv.Client() | |
| # Clean and prepare the search query | |
| query = query.replace('and', '').strip() | |
| terms = [term.strip() for term in query.split() if term.strip()] | |
| # Always include autism in the search | |
| if 'autism' not in [t.lower() for t in terms]: | |
| terms.insert(0, 'autism') | |
| # Create search query with required autism term | |
| term_queries = [] | |
| for term in terms: | |
| if term.lower() != "autism": | |
| term_queries.append(f'abs:"{term}" OR ti:"{term}"') | |
| search_query = '(abs:"autism" OR ti:"autism")' | |
| if term_queries: | |
| search_query += f' AND ({" OR ".join(term_queries)})' | |
| search_query = f'({search_query}) AND (cat:q-bio* OR cat:med*)' | |
| logging.info(f"Searching arXiv with query: {search_query}") | |
| search = arxiv.Search( | |
| query=search_query, | |
| max_results=max_results * 2, | |
| sort_by=arxiv.SortCriterion.Relevance | |
| ) | |
| try: | |
| results = list(client.results(search)) | |
| papers = [] | |
| for i, result in enumerate(results): | |
| # Only include papers that mention autism | |
| text = (result.title + " " + result.summary).lower() | |
| if 'autism' in text: | |
| papers.append({ | |
| "id": f"arxiv_{i}", | |
| "text": result.summary, | |
| "title": result.title, | |
| "url": result.entry_id, | |
| "published": result.published.strftime("%Y-%m-%d") | |
| }) | |
| if len(papers) >= max_results: | |
| break | |
| logging.info(f"Found {len(papers)} relevant papers about autism from arXiv") | |
| return papers | |
| except Exception as e: | |
| logging.error(f"Error fetching papers from arXiv: {str(e)}") | |
| return [] | |
| def fetch_pubmed_papers(query, max_results=10): | |
| """Fetch papers from PubMed using E-utilities""" | |
| base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils" | |
| # Search for papers | |
| search_url = f"{base_url}/esearch.fcgi" | |
| search_params = { | |
| 'db': 'pubmed', | |
| 'term': f"{query} AND autism", | |
| 'retmax': max_results, | |
| 'sort': 'relevance', | |
| 'retmode': 'xml' | |
| } | |
| try: | |
| # Get paper IDs | |
| response = requests.get(search_url, params=search_params) | |
| root = ET.fromstring(response.content) | |
| id_list = [id_elem.text for id_elem in root.findall('.//Id')] | |
| if not id_list: | |
| return [] | |
| # Fetch paper details | |
| fetch_url = f"{base_url}/efetch.fcgi" | |
| fetch_params = { | |
| 'db': 'pubmed', | |
| 'id': ','.join(id_list), | |
| 'retmode': 'xml' | |
| } | |
| response = requests.get(fetch_url, params=fetch_params) | |
| root = ET.fromstring(response.content) | |
| papers = [] | |
| for article in root.findall('.//PubmedArticle'): | |
| try: | |
| # Extract article information | |
| title = article.find('.//ArticleTitle').text | |
| abstract = article.find('.//Abstract/AbstractText') | |
| abstract = abstract.text if abstract is not None else "" | |
| if 'autism' in (title + abstract).lower(): | |
| pmid = article.find('.//PMID').text | |
| date = article.find('.//PubDate') | |
| year = date.find('Year').text if date.find('Year') is not None else "Unknown" | |
| papers.append({ | |
| "id": f"pubmed_{pmid}", | |
| "text": abstract, | |
| "title": title, | |
| "url": f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/", | |
| "published": year | |
| }) | |
| except Exception as e: | |
| logging.warning(f"Error processing PubMed article: {str(e)}") | |
| continue | |
| logging.info(f"Found {len(papers)} relevant papers from PubMed") | |
| return papers | |
| except Exception as e: | |
| logging.error(f"Error fetching papers from PubMed: {str(e)}") | |
| return [] | |
| def fetch_papers(query, max_results=10): | |
| """Fetch papers from both arXiv and PubMed""" | |
| arxiv_papers = fetch_arxiv_papers(query, max_results=max_results) | |
| sleep(1) # Respect rate limits | |
| pubmed_papers = fetch_pubmed_papers(query, max_results=max_results) | |
| # Combine and deduplicate papers based on title similarity | |
| all_papers = arxiv_papers + pubmed_papers | |
| unique_papers = [] | |
| seen_titles = set() | |
| for paper in all_papers: | |
| title_lower = paper['title'].lower() | |
| if not any(title_lower in seen_title or seen_title in title_lower for seen_title in seen_titles): | |
| unique_papers.append(paper) | |
| seen_titles.add(title_lower) | |
| # Sort by relevance (papers with 'autism' in title first) | |
| unique_papers.sort(key=lambda x: 'autism' in x['title'].lower(), reverse=True) | |
| return unique_papers[:max_results] | |
| def build_faiss_index(papers, dataset_dir=DATASET_DIR): | |
| """Build and save dataset with FAISS index for RAG""" | |
| if not papers: | |
| logging.warning("No papers found. Creating empty dataset.") | |
| # Create an empty dataset with the expected structure | |
| dataset = Dataset.from_dict({ | |
| "text": [], | |
| "embeddings": [], | |
| "title": [] | |
| }) | |
| os.makedirs(dataset_dir, exist_ok=True) | |
| dataset.save_to_disk(os.path.join(dataset_dir, "dataset")) | |
| return dataset_dir | |
| # Initialize smaller DPR encoder | |
| ctx_encoder = DPRContextEncoder.from_pretrained( | |
| "facebook/dpr-ctx_encoder-single-nq-base", | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True | |
| ) | |
| ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base") | |
| # Create embeddings with smaller batches and memory optimization | |
| texts = [p["text"] for p in papers] | |
| embeddings = [] | |
| batch_size = 4 # Smaller batch size | |
| with torch.inference_mode(): | |
| for i in range(0, len(texts), batch_size): | |
| batch_texts = texts[i:i + batch_size] | |
| inputs = ctx_tokenizer( | |
| batch_texts, | |
| max_length=256, # Reduced from default | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| outputs = ctx_encoder(**inputs) | |
| embeddings.extend(outputs.pooler_output.cpu().numpy()) | |
| # Clear memory | |
| del outputs | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Convert to numpy array and build FAISS index | |
| embeddings = np.array(embeddings, dtype=np.float32) # Ensure float32 type | |
| dimension = embeddings.shape[1] | |
| # Normalize the vectors manually | |
| norms = np.linalg.norm(embeddings, axis=1, keepdims=True) | |
| embeddings = embeddings / norms | |
| # Create FAISS index | |
| index = faiss.IndexFlatIP(dimension) | |
| index.add(embeddings) | |
| # Create and save the dataset | |
| dataset = Dataset.from_dict({ | |
| "text": texts, | |
| "embeddings": embeddings.tolist(), # Convert to list for storage | |
| "title": [p["title"] for p in papers] | |
| }) | |
| # Create directory if it doesn't exist | |
| os.makedirs(dataset_dir, exist_ok=True) | |
| # Save dataset | |
| dataset.save_to_disk(os.path.join(dataset_dir, "dataset")) | |
| logging.info(f"Dataset saved to {dataset_dir}") | |
| return dataset_dir | |