Spaces:
Sleeping
Sleeping
fix: daiss
Browse files- app.py +2 -2
- faiss_index/index.py +103 -6
- requirements.txt +2 -1
app.py
CHANGED
@@ -31,7 +31,7 @@ def load_models():
|
|
31 |
@st.cache_data(ttl=3600) # Cache for 1 hour
|
32 |
def load_dataset(query):
|
33 |
# Always fetch fresh results for the specific query
|
34 |
-
with st.spinner("Searching
|
35 |
import faiss_index.index as idx
|
36 |
# Ensure both autism and the query terms are included
|
37 |
if 'autism' not in query.lower():
|
@@ -39,7 +39,7 @@ def load_dataset(query):
|
|
39 |
else:
|
40 |
search_query = query
|
41 |
|
42 |
-
papers = idx.
|
43 |
|
44 |
if not papers:
|
45 |
st.warning("No relevant papers found. Please try rephrasing your question.")
|
|
|
31 |
@st.cache_data(ttl=3600) # Cache for 1 hour
|
32 |
def load_dataset(query):
|
33 |
# Always fetch fresh results for the specific query
|
34 |
+
with st.spinner("Searching research papers from arXiv and PubMed..."):
|
35 |
import faiss_index.index as idx
|
36 |
# Ensure both autism and the query terms are included
|
37 |
if 'autism' not in query.lower():
|
|
|
39 |
else:
|
40 |
search_query = query
|
41 |
|
42 |
+
papers = idx.fetch_papers(search_query, max_results=25) # This now fetches from both sources
|
43 |
|
44 |
if not papers:
|
45 |
st.warning("No relevant papers found. Please try rephrasing your question.")
|
faiss_index/index.py
CHANGED
@@ -6,6 +6,10 @@ import os
|
|
6 |
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
|
7 |
import torch
|
8 |
import logging
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# Configure logging
|
11 |
logging.basicConfig(level=logging.INFO)
|
@@ -19,7 +23,7 @@ def fetch_arxiv_papers(query, max_results=10):
|
|
19 |
client = arxiv.Client()
|
20 |
|
21 |
# Clean and prepare the search query
|
22 |
-
query = query.replace('and', '').strip()
|
23 |
terms = [term.strip() for term in query.split() if term.strip()]
|
24 |
|
25 |
# Always include autism in the search
|
@@ -27,14 +31,21 @@ def fetch_arxiv_papers(query, max_results=10):
|
|
27 |
terms.insert(0, 'autism')
|
28 |
|
29 |
# Create search query with required autism term
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
search_query = f'({search_query}) AND (cat:q-bio* OR cat:med*)'
|
32 |
|
33 |
logging.info(f"Searching arXiv with query: {search_query}")
|
34 |
|
35 |
search = arxiv.Search(
|
36 |
query=search_query,
|
37 |
-
max_results=max_results * 2,
|
38 |
sort_by=arxiv.SortCriterion.Relevance
|
39 |
)
|
40 |
|
@@ -47,11 +58,11 @@ def fetch_arxiv_papers(query, max_results=10):
|
|
47 |
text = (result.title + " " + result.summary).lower()
|
48 |
if 'autism' in text:
|
49 |
papers.append({
|
50 |
-
"id":
|
51 |
"text": result.summary,
|
52 |
"title": result.title,
|
53 |
-
"url": result.entry_id,
|
54 |
-
"published": result.published.strftime("%Y-%m-%d")
|
55 |
})
|
56 |
if len(papers) >= max_results:
|
57 |
break
|
@@ -62,6 +73,92 @@ def fetch_arxiv_papers(query, max_results=10):
|
|
62 |
logging.error(f"Error fetching papers from arXiv: {str(e)}")
|
63 |
return []
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
def build_faiss_index(papers, dataset_dir=DATASET_DIR):
|
66 |
"""Build and save dataset with FAISS index for RAG"""
|
67 |
if not papers:
|
|
|
6 |
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
|
7 |
import torch
|
8 |
import logging
|
9 |
+
import requests
|
10 |
+
from datetime import datetime
|
11 |
+
import xml.etree.ElementTree as ET
|
12 |
+
from time import sleep
|
13 |
|
14 |
# Configure logging
|
15 |
logging.basicConfig(level=logging.INFO)
|
|
|
23 |
client = arxiv.Client()
|
24 |
|
25 |
# Clean and prepare the search query
|
26 |
+
query = query.replace('and', '').strip()
|
27 |
terms = [term.strip() for term in query.split() if term.strip()]
|
28 |
|
29 |
# Always include autism in the search
|
|
|
31 |
terms.insert(0, 'autism')
|
32 |
|
33 |
# Create search query with required autism term
|
34 |
+
term_queries = []
|
35 |
+
for term in terms:
|
36 |
+
if term.lower() != "autism":
|
37 |
+
term_queries.append(f'abs:"{term}" OR ti:"{term}"')
|
38 |
+
|
39 |
+
search_query = '(abs:"autism" OR ti:"autism")'
|
40 |
+
if term_queries:
|
41 |
+
search_query += f' AND ({" OR ".join(term_queries)})'
|
42 |
search_query = f'({search_query}) AND (cat:q-bio* OR cat:med*)'
|
43 |
|
44 |
logging.info(f"Searching arXiv with query: {search_query}")
|
45 |
|
46 |
search = arxiv.Search(
|
47 |
query=search_query,
|
48 |
+
max_results=max_results * 2,
|
49 |
sort_by=arxiv.SortCriterion.Relevance
|
50 |
)
|
51 |
|
|
|
58 |
text = (result.title + " " + result.summary).lower()
|
59 |
if 'autism' in text:
|
60 |
papers.append({
|
61 |
+
"id": f"arxiv_{i}",
|
62 |
"text": result.summary,
|
63 |
"title": result.title,
|
64 |
+
"url": result.entry_id,
|
65 |
+
"published": result.published.strftime("%Y-%m-%d")
|
66 |
})
|
67 |
if len(papers) >= max_results:
|
68 |
break
|
|
|
73 |
logging.error(f"Error fetching papers from arXiv: {str(e)}")
|
74 |
return []
|
75 |
|
76 |
+
def fetch_pubmed_papers(query, max_results=10):
|
77 |
+
"""Fetch papers from PubMed using E-utilities"""
|
78 |
+
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
|
79 |
+
|
80 |
+
# Search for papers
|
81 |
+
search_url = f"{base_url}/esearch.fcgi"
|
82 |
+
search_params = {
|
83 |
+
'db': 'pubmed',
|
84 |
+
'term': f"{query} AND autism",
|
85 |
+
'retmax': max_results,
|
86 |
+
'sort': 'relevance',
|
87 |
+
'retmode': 'xml'
|
88 |
+
}
|
89 |
+
|
90 |
+
try:
|
91 |
+
# Get paper IDs
|
92 |
+
response = requests.get(search_url, params=search_params)
|
93 |
+
root = ET.fromstring(response.content)
|
94 |
+
id_list = [id_elem.text for id_elem in root.findall('.//Id')]
|
95 |
+
|
96 |
+
if not id_list:
|
97 |
+
return []
|
98 |
+
|
99 |
+
# Fetch paper details
|
100 |
+
fetch_url = f"{base_url}/efetch.fcgi"
|
101 |
+
fetch_params = {
|
102 |
+
'db': 'pubmed',
|
103 |
+
'id': ','.join(id_list),
|
104 |
+
'retmode': 'xml'
|
105 |
+
}
|
106 |
+
|
107 |
+
response = requests.get(fetch_url, params=fetch_params)
|
108 |
+
root = ET.fromstring(response.content)
|
109 |
+
papers = []
|
110 |
+
|
111 |
+
for article in root.findall('.//PubmedArticle'):
|
112 |
+
try:
|
113 |
+
# Extract article information
|
114 |
+
title = article.find('.//ArticleTitle').text
|
115 |
+
abstract = article.find('.//Abstract/AbstractText')
|
116 |
+
abstract = abstract.text if abstract is not None else ""
|
117 |
+
|
118 |
+
if 'autism' in (title + abstract).lower():
|
119 |
+
pmid = article.find('.//PMID').text
|
120 |
+
date = article.find('.//PubDate')
|
121 |
+
year = date.find('Year').text if date.find('Year') is not None else "Unknown"
|
122 |
+
|
123 |
+
papers.append({
|
124 |
+
"id": f"pubmed_{pmid}",
|
125 |
+
"text": abstract,
|
126 |
+
"title": title,
|
127 |
+
"url": f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/",
|
128 |
+
"published": year
|
129 |
+
})
|
130 |
+
except Exception as e:
|
131 |
+
logging.warning(f"Error processing PubMed article: {str(e)}")
|
132 |
+
continue
|
133 |
+
|
134 |
+
logging.info(f"Found {len(papers)} relevant papers from PubMed")
|
135 |
+
return papers
|
136 |
+
|
137 |
+
except Exception as e:
|
138 |
+
logging.error(f"Error fetching papers from PubMed: {str(e)}")
|
139 |
+
return []
|
140 |
+
|
141 |
+
def fetch_papers(query, max_results=10):
|
142 |
+
"""Fetch papers from both arXiv and PubMed"""
|
143 |
+
arxiv_papers = fetch_arxiv_papers(query, max_results=max_results)
|
144 |
+
sleep(1) # Respect rate limits
|
145 |
+
pubmed_papers = fetch_pubmed_papers(query, max_results=max_results)
|
146 |
+
|
147 |
+
# Combine and deduplicate papers based on title similarity
|
148 |
+
all_papers = arxiv_papers + pubmed_papers
|
149 |
+
unique_papers = []
|
150 |
+
seen_titles = set()
|
151 |
+
|
152 |
+
for paper in all_papers:
|
153 |
+
title_lower = paper['title'].lower()
|
154 |
+
if not any(title_lower in seen_title or seen_title in title_lower for seen_title in seen_titles):
|
155 |
+
unique_papers.append(paper)
|
156 |
+
seen_titles.add(title_lower)
|
157 |
+
|
158 |
+
# Sort by relevance (papers with 'autism' in title first)
|
159 |
+
unique_papers.sort(key=lambda x: 'autism' in x['title'].lower(), reverse=True)
|
160 |
+
return unique_papers[:max_results]
|
161 |
+
|
162 |
def build_faiss_index(papers, dataset_dir=DATASET_DIR):
|
163 |
"""Build and save dataset with FAISS index for RAG"""
|
164 |
if not papers:
|
requirements.txt
CHANGED
@@ -9,4 +9,5 @@ torch>=2.2.0
|
|
9 |
accelerate>=0.26.0
|
10 |
bitsandbytes>=0.41.1
|
11 |
numpy>=1.24.0
|
12 |
-
pandas>=2.2.0
|
|
|
|
9 |
accelerate>=0.26.0
|
10 |
bitsandbytes>=0.41.1
|
11 |
numpy>=1.24.0
|
12 |
+
pandas>=2.2.0
|
13 |
+
requests>=2.31.0
|