wakeupmh commited on
Commit
97889da
·
1 Parent(s): 6782472

fix: response

Browse files
Files changed (1) hide show
  1. app.py +62 -36
app.py CHANGED
@@ -16,7 +16,7 @@ logging.basicConfig(level=logging.INFO)
16
  DATA_DIR = "/data" if os.path.exists("/data") else "."
17
  DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
18
  DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
19
- MODEL_PATH = "google/flan-t5-base" # Lighter model
20
 
21
  @st.cache_resource
22
  def load_local_model():
@@ -24,7 +24,7 @@ def load_local_model():
24
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
25
  model = AutoModelForSeq2SeqLM.from_pretrained(
26
  MODEL_PATH,
27
- torch_dtype=torch.float32, # Using float32 for CPU compatibility
28
  device_map="auto"
29
  )
30
  return model, tokenizer
@@ -33,8 +33,11 @@ def fetch_arxiv_papers(query, max_results=5):
33
  """Fetch papers from arXiv"""
34
  client = arxiv.Client()
35
 
36
- # Clean and prepare the search query
37
- search_query = f"ti:{query} OR abs:{query} AND cat:q-bio"
 
 
 
38
 
39
  # Search arXiv
40
  search = arxiv.Search(
@@ -49,7 +52,7 @@ def fetch_arxiv_papers(query, max_results=5):
49
  'title': result.title,
50
  'abstract': result.summary,
51
  'url': result.pdf_url,
52
- 'published': result.published
53
  })
54
 
55
  return papers
@@ -58,11 +61,17 @@ def fetch_pubmed_papers(query, max_results=5):
58
  """Fetch papers from PubMed"""
59
  base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
60
 
 
 
 
 
 
 
61
  # Search for papers
62
  search_url = f"{base_url}/esearch.fcgi"
63
  search_params = {
64
  'db': 'pubmed',
65
- 'term': query,
66
  'retmax': max_results,
67
  'sort': 'relevance',
68
  'retmode': 'xml'
@@ -92,13 +101,16 @@ def fetch_pubmed_papers(query, max_results=5):
92
  for article in articles.findall('.//PubmedArticle'):
93
  title = article.find('.//ArticleTitle')
94
  abstract = article.find('.//Abstract/AbstractText')
 
 
95
 
96
- papers.append({
97
- 'title': title.text if title is not None else 'No title available',
98
- 'abstract': abstract.text if abstract is not None else 'No abstract available',
99
- 'url': f"https://pubmed.ncbi.nlm.nih.gov/{article.find('.//PMID').text}/",
100
- 'published': article.find('.//PubDate/Year').text if article.find('.//PubDate/Year') is not None else 'Unknown'
101
- })
 
102
 
103
  except Exception as e:
104
  st.error(f"Error fetching PubMed papers: {str(e)}")
@@ -113,12 +125,13 @@ def search_research_papers(query):
113
  # Combine and format papers
114
  all_papers = []
115
  for paper in arxiv_papers + pubmed_papers:
116
- all_papers.append({
117
- 'title': paper['title'],
118
- 'text': f"Title: {paper['title']}\nAbstract: {paper['abstract']}",
119
- 'url': paper['url'],
120
- 'published': paper['published']
121
- })
 
122
 
123
  return pd.DataFrame(all_papers)
124
 
@@ -127,38 +140,50 @@ def generate_answer(question, context, max_length=512):
127
  model, tokenizer = load_local_model()
128
 
129
  # Format the context as a structured query
130
- prompt = f"""Based on the following research papers about autism, provide a detailed answer:
131
-
132
- Question: {question}
133
 
134
  Research Context:
135
  {context}
136
 
137
- Please analyze:
138
- 1. Main findings
139
- 2. Research methods
 
 
140
  3. Clinical implications
141
- 4. Limitations
142
 
143
- Answer:"""
144
 
145
  # Generate response
146
- inputs = tokenizer(prompt, return_tensors="pt", max_length=max_length, truncation=True)
147
 
148
  with torch.inference_mode():
149
  outputs = model.generate(
150
  **inputs,
151
  max_length=max_length,
152
- num_beams=4,
 
 
153
  temperature=0.7,
154
- top_p=0.9,
155
- repetition_penalty=1.2,
156
  early_stopping=True
157
  )
158
 
159
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
160
 
161
- # Format the response
 
 
 
 
 
 
 
 
 
 
162
  formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ")
163
 
164
  return formatted_response
@@ -181,10 +206,6 @@ if query:
181
  st.write("Searching for data in PubMed and arXiv...")
182
  st.write(f"Found {len(df)} relevant papers!")
183
 
184
- # Display paper sources
185
- for _, paper in df.iterrows():
186
- st.markdown(f"- [{paper['title']}]({paper['url']}) ({paper['published']})")
187
-
188
  # Get relevant context
189
  context = "\n".join([
190
  f"{text[:1000]}" for text in df['text'].head(3)
@@ -193,4 +214,9 @@ if query:
193
  # Generate answer
194
  st.write("Generating answer...")
195
  answer = generate_answer(query, context)
196
- st.markdown(answer)
 
 
 
 
 
 
16
  DATA_DIR = "/data" if os.path.exists("/data") else "."
17
  DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
18
  DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
19
+ MODEL_PATH = "facebook/bart-large-cnn" # Changed to BART model which is better for summarization
20
 
21
  @st.cache_resource
22
  def load_local_model():
 
24
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
25
  model = AutoModelForSeq2SeqLM.from_pretrained(
26
  MODEL_PATH,
27
+ torch_dtype=torch.float32,
28
  device_map="auto"
29
  )
30
  return model, tokenizer
 
33
  """Fetch papers from arXiv"""
34
  client = arxiv.Client()
35
 
36
+ # Ensure query includes autism-related terms
37
+ if 'autism' not in query.lower():
38
+ search_query = f"(ti:{query} OR abs:{query}) AND (ti:autism OR abs:autism) AND cat:q-bio"
39
+ else:
40
+ search_query = f"(ti:{query} OR abs:{query}) AND cat:q-bio"
41
 
42
  # Search arXiv
43
  search = arxiv.Search(
 
52
  'title': result.title,
53
  'abstract': result.summary,
54
  'url': result.pdf_url,
55
+ 'published': result.published.strftime("%Y-%m-%d")
56
  })
57
 
58
  return papers
 
61
  """Fetch papers from PubMed"""
62
  base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
63
 
64
+ # Ensure query includes autism-related terms
65
+ if 'autism' not in query.lower():
66
+ search_term = f"({query}) AND (autism[Title/Abstract] OR ASD[Title/Abstract])"
67
+ else:
68
+ search_term = query
69
+
70
  # Search for papers
71
  search_url = f"{base_url}/esearch.fcgi"
72
  search_params = {
73
  'db': 'pubmed',
74
+ 'term': search_term,
75
  'retmax': max_results,
76
  'sort': 'relevance',
77
  'retmode': 'xml'
 
101
  for article in articles.findall('.//PubmedArticle'):
102
  title = article.find('.//ArticleTitle')
103
  abstract = article.find('.//Abstract/AbstractText')
104
+ year = article.find('.//PubDate/Year')
105
+ pmid = article.find('.//PMID')
106
 
107
+ if title is not None and abstract is not None:
108
+ papers.append({
109
+ 'title': title.text,
110
+ 'abstract': abstract.text,
111
+ 'url': f"https://pubmed.ncbi.nlm.nih.gov/{pmid.text}/",
112
+ 'published': year.text if year is not None else 'Unknown'
113
+ })
114
 
115
  except Exception as e:
116
  st.error(f"Error fetching PubMed papers: {str(e)}")
 
125
  # Combine and format papers
126
  all_papers = []
127
  for paper in arxiv_papers + pubmed_papers:
128
+ if paper['abstract'] and len(paper['abstract'].strip()) > 0:
129
+ all_papers.append({
130
+ 'title': paper['title'],
131
+ 'text': f"Title: {paper['title']}\n\nAbstract: {paper['abstract']}",
132
+ 'url': paper['url'],
133
+ 'published': paper['published']
134
+ })
135
 
136
  return pd.DataFrame(all_papers)
137
 
 
140
  model, tokenizer = load_local_model()
141
 
142
  # Format the context as a structured query
143
+ prompt = f"""Summarize the following research about autism and answer the question.
 
 
144
 
145
  Research Context:
146
  {context}
147
 
148
+ Question: {question}
149
+
150
+ Provide a detailed answer that includes:
151
+ 1. Main findings from the research
152
+ 2. Research methods used
153
  3. Clinical implications
154
+ 4. Limitations of the studies
155
 
156
+ If the research doesn't address the question directly, explain what information is missing."""
157
 
158
  # Generate response
159
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
160
 
161
  with torch.inference_mode():
162
  outputs = model.generate(
163
  **inputs,
164
  max_length=max_length,
165
+ min_length=200, # Ensure longer responses
166
+ num_beams=5,
167
+ length_penalty=2.0, # Encourage even longer responses
168
  temperature=0.7,
169
+ no_repeat_ngram_size=3,
170
+ repetition_penalty=1.3,
171
  early_stopping=True
172
  )
173
 
174
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
175
 
176
+ # If response is too short or empty, provide a fallback message
177
+ if len(response.strip()) < 100:
178
+ return """I apologize, but I couldn't generate a specific answer from the research papers provided.
179
+ This might be because:
180
+ 1. The research papers don't directly address your question
181
+ 2. The context needs more specific information
182
+ 3. The question might need to be more specific
183
+
184
+ Please try rephrasing your question or ask about a more specific aspect of autism."""
185
+
186
+ # Format the response for better readability
187
  formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ")
188
 
189
  return formatted_response
 
206
  st.write("Searching for data in PubMed and arXiv...")
207
  st.write(f"Found {len(df)} relevant papers!")
208
 
 
 
 
 
209
  # Get relevant context
210
  context = "\n".join([
211
  f"{text[:1000]}" for text in df['text'].head(3)
 
214
  # Generate answer
215
  st.write("Generating answer...")
216
  answer = generate_answer(query, context)
217
+ # Display paper sources
218
+ with st.expander("View source papers"):
219
+ for _, paper in df.iterrows():
220
+ st.markdown(f"- [{paper['title']}]({paper['url']}) ({paper['published']})")
221
+ st.success("Answer found!")
222
+ st.markdown(answer)