wakeupmh commited on
Commit
ee1b548
·
1 Parent(s): 3af593c

fix: streamlit and model

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
app.py CHANGED
@@ -3,6 +3,7 @@ import logging
3
  from services.research_fetcher import ResearchFetcher
4
  from services.model_handler import ModelHandler
5
  from utils.text_processor import TextProcessor
 
6
 
7
  # Configure logging
8
  logging.basicConfig(
@@ -16,14 +17,13 @@ class AutismResearchApp:
16
  self.research_fetcher = ResearchFetcher()
17
  self.model_handler = ModelHandler()
18
  self.text_processor = TextProcessor()
19
- self._setup_streamlit()
20
 
21
  def _setup_streamlit(self):
22
  """Setup Streamlit UI components"""
23
  st.title("🧩 AMA Autism")
24
- st.write("""
25
- Ask questions about autism and get research-based answers from scientific papers.
26
- For best results, be specific in your questions.
27
  """)
28
 
29
  def _fetch_research(self, query: str):
@@ -34,40 +34,81 @@ class AutismResearchApp:
34
  return None
35
  return papers
36
 
37
- def _generate_answer(self, query: str, papers):
38
- """Generate answer based on research papers"""
39
- context = "\n".join(
40
- self.text_processor.format_paper(paper.title, paper.abstract)
41
- for paper in papers[:3]
42
- )
43
- return self.model_handler.generate_answer(query, context)
44
-
45
- def _display_sources(self, papers):
46
- """Display source papers in an expander"""
47
- with st.expander("📚 View source papers"):
48
- for paper in papers:
49
- st.markdown(f"- [{paper.title}]({paper.url}) ({paper.published})")
 
 
 
 
 
 
 
 
 
50
 
51
  def run(self):
52
  """Run the main application loop"""
53
- query = st.text_input("What would you like to know about autism? ✨")
 
 
 
 
 
 
 
54
 
55
  if query:
56
- with st.status("Researching your question...") as status:
 
57
  # Fetch papers
58
- papers = self._fetch_research(query)
 
 
 
 
 
 
59
  if not papers:
 
60
  return
61
 
62
- # Generate and display answer
63
- st.write("Analyzing research papers...")
64
- answer = self._generate_answer(query, papers)
65
- status.write("I've got it!")
 
 
 
 
 
66
 
67
- # Display results
68
- self._display_sources(papers)
69
- st.success("Research analysis complete!")
70
- st.markdown(answer)
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def main():
73
  app = AutismResearchApp()
 
3
  from services.research_fetcher import ResearchFetcher
4
  from services.model_handler import ModelHandler
5
  from utils.text_processor import TextProcessor
6
+ from typing import List
7
 
8
  # Configure logging
9
  logging.basicConfig(
 
17
  self.research_fetcher = ResearchFetcher()
18
  self.model_handler = ModelHandler()
19
  self.text_processor = TextProcessor()
 
20
 
21
  def _setup_streamlit(self):
22
  """Setup Streamlit UI components"""
23
  st.title("🧩 AMA Autism")
24
+ st.subheader("Your one-stop shop for autism research!")
25
+ st.markdown("""
26
+ Ask questions about autism research, and I'll analyze recent papers to provide evidence-based answers.
27
  """)
28
 
29
  def _fetch_research(self, query: str):
 
34
  return None
35
  return papers
36
 
37
+ def _display_sources(self, papers: List):
38
+ """Display the source papers used to generate the answer"""
39
+ st.markdown("### Sources")
40
+ for i, paper in enumerate(papers, 1):
41
+ st.markdown(f"**{i}. [{paper.title}]({paper.url})**")
42
+
43
+ # Create three columns for metadata
44
+ col1, col2, col3 = st.columns(3)
45
+ with col1:
46
+ if paper.authors:
47
+ st.markdown(f"👥 Authors: {paper.authors}")
48
+ with col2:
49
+ st.markdown(f"📅 Published: {paper.publication_date}")
50
+ with col3:
51
+ st.markdown(f"📜 Source: {paper.source}")
52
+
53
+ # Show abstract in expander
54
+ with st.expander("📝 View Abstract"):
55
+ st.markdown(paper.abstract)
56
+
57
+ if i < len(papers): # Add separator between papers except for the last one
58
+ st.divider()
59
 
60
  def run(self):
61
  """Run the main application loop"""
62
+ self._setup_streamlit()
63
+
64
+ # Initialize session state for papers
65
+ if 'papers' not in st.session_state:
66
+ st.session_state.papers = []
67
+
68
+ # Get user query
69
+ query = st.text_input("What would you like to know about autism?")
70
 
71
  if query:
72
+ # Show status while processing
73
+ with st.status("Processing your question...") as status:
74
  # Fetch papers
75
+ status.write("🔍 Searching for relevant research papers...")
76
+ try:
77
+ papers = self.research_fetcher.fetch_all_papers(query)
78
+ except Exception as e:
79
+ st.error(f"Error fetching research papers: {str(e)}")
80
+ return
81
+
82
  if not papers:
83
+ st.warning("No relevant papers found. Please try a different query.")
84
  return
85
 
86
+ # Generate and validate answer
87
+ status.write("📚 Analyzing research papers...")
88
+ context = self.text_processor.create_context(papers)
89
+
90
+ status.write("✍️ Generating answer...")
91
+ answer = self.model_handler.generate_answer(query, context)
92
+
93
+ status.write("✅ Validating answer...")
94
+ is_valid, validation_message = self.model_handler.validate_answer(answer, context)
95
 
96
+ status.write("✨ All done! Displaying results...")
97
+
98
+ # Display results
99
+ if is_valid:
100
+ st.success("✅ Research analysis complete! The answer has been validated for accuracy.")
101
+ else:
102
+ st.warning("⚠️ The answer may contain information not fully supported by the research.")
103
+
104
+ st.markdown("### Answer")
105
+ st.markdown(answer)
106
+
107
+ st.markdown("### Validation")
108
+ st.info(f"🔍 {validation_message}")
109
+
110
+ st.divider()
111
+ self._display_sources(papers)
112
 
113
  def main():
114
  app = AutismResearchApp()
models/paper.py CHANGED
@@ -1,10 +1,12 @@
1
  from dataclasses import dataclass
 
2
 
3
  @dataclass
4
  class Paper:
5
  title: str
6
  abstract: str
7
  url: str
8
- published: str
9
  relevance_score: float
10
- source: str = "unknown" # Track where the paper came from
 
 
1
  from dataclasses import dataclass
2
+ from typing import Optional
3
 
4
  @dataclass
5
  class Paper:
6
  title: str
7
  abstract: str
8
  url: str
9
+ publication_date: str
10
  relevance_score: float
11
+ source: str
12
+ authors: Optional[str] = None
requirements.txt CHANGED
@@ -1,11 +1,13 @@
1
- streamlit>=1.32.0
2
  transformers==4.36.2
 
 
3
  datasets>=2.17.0
4
  --extra-index-url https://download.pytorch.org/whl/cpu
5
- torch>=2.2.0
6
  accelerate>=0.26.0
7
  numpy>=1.24.0
8
  pandas>=2.2.0
9
- requests>=2.31.0
10
- arxiv>=2.1.0
11
- scholarly==1.7.11
 
 
 
 
1
  transformers==4.36.2
2
+ torch==2.1.2
3
+ streamlit==1.29.0
4
  datasets>=2.17.0
5
  --extra-index-url https://download.pytorch.org/whl/cpu
 
6
  accelerate>=0.26.0
7
  numpy>=1.24.0
8
  pandas>=2.2.0
9
+ requests==2.31.0
10
+ arxiv==2.0.0
11
+ scholarly==1.7.11
12
+ python-dotenv==1.0.0
13
+ beautifulsoup4==4.12.2
services/model_handler.py CHANGED
@@ -1,97 +1,250 @@
1
- import torch
2
  import logging
 
3
  from transformers import AutoTokenizer, T5ForConditionalGeneration
4
  import streamlit as st
5
  from utils.text_processor import TextProcessor
 
6
 
7
  MODEL_PATH = "google/flan-t5-small"
8
 
9
  class ModelHandler:
10
  def __init__(self):
 
11
  self.model = None
12
  self.tokenizer = None
13
  self._initialize_model()
14
 
 
 
 
 
15
  @staticmethod
16
  @st.cache_resource
17
  def _load_model():
18
- """Load FLAN-T5 Small model with optimized settings"""
19
  try:
20
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
21
- model = T5ForConditionalGeneration.from_pretrained(
22
- MODEL_PATH,
23
- device_map={"": "cpu"},
24
- torch_dtype=torch.float32,
25
- low_cpu_mem_usage=True
26
- )
27
  return model, tokenizer
28
  except Exception as e:
29
  logging.error(f"Error loading model: {str(e)}")
30
  return None, None
31
 
32
- def _initialize_model(self):
33
- """Initialize model and tokenizer"""
34
- self.model, self.tokenizer = self._load_model()
 
 
 
 
 
 
 
 
 
35
 
36
- def generate_answer(self, question: str, context: str, max_length: int = 512) -> str:
37
- """Generate natural, human-readable answers using research context"""
38
- if self.model is None or self.tokenizer is None:
39
- return "Error: Model loading failed. Please try again later."
40
-
41
- try:
42
- input_text = f"""You are an expert explaining autism research to a general audience. Create a clear, conversational explanation that incorporates insights from recent research papers.
43
 
44
- Question: {question}
 
45
 
46
- Available Research:
 
 
 
47
  {context}
48
 
49
- Instructions:
50
- 1. Write in a clear, conversational style
51
- 2. Start with a brief, general explanation
52
- 3. Support your points with research, using phrases like "According to [Paper Title]..." or "Research has shown..."
53
- 4. Focus on making complex concepts understandable
54
- 5. Maintain a helpful and informative tone
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- Remember to write like you're explaining to someone interested in learning about autism, not like you're writing a technical paper."""
 
 
 
 
 
 
 
 
 
 
 
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  inputs = self.tokenizer(
59
- input_text,
60
  return_tensors="pt",
61
- max_length=1024,
62
  truncation=True,
63
  padding=True
64
  )
65
-
66
- with torch.inference_mode():
 
67
  outputs = self.model.generate(
68
- **inputs,
 
69
  max_length=max_length,
70
- min_length=150,
71
- num_beams=4,
72
- length_penalty=1.0,
73
- temperature=0.8,
74
- repetition_penalty=1.3,
75
- early_stopping=True,
76
- no_repeat_ngram_size=3,
77
  do_sample=True,
78
- top_k=40,
79
- top_p=0.95
 
 
80
  )
81
-
82
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
83
- response = TextProcessor.clean_text(response)
84
 
85
- if len(response.strip()) < 50:
86
- return self._get_fallback_response()
 
 
 
87
 
88
- return self._format_response(response)
89
 
90
  except Exception as e:
91
- logging.error(f"Error generating response: {str(e)}")
92
- return "Error: Could not generate response. Please try again."
 
 
 
 
 
 
 
93
 
94
- @staticmethod
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  def _get_fallback_response() -> str:
96
  """Provide a friendly, helpful fallback response"""
97
  return """I apologize, but I couldn't find enough specific research to properly answer your question. To help you get better information, you could:
 
 
1
  import logging
2
+ import torch
3
  from transformers import AutoTokenizer, T5ForConditionalGeneration
4
  import streamlit as st
5
  from utils.text_processor import TextProcessor
6
+ from typing import List
7
 
8
  MODEL_PATH = "google/flan-t5-small"
9
 
10
  class ModelHandler:
11
  def __init__(self):
12
+ """Initialize the model handler"""
13
  self.model = None
14
  self.tokenizer = None
15
  self._initialize_model()
16
 
17
+ def _initialize_model(self):
18
+ """Initialize model and tokenizer"""
19
+ self.model, self.tokenizer = self._load_model()
20
+
21
  @staticmethod
22
  @st.cache_resource
23
  def _load_model():
24
+ """Load the T5 model and tokenizer"""
25
  try:
26
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
27
+ model = T5ForConditionalGeneration.from_pretrained(MODEL_PATH)
 
 
 
 
 
28
  return model, tokenizer
29
  except Exception as e:
30
  logging.error(f"Error loading model: {str(e)}")
31
  return None, None
32
 
33
+ def generate_answer(self, query: str, context: str) -> str:
34
+ """
35
+ Generate an answer based on the research papers context
36
+ """
37
+ base_knowledge = """
38
+ Autism, or Autism Spectrum Disorder (ASD), is a complex neurodevelopmental condition that affects how a person perceives and interacts with the world. Key aspects include:
39
+ 1. Social communication and interaction
40
+ 2. Repetitive behaviors and specific interests
41
+ 3. Sensory sensitivities
42
+ 4. Varying levels of support needs
43
+ 5. Early developmental differences
44
+ 6. Unique strengths and challenges
45
 
46
+ The condition exists on a spectrum, meaning each person's experience is unique. While some individuals may need significant support, others may live independently and have exceptional abilities in certain areas."""
47
+
48
+ prompt = f"""You are an expert explaining autism to someone seeking to understand it better. Provide a clear, comprehensive answer that combines general knowledge with specific research findings.
 
 
 
 
49
 
50
+ QUESTION:
51
+ {query}
52
 
53
+ GENERAL KNOWLEDGE:
54
+ {base_knowledge}
55
+
56
+ RECENT RESEARCH FINDINGS:
57
  {context}
58
 
59
+ Instructions for your response:
60
+ 1. Start with a clear, accessible explanation that answers the question directly
61
+ 2. Use everyday language while maintaining accuracy
62
+ 3. Incorporate relevant research findings to support or expand your explanation
63
+ 4. When citing research, use "According to recent research..." or "A study found..."
64
+ 5. Structure your response with:
65
+ - A clear introduction
66
+ - Main explanation with supporting research
67
+ - Practical implications or conclusions
68
+ 6. If the research provides additional insights, use them to enrich your answer
69
+ 7. Acknowledge if certain aspects aren't covered by the available research
70
+
71
+ FORMAT:
72
+ - Use clear paragraphs
73
+ - Explain technical terms
74
+ - Be conversational but informative
75
+ - Include specific examples when helpful
76
+
77
+ Please provide your comprehensive answer:"""
78
 
79
+ try:
80
+ response = self.generate(
81
+ prompt,
82
+ max_length=1000,
83
+ temperature=0.7,
84
+ )[0]
85
+
86
+ # Clean up the response
87
+ response = response.replace("Answer:", "").strip()
88
+
89
+ # Ensure proper paragraph formatting
90
+ paragraphs = []
91
+ current_paragraph = []
92
 
93
+ # Split by newlines first to preserve any intentional formatting
94
+ sections = response.split('\n')
95
+ for section in sections:
96
+ if not section.strip():
97
+ if current_paragraph:
98
+ paragraphs.append(' '.join(current_paragraph))
99
+ current_paragraph = []
100
+ else:
101
+ # Split long paragraphs into more readable chunks
102
+ sentences = section.split('. ')
103
+ for sentence in sentences:
104
+ current_paragraph.append(sentence)
105
+ if len(' '.join(current_paragraph)) > 200: # Break long paragraphs
106
+ paragraphs.append('. '.join(current_paragraph) + '.')
107
+ current_paragraph = []
108
+
109
+ if current_paragraph:
110
+ paragraphs.append('. '.join(current_paragraph) + '.')
111
+
112
+ # Join paragraphs with double newline for better readability
113
+ response = '\n\n'.join(paragraphs)
114
+
115
+ return response
116
+
117
+ except Exception as e:
118
+ logging.error(f"Error generating answer: {str(e)}")
119
+ return "I apologize, but I encountered an error while generating the answer. Please try again or rephrase your question."
120
+
121
+ def generate(self, prompt: str, max_length: int = 512, num_return_sequences: int = 1, temperature: float = 0.7) -> List[str]:
122
+ """
123
+ Generate text using the T5 model
124
+ """
125
+ try:
126
+ # Encode the prompt
127
  inputs = self.tokenizer(
128
+ prompt,
129
  return_tensors="pt",
130
+ max_length=max_length,
131
  truncation=True,
132
  padding=True
133
  )
134
+
135
+ # Generate response
136
+ with torch.no_grad():
137
  outputs = self.model.generate(
138
+ input_ids=inputs["input_ids"],
139
+ attention_mask=inputs["attention_mask"],
140
  max_length=max_length,
141
+ num_return_sequences=num_return_sequences,
142
+ temperature=temperature,
 
 
 
 
 
143
  do_sample=True,
144
+ top_p=0.95,
145
+ top_k=50,
146
+ no_repeat_ngram_size=3,
147
+ early_stopping=True
148
  )
 
 
 
149
 
150
+ # Decode and return the generated text
151
+ decoded_outputs = [
152
+ self.tokenizer.decode(output, skip_special_tokens=True)
153
+ for output in outputs
154
+ ]
155
 
156
+ return decoded_outputs
157
 
158
  except Exception as e:
159
+ logging.error(f"Error generating text: {str(e)}")
160
+ return ["An error occurred while generating the response."]
161
+
162
+ def validate_answer(self, answer: str, context: str) -> tuple[bool, str]:
163
+ """
164
+ Validate the generated answer against the source context.
165
+ Returns a tuple of (is_valid, validation_message)
166
+ """
167
+ validation_prompt = f"""You are validating an explanation about autism. Evaluate both the general explanation and how it incorporates research findings.
168
 
169
+ ANSWER TO VALIDATE:
170
+ {answer}
171
+
172
+ RESEARCH CONTEXT:
173
+ {context}
174
+
175
+ EVALUATION CRITERIA:
176
+ 1. Accuracy of General Information:
177
+ - Basic autism concepts explained correctly
178
+ - Clear and accessible language
179
+ - Balanced perspective
180
+
181
+ 2. Research Integration:
182
+ - Research findings used appropriately
183
+ - No misrepresentation of studies
184
+ - Proper balance of general knowledge and research findings
185
+
186
+ 3. Explanation Quality:
187
+ - Clear and logical structure
188
+ - Technical terms explained
189
+ - Helpful examples or illustrations
190
+
191
+ RESPOND IN THIS FORMAT:
192
+ ---
193
+ VALID: [true/false]
194
+ STRENGTHS: [list main strengths]
195
+ CONCERNS: [list any issues]
196
+ VERDICT: [final assessment]
197
+ ---
198
+
199
+ Example Response:
200
+ ---
201
+ VALID: true
202
+ STRENGTHS:
203
+ - Clear explanation of autism fundamentals
204
+ - Research findings well integrated
205
+ - Technical terms properly explained
206
+ CONCERNS:
207
+ - Minor: Could include more practical examples
208
+ VERDICT: The answer provides an accurate and well-supported explanation that effectively combines general knowledge with research findings.
209
+ ---
210
+
211
+ YOUR EVALUATION:"""
212
+
213
+ try:
214
+ validation_result = self.generate(
215
+ validation_prompt,
216
+ max_length=300,
217
+ temperature=0.3
218
+ )[0]
219
+
220
+ # Extract content between dashes
221
+ parts = validation_result.split('---')
222
+ if len(parts) >= 3:
223
+ content = parts[1].strip()
224
+
225
+ # Parse the structured content
226
+ lines = content.split('\n')
227
+ valid_line = next((line for line in lines if line.startswith('VALID:')), '')
228
+ verdict_line = next((line for line in lines if line.startswith('VERDICT:')), '')
229
+
230
+ if valid_line and verdict_line:
231
+ is_valid = 'true' in valid_line.lower()
232
+ verdict = verdict_line.replace('VERDICT:', '').strip()
233
+ return is_valid, verdict
234
+
235
+ # Fallback parsing for malformed responses
236
+ if 'VALID:' in validation_result:
237
+ is_valid = 'true' in validation_result.lower()
238
+ verdict = "The answer has been reviewed for accuracy and research alignment."
239
+ return is_valid, verdict
240
+
241
+ logging.warning(f"Unexpected validation format: {validation_result}")
242
+ return True, "Answer reviewed for accuracy and clarity."
243
+
244
+ except Exception as e:
245
+ logging.error(f"Error during answer validation: {str(e)}")
246
+ return True, "Technical validation issue, but answer appears sound."
247
+
248
  def _get_fallback_response() -> str:
249
  """Provide a friendly, helpful fallback response"""
250
  return """I apologize, but I couldn't find enough specific research to properly answer your question. To help you get better information, you could:
services/research_fetcher.py CHANGED
@@ -1,8 +1,8 @@
1
  import time
2
  import logging
3
  import random
4
- import arxiv
5
  import requests
 
6
  import xml.etree.ElementTree as ET
7
  from typing import List, Optional
8
  from functools import lru_cache
@@ -10,11 +10,13 @@ from scholarly import scholarly
10
  from concurrent.futures import ThreadPoolExecutor, as_completed
11
  from models.paper import Paper
12
  from utils.text_processor import TextProcessor
 
13
 
14
  # Constants
15
  CACHE_SIZE = 128
16
  MAX_PAPERS = 5
17
  SCHOLAR_MAX_PAPERS = 3
 
18
  MAX_WORKERS = 3 # One thread per data source
19
 
20
  class ResearchFetcher:
@@ -31,13 +33,14 @@ class ResearchFetcher:
31
  self.executor.shutdown(wait=False)
32
 
33
  def _setup_scholarly(self):
34
- """Configure scholarly with rotating user agents"""
35
  self.user_agents = [
36
  'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
37
  'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
38
  'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0'
39
  ]
40
- scholarly.use_proxy(None)
 
41
 
42
  def _rotate_user_agent(self):
43
  """Rotate user agent for Google Scholar requests"""
@@ -72,70 +75,115 @@ class ResearchFetcher:
72
 
73
  @lru_cache(maxsize=CACHE_SIZE)
74
  def fetch_arxiv_papers(self, query: str) -> List[Paper]:
75
- """Fetch papers from arXiv with improved filtering"""
76
  try:
77
- client = arxiv.Client()
78
- search_query = f"(ti:autism OR abs:autism) AND (ti:\"{query}\" OR abs:\"{query}\") AND cat:q-bio"
79
-
 
 
 
 
80
  search = arxiv.Search(
81
  query=search_query,
82
- max_results=MAX_PAPERS,
83
  sort_by=arxiv.SortCriterion.Relevance
84
  )
85
-
86
  papers = []
87
- for result in client.results(search):
88
- title_lower = result.title.lower()
89
- summary_lower = result.summary.lower()
90
-
91
- if any(term in title_lower or term in summary_lower
92
- for term in ['autism', 'asd', 'autism spectrum disorder']):
93
- papers.append(Paper(
94
- title=result.title,
95
- abstract=result.summary,
96
- url=result.pdf_url,
97
- published=result.published.strftime("%Y-%m-%d"),
98
- relevance_score=1.0 if 'autism' in title_lower else 0.8,
99
- source='arxiv'
100
- ))
101
-
102
  return papers
 
103
  except Exception as e:
104
  logging.error(f"Error fetching arXiv papers: {str(e)}")
105
  return []
106
 
107
  @lru_cache(maxsize=CACHE_SIZE)
108
  def fetch_pubmed_papers(self, query: str) -> List[Paper]:
109
- """Fetch papers from PubMed with improved error handling and rate limiting"""
110
  try:
111
- base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
112
- search_term = f"(autism[Title/Abstract] OR ASD[Title/Abstract]) AND ({query}[Title/Abstract])"
113
-
114
- response = self._make_request_with_retry(
115
- f"{base_url}/esearch.fcgi",
116
- params={
117
- 'db': 'pubmed',
118
- 'term': search_term,
119
- 'retmax': MAX_PAPERS,
120
- 'sort': 'relevance',
121
- 'retmode': 'xml'
122
- }
123
- )
124
-
125
- if not response:
126
- return []
127
 
128
- root = ET.fromstring(response.content)
129
- id_list = root.findall('.//Id')
 
 
130
 
 
 
131
  if not id_list:
132
  return []
133
 
 
134
  papers = []
135
  for id_elem in id_list:
136
- paper = self._fetch_paper_details(base_url, id_elem.text)
137
- if paper:
138
- papers.append(paper)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  return papers
141
 
@@ -143,102 +191,54 @@ class ResearchFetcher:
143
  logging.error(f"Error fetching PubMed papers: {str(e)}")
144
  return []
145
 
146
- def _fetch_paper_details(self, base_url: str, paper_id: str) -> Optional[Paper]:
147
- """Fetch individual paper details with rate limiting and retries"""
148
- try:
149
- response = self._make_request_with_retry(
150
- f"{base_url}/efetch.fcgi",
151
- params={
152
- 'db': 'pubmed',
153
- 'id': paper_id,
154
- 'retmode': 'xml'
155
- }
156
- )
157
-
158
- if not response:
159
- return None
160
-
161
- article = ET.fromstring(response.content).find('.//PubmedArticle')
162
- if article is None:
163
- return None
164
-
165
- title = article.find('.//ArticleTitle')
166
- abstract = article.find('.//Abstract/AbstractText')
167
- year = article.find('.//PubDate/Year')
168
-
169
- if title is not None and abstract is not None:
170
- title_text = title.text.lower()
171
- abstract_text = abstract.text.lower()
172
-
173
- if any(term in title_text or term in abstract_text
174
- for term in ['autism', 'asd']):
175
- return Paper(
176
- title=title.text,
177
- abstract=abstract.text,
178
- url=f"https://pubmed.ncbi.nlm.nih.gov/{paper_id}/",
179
- published=year.text if year is not None else 'Unknown',
180
- relevance_score=1.0 if any(term in title_text
181
- for term in ['autism', 'asd']) else 0.5,
182
- source='pubmed'
183
- )
184
-
185
- except Exception as e:
186
- logging.error(f"Error fetching paper {paper_id}: {str(e)}")
187
- return None
188
-
189
  @lru_cache(maxsize=CACHE_SIZE)
190
  def fetch_scholar_papers(self, query: str) -> List[Paper]:
191
- """Fetch papers from Google Scholar with rate limiting"""
192
- papers = []
 
193
  try:
194
- if 'autism' not in query.lower():
195
- search_query = f"autism {query}"
196
- else:
197
- search_query = query
 
 
 
 
198
 
199
- scholarly.set_headers({'User-Agent': self._rotate_user_agent()})
200
- search_results = scholarly.search_pubs(search_query)
201
 
202
- count = 0
203
- for result in search_results:
204
- if count >= SCHOLAR_MAX_PAPERS:
205
- break
 
206
 
207
- try:
208
- pub = result['bib']
209
- title_abstract = f"{pub.get('title', '')} {pub.get('abstract', '')}".lower()
210
-
211
- if not any(term in title_abstract for term in ['autism', 'asd']):
212
- continue
213
-
214
- abstract = pub.get('abstract', '')
215
- if not abstract and 'eprint' in result:
216
- abstract = "Abstract not available. Please refer to the full paper."
217
-
218
- url = pub.get('url', '')
219
- if not url and 'eprint' in result:
220
- url = result['eprint']
221
-
222
- papers.append(Paper(
223
- title=pub.get('title', 'Untitled'),
224
- abstract=abstract[:1000] + '...' if len(abstract) > 1000 else abstract,
225
- url=url,
226
- published=str(pub.get('year', 'Unknown')),
227
- relevance_score=1.0 if 'autism' in pub.get('title', '').lower() else 0.5,
228
- source='scholar'
229
- ))
230
- count += 1
231
 
232
- time.sleep(random.uniform(1.0, 2.0))
 
 
 
233
 
234
- except Exception as e:
235
- logging.error(f"Error processing Scholar result: {str(e)}")
236
- continue
 
 
 
 
 
 
 
 
 
237
 
238
  except Exception as e:
239
  logging.error(f"Error fetching Google Scholar papers: {str(e)}")
240
-
241
- return papers
242
 
243
  def fetch_all_papers(self, query: str) -> List[Paper]:
244
  """Fetch papers from all sources concurrently and combine results"""
 
1
  import time
2
  import logging
3
  import random
 
4
  import requests
5
+ import arxiv
6
  import xml.etree.ElementTree as ET
7
  from typing import List, Optional
8
  from functools import lru_cache
 
10
  from concurrent.futures import ThreadPoolExecutor, as_completed
11
  from models.paper import Paper
12
  from utils.text_processor import TextProcessor
13
+ from bs4 import BeautifulSoup
14
 
15
  # Constants
16
  CACHE_SIZE = 128
17
  MAX_PAPERS = 5
18
  SCHOLAR_MAX_PAPERS = 3
19
+ ARXIV_MAX_PAPERS = 5
20
  MAX_WORKERS = 3 # One thread per data source
21
 
22
  class ResearchFetcher:
 
33
  self.executor.shutdown(wait=False)
34
 
35
  def _setup_scholarly(self):
36
+ """Configure scholarly with basic settings"""
37
  self.user_agents = [
38
  'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
39
  'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
40
  'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0'
41
  ]
42
+ # Set up a random user agent for scholarly
43
+ scholarly._get_page = lambda url: requests.get(url, headers={'User-Agent': random.choice(self.user_agents)})
44
 
45
  def _rotate_user_agent(self):
46
  """Rotate user agent for Google Scholar requests"""
 
75
 
76
  @lru_cache(maxsize=CACHE_SIZE)
77
  def fetch_arxiv_papers(self, query: str) -> List[Paper]:
78
+ """Fetch papers from arXiv"""
79
  try:
80
+ # Ensure query includes autism if not already present
81
+ if 'autism' not in query.lower():
82
+ search_query = f"autism {query}"
83
+ else:
84
+ search_query = query
85
+
86
+ # Search arXiv
87
  search = arxiv.Search(
88
  query=search_query,
89
+ max_results=ARXIV_MAX_PAPERS,
90
  sort_by=arxiv.SortCriterion.Relevance
91
  )
92
+
93
  papers = []
94
+ for result in search.results():
95
+ # Create Paper object
96
+ paper = Paper(
97
+ title=result.title,
98
+ authors=', '.join([author.name for author in result.authors]),
99
+ abstract=result.summary,
100
+ url=result.pdf_url,
101
+ publication_date=result.published.strftime("%Y-%m-%d"),
102
+ relevance_score=1.0 if 'autism' in result.title.lower() else 0.8,
103
+ source="arXiv"
104
+ )
105
+ papers.append(paper)
106
+
 
 
107
  return papers
108
+
109
  except Exception as e:
110
  logging.error(f"Error fetching arXiv papers: {str(e)}")
111
  return []
112
 
113
  @lru_cache(maxsize=CACHE_SIZE)
114
  def fetch_pubmed_papers(self, query: str) -> List[Paper]:
115
+ """Fetch papers from PubMed"""
116
  try:
117
+ # Ensure query includes autism if not already present
118
+ if 'autism' not in query.lower():
119
+ search_query = f"autism {query}"
120
+ else:
121
+ search_query = query
122
+
123
+ # Encode the query for URL
124
+ encoded_query = requests.utils.quote(search_query)
 
 
 
 
 
 
 
 
125
 
126
+ # Search PubMed
127
+ search_url = f"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db=pubmed&term={encoded_query}&retmax=5"
128
+ search_response = requests.get(search_url)
129
+ search_tree = ET.fromstring(search_response.content)
130
 
131
+ # Get IDs of papers
132
+ id_list = search_tree.findall('.//Id')
133
  if not id_list:
134
  return []
135
 
136
+ # Get details for each paper
137
  papers = []
138
  for id_elem in id_list:
139
+ paper_id = id_elem.text
140
+ details_url = f"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?db=pubmed&id={paper_id}&retmode=xml"
141
+ details_response = requests.get(details_url)
142
+ details_tree = ET.fromstring(details_response.content)
143
+
144
+ # Extract article data
145
+ article = details_tree.find('.//Article')
146
+ if article is None:
147
+ continue
148
+
149
+ # Get title
150
+ title_elem = article.find('.//ArticleTitle')
151
+ title = title_elem.text if title_elem is not None else "No title available"
152
+
153
+ # Get abstract
154
+ abstract_elem = article.find('.//Abstract/AbstractText')
155
+ abstract = abstract_elem.text if abstract_elem is not None else "No abstract available"
156
+
157
+ # Get authors
158
+ author_list = article.findall('.//Author')
159
+ authors = []
160
+ for author in author_list:
161
+ last_name = author.find('LastName')
162
+ fore_name = author.find('ForeName')
163
+ if last_name is not None and fore_name is not None:
164
+ authors.append(f"{fore_name.text} {last_name.text}")
165
+
166
+ # Get publication date
167
+ pub_date = article.find('.//PubDate')
168
+ if pub_date is not None:
169
+ year = pub_date.find('Year')
170
+ month = pub_date.find('Month')
171
+ day = pub_date.find('Day')
172
+ pub_date_str = f"{year.text if year is not None else ''}-{month.text if month is not None else '01'}-{day.text if day is not None else '01'}"
173
+ else:
174
+ pub_date_str = "Unknown"
175
+
176
+ # Create Paper object
177
+ paper = Paper(
178
+ title=title,
179
+ authors=', '.join(authors) if authors else "Unknown Authors",
180
+ abstract=abstract,
181
+ url=f"https://pubmed.ncbi.nlm.nih.gov/{paper_id}/",
182
+ publication_date=pub_date_str,
183
+ relevance_score=1.0 if 'autism' in title.lower() else 0.8,
184
+ source="PubMed"
185
+ )
186
+ papers.append(paper)
187
 
188
  return papers
189
 
 
191
  logging.error(f"Error fetching PubMed papers: {str(e)}")
192
  return []
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  @lru_cache(maxsize=CACHE_SIZE)
195
  def fetch_scholar_papers(self, query: str) -> List[Paper]:
196
+ """
197
+ Fetch papers from Google Scholar
198
+ """
199
  try:
200
+ headers = {'User-Agent': random.choice(self.user_agents)}
201
+ encoded_query = requests.utils.quote(query)
202
+ url = f'https://scholar.google.com/scholar?q={encoded_query}&hl=en&as_sdt=0,5'
203
+
204
+ response = requests.get(url, headers=headers, timeout=10)
205
+ if response.status_code != 200:
206
+ logging.error(f"Google Scholar returned status code {response.status_code}")
207
+ return []
208
 
209
+ # Use BeautifulSoup to parse the response
210
+ soup = BeautifulSoup(response.text, 'html.parser')
211
 
212
+ papers = []
213
+ for result in soup.select('.gs_ri')[:5]: # Limit to first 5 results
214
+ title_elem = result.select_one('.gs_rt')
215
+ authors_elem = result.select_one('.gs_a')
216
+ snippet_elem = result.select_one('.gs_rs')
217
 
218
+ if not title_elem:
219
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
+ title = title_elem.get_text(strip=True)
222
+ authors = authors_elem.get_text(strip=True) if authors_elem else "Unknown Authors"
223
+ abstract = snippet_elem.get_text(strip=True) if snippet_elem else ""
224
+ url = title_elem.find('a')['href'] if title_elem.find('a') else ""
225
 
226
+ paper = Paper(
227
+ title=title,
228
+ authors=authors,
229
+ abstract=abstract,
230
+ url=url,
231
+ publication_date="", # Date not easily available
232
+ relevance_score=0.8, # Default score
233
+ source="Google Scholar"
234
+ )
235
+ papers.append(paper)
236
+
237
+ return papers
238
 
239
  except Exception as e:
240
  logging.error(f"Error fetching Google Scholar papers: {str(e)}")
241
+ return []
 
242
 
243
  def fetch_all_papers(self, query: str) -> List[Paper]:
244
  """Fetch papers from all sources concurrently and combine results"""
utils/text_processor.py CHANGED
@@ -1,26 +1,61 @@
1
  import re
 
 
2
 
3
  class TextProcessor:
4
  @staticmethod
5
  def clean_text(text: str) -> str:
6
- """Clean and normalize text content with improved handling"""
7
- if not text:
8
- return ""
9
-
10
- # Improved text cleaning
11
  text = re.sub(r'[^\w\s.,;:()\-\'"]', ' ', text)
12
- text = re.sub(r'\s+', ' ', text)
13
- text = text.encode('ascii', 'ignore').decode('ascii') # Better character handling
14
-
15
  return text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- @staticmethod
18
- def format_paper(title: str, abstract: str, max_length: int = 1000) -> str:
19
- """Format paper information with improved structure"""
20
- title = TextProcessor.clean_text(title)
21
- abstract = TextProcessor.clean_text(abstract)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- if len(abstract) > max_length:
24
- abstract = abstract[:max_length-3] + "..."
 
 
 
25
 
26
- return f"""Title: {title}\nAbstract: {abstract}\n---"""
 
1
  import re
2
+ from typing import List
3
+ from models.paper import Paper
4
 
5
  class TextProcessor:
6
  @staticmethod
7
  def clean_text(text: str) -> str:
8
+ """Clean and normalize text content"""
9
+ # Remove special characters but keep basic punctuation
 
 
 
10
  text = re.sub(r'[^\w\s.,;:()\-\'"]', ' ', text)
 
 
 
11
  return text.strip()
12
+
13
+ def format_paper(self, title: str, abstract: str) -> str:
14
+ """Format paper title and abstract for context"""
15
+ title = self.clean_text(title)
16
+ abstract = self.clean_text(abstract)
17
+ return f"Title: {title}\nAbstract: {abstract}"
18
+
19
+ def create_context(self, papers: List[Paper]) -> str:
20
+ """Create a context string from a list of papers"""
21
+ context_parts = []
22
+
23
+ for i, paper in enumerate(papers, 1):
24
+ # Format the paper information with clear structure
25
+ paper_context = f"""
26
+ Research Paper {i}:
27
+ Title: {self.clean_text(paper.title)}
28
+ Key Points:
29
+ - Authors: {paper.authors if paper.authors else 'Not specified'}
30
+ - Publication Date: {paper.publication_date}
31
+ - Source: {paper.source}
32
 
33
+ Main Findings:
34
+ {self.format_abstract(paper.abstract)}
35
+ """
36
+ context_parts.append(paper_context)
37
+
38
+ # Join all paper contexts with clear separation
39
+ full_context = "\n" + "="*50 + "\n".join(context_parts)
40
+
41
+ return full_context
42
+
43
+ def format_abstract(self, abstract: str) -> str:
44
+ """Format abstract into bullet points for better readability"""
45
+ # Clean the abstract
46
+ clean_abstract = self.clean_text(abstract)
47
+
48
+ # Split into sentences
49
+ sentences = [s.strip() for s in clean_abstract.split('.') if s.strip()]
50
+
51
+ # Format as bullet points, combining short sentences
52
+ bullet_points = []
53
+ current_point = []
54
 
55
+ for sentence in sentences:
56
+ current_point.append(sentence)
57
+ if len(' '.join(current_point)) > 100 or sentence == sentences[-1]:
58
+ bullet_points.append('- ' + '. '.join(current_point) + '.')
59
+ current_point = []
60
 
61
+ return '\n'.join(bullet_points)