Spaces:
Sleeping
Sleeping
fix: response
Browse files
app.py
CHANGED
@@ -16,7 +16,7 @@ logging.basicConfig(level=logging.INFO)
|
|
16 |
DATA_DIR = "/data" if os.path.exists("/data") else "."
|
17 |
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
|
18 |
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
|
19 |
-
MODEL_PATH = "
|
20 |
|
21 |
@st.cache_resource
|
22 |
def load_local_model():
|
@@ -24,7 +24,7 @@ def load_local_model():
|
|
24 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
25 |
model = AutoModelForSeq2SeqLM.from_pretrained(
|
26 |
MODEL_PATH,
|
27 |
-
torch_dtype=torch.float32,
|
28 |
device_map="auto"
|
29 |
)
|
30 |
return model, tokenizer
|
@@ -33,8 +33,11 @@ def fetch_arxiv_papers(query, max_results=5):
|
|
33 |
"""Fetch papers from arXiv"""
|
34 |
client = arxiv.Client()
|
35 |
|
36 |
-
#
|
37 |
-
|
|
|
|
|
|
|
38 |
|
39 |
# Search arXiv
|
40 |
search = arxiv.Search(
|
@@ -49,7 +52,7 @@ def fetch_arxiv_papers(query, max_results=5):
|
|
49 |
'title': result.title,
|
50 |
'abstract': result.summary,
|
51 |
'url': result.pdf_url,
|
52 |
-
'published': result.published
|
53 |
})
|
54 |
|
55 |
return papers
|
@@ -58,11 +61,17 @@ def fetch_pubmed_papers(query, max_results=5):
|
|
58 |
"""Fetch papers from PubMed"""
|
59 |
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
# Search for papers
|
62 |
search_url = f"{base_url}/esearch.fcgi"
|
63 |
search_params = {
|
64 |
'db': 'pubmed',
|
65 |
-
'term':
|
66 |
'retmax': max_results,
|
67 |
'sort': 'relevance',
|
68 |
'retmode': 'xml'
|
@@ -92,13 +101,16 @@ def fetch_pubmed_papers(query, max_results=5):
|
|
92 |
for article in articles.findall('.//PubmedArticle'):
|
93 |
title = article.find('.//ArticleTitle')
|
94 |
abstract = article.find('.//Abstract/AbstractText')
|
|
|
|
|
95 |
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
102 |
|
103 |
except Exception as e:
|
104 |
st.error(f"Error fetching PubMed papers: {str(e)}")
|
@@ -113,12 +125,13 @@ def search_research_papers(query):
|
|
113 |
# Combine and format papers
|
114 |
all_papers = []
|
115 |
for paper in arxiv_papers + pubmed_papers:
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
122 |
|
123 |
return pd.DataFrame(all_papers)
|
124 |
|
@@ -127,38 +140,50 @@ def generate_answer(question, context, max_length=512):
|
|
127 |
model, tokenizer = load_local_model()
|
128 |
|
129 |
# Format the context as a structured query
|
130 |
-
prompt = f"""
|
131 |
-
|
132 |
-
Question: {question}
|
133 |
|
134 |
Research Context:
|
135 |
{context}
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
|
|
|
|
|
140 |
3. Clinical implications
|
141 |
-
4. Limitations
|
142 |
|
143 |
-
|
144 |
|
145 |
# Generate response
|
146 |
-
inputs = tokenizer(prompt, return_tensors="pt", max_length=
|
147 |
|
148 |
with torch.inference_mode():
|
149 |
outputs = model.generate(
|
150 |
**inputs,
|
151 |
max_length=max_length,
|
152 |
-
|
|
|
|
|
153 |
temperature=0.7,
|
154 |
-
|
155 |
-
repetition_penalty=1.
|
156 |
early_stopping=True
|
157 |
)
|
158 |
|
159 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
160 |
|
161 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ")
|
163 |
|
164 |
return formatted_response
|
@@ -181,10 +206,6 @@ if query:
|
|
181 |
st.write("Searching for data in PubMed and arXiv...")
|
182 |
st.write(f"Found {len(df)} relevant papers!")
|
183 |
|
184 |
-
# Display paper sources
|
185 |
-
for _, paper in df.iterrows():
|
186 |
-
st.markdown(f"- [{paper['title']}]({paper['url']}) ({paper['published']})")
|
187 |
-
|
188 |
# Get relevant context
|
189 |
context = "\n".join([
|
190 |
f"{text[:1000]}" for text in df['text'].head(3)
|
@@ -193,4 +214,9 @@ if query:
|
|
193 |
# Generate answer
|
194 |
st.write("Generating answer...")
|
195 |
answer = generate_answer(query, context)
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
16 |
DATA_DIR = "/data" if os.path.exists("/data") else "."
|
17 |
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
|
18 |
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
|
19 |
+
MODEL_PATH = "facebook/bart-large-cnn" # Changed to BART model which is better for summarization
|
20 |
|
21 |
@st.cache_resource
|
22 |
def load_local_model():
|
|
|
24 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
25 |
model = AutoModelForSeq2SeqLM.from_pretrained(
|
26 |
MODEL_PATH,
|
27 |
+
torch_dtype=torch.float32,
|
28 |
device_map="auto"
|
29 |
)
|
30 |
return model, tokenizer
|
|
|
33 |
"""Fetch papers from arXiv"""
|
34 |
client = arxiv.Client()
|
35 |
|
36 |
+
# Ensure query includes autism-related terms
|
37 |
+
if 'autism' not in query.lower():
|
38 |
+
search_query = f"(ti:{query} OR abs:{query}) AND (ti:autism OR abs:autism) AND cat:q-bio"
|
39 |
+
else:
|
40 |
+
search_query = f"(ti:{query} OR abs:{query}) AND cat:q-bio"
|
41 |
|
42 |
# Search arXiv
|
43 |
search = arxiv.Search(
|
|
|
52 |
'title': result.title,
|
53 |
'abstract': result.summary,
|
54 |
'url': result.pdf_url,
|
55 |
+
'published': result.published.strftime("%Y-%m-%d")
|
56 |
})
|
57 |
|
58 |
return papers
|
|
|
61 |
"""Fetch papers from PubMed"""
|
62 |
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
|
63 |
|
64 |
+
# Ensure query includes autism-related terms
|
65 |
+
if 'autism' not in query.lower():
|
66 |
+
search_term = f"({query}) AND (autism[Title/Abstract] OR ASD[Title/Abstract])"
|
67 |
+
else:
|
68 |
+
search_term = query
|
69 |
+
|
70 |
# Search for papers
|
71 |
search_url = f"{base_url}/esearch.fcgi"
|
72 |
search_params = {
|
73 |
'db': 'pubmed',
|
74 |
+
'term': search_term,
|
75 |
'retmax': max_results,
|
76 |
'sort': 'relevance',
|
77 |
'retmode': 'xml'
|
|
|
101 |
for article in articles.findall('.//PubmedArticle'):
|
102 |
title = article.find('.//ArticleTitle')
|
103 |
abstract = article.find('.//Abstract/AbstractText')
|
104 |
+
year = article.find('.//PubDate/Year')
|
105 |
+
pmid = article.find('.//PMID')
|
106 |
|
107 |
+
if title is not None and abstract is not None:
|
108 |
+
papers.append({
|
109 |
+
'title': title.text,
|
110 |
+
'abstract': abstract.text,
|
111 |
+
'url': f"https://pubmed.ncbi.nlm.nih.gov/{pmid.text}/",
|
112 |
+
'published': year.text if year is not None else 'Unknown'
|
113 |
+
})
|
114 |
|
115 |
except Exception as e:
|
116 |
st.error(f"Error fetching PubMed papers: {str(e)}")
|
|
|
125 |
# Combine and format papers
|
126 |
all_papers = []
|
127 |
for paper in arxiv_papers + pubmed_papers:
|
128 |
+
if paper['abstract'] and len(paper['abstract'].strip()) > 0:
|
129 |
+
all_papers.append({
|
130 |
+
'title': paper['title'],
|
131 |
+
'text': f"Title: {paper['title']}\n\nAbstract: {paper['abstract']}",
|
132 |
+
'url': paper['url'],
|
133 |
+
'published': paper['published']
|
134 |
+
})
|
135 |
|
136 |
return pd.DataFrame(all_papers)
|
137 |
|
|
|
140 |
model, tokenizer = load_local_model()
|
141 |
|
142 |
# Format the context as a structured query
|
143 |
+
prompt = f"""Summarize the following research about autism and answer the question.
|
|
|
|
|
144 |
|
145 |
Research Context:
|
146 |
{context}
|
147 |
|
148 |
+
Question: {question}
|
149 |
+
|
150 |
+
Provide a detailed answer that includes:
|
151 |
+
1. Main findings from the research
|
152 |
+
2. Research methods used
|
153 |
3. Clinical implications
|
154 |
+
4. Limitations of the studies
|
155 |
|
156 |
+
If the research doesn't address the question directly, explain what information is missing."""
|
157 |
|
158 |
# Generate response
|
159 |
+
inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
|
160 |
|
161 |
with torch.inference_mode():
|
162 |
outputs = model.generate(
|
163 |
**inputs,
|
164 |
max_length=max_length,
|
165 |
+
min_length=200, # Ensure longer responses
|
166 |
+
num_beams=5,
|
167 |
+
length_penalty=2.0, # Encourage even longer responses
|
168 |
temperature=0.7,
|
169 |
+
no_repeat_ngram_size=3,
|
170 |
+
repetition_penalty=1.3,
|
171 |
early_stopping=True
|
172 |
)
|
173 |
|
174 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
175 |
|
176 |
+
# If response is too short or empty, provide a fallback message
|
177 |
+
if len(response.strip()) < 100:
|
178 |
+
return """I apologize, but I couldn't generate a specific answer from the research papers provided.
|
179 |
+
This might be because:
|
180 |
+
1. The research papers don't directly address your question
|
181 |
+
2. The context needs more specific information
|
182 |
+
3. The question might need to be more specific
|
183 |
+
|
184 |
+
Please try rephrasing your question or ask about a more specific aspect of autism."""
|
185 |
+
|
186 |
+
# Format the response for better readability
|
187 |
formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ")
|
188 |
|
189 |
return formatted_response
|
|
|
206 |
st.write("Searching for data in PubMed and arXiv...")
|
207 |
st.write(f"Found {len(df)} relevant papers!")
|
208 |
|
|
|
|
|
|
|
|
|
209 |
# Get relevant context
|
210 |
context = "\n".join([
|
211 |
f"{text[:1000]}" for text in df['text'].head(3)
|
|
|
214 |
# Generate answer
|
215 |
st.write("Generating answer...")
|
216 |
answer = generate_answer(query, context)
|
217 |
+
# Display paper sources
|
218 |
+
with st.expander("View source papers"):
|
219 |
+
for _, paper in df.iterrows():
|
220 |
+
st.markdown(f"- [{paper['title']}]({paper['url']}) ({paper['published']})")
|
221 |
+
st.success("Answer found!")
|
222 |
+
st.markdown(answer)
|