File size: 10,828 Bytes
f1586e3
e348a54
0f8445a
5a09d5c
e348a54
 
4660a83
 
 
58be7e5
5a09d5c
 
8108db5
f1586e3
4660a83
0452175
 
 
e348a54
0452175
d3e32db
4660a83
 
a47c92e
 
e348a54
a47c92e
 
 
 
 
 
 
 
f68ac31
58be7e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4660a83
 
 
62b3157
7842508
 
4660a83
 
 
 
 
 
 
 
 
 
7842508
 
 
 
 
 
 
 
 
 
 
 
8903db2
4660a83
f944585
4660a83
 
 
 
7842508
 
97889da
4660a83
 
 
 
97889da
4660a83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97889da
 
4660a83
97889da
7842508
 
 
 
 
 
 
 
 
 
 
 
 
4660a83
 
 
d3e32db
4660a83
 
 
 
 
 
cb9a068
4660a83
 
 
97889da
58be7e5
 
 
 
7842508
58be7e5
 
 
 
 
 
 
7842508
58be7e5
 
7842508
 
 
 
cc0b0d6
7842508
 
 
 
 
 
 
 
 
4660a83
 
 
 
62b3157
a47c92e
 
 
58be7e5
 
e348a54
58be7e5
e348a54
67c9f69
4660a83
e348a54
97889da
67c9f69
58be7e5
218a8a7
67c9f69
 
 
 
 
4660a83
67c9f69
f99a008
a47c92e
e348a54
 
 
 
 
 
97889da
a47c92e
 
 
 
8741970
e348a54
a47c92e
 
 
e348a54
 
 
 
 
a47c92e
 
 
58be7e5
a47c92e
218a8a7
 
67c9f69
218a8a7
67c9f69
 
 
 
 
218a8a7
67c9f69
a47c92e
 
 
 
 
cc0b0d6
a47c92e
 
 
f1586e3
f68ac31
 
67b3686
 
 
 
 
 
218a8a7
f1586e3
 
b45a04f
4660a83
 
67b3686
b45a04f
67b3686
 
62b3157
 
 
 
4660a83
62b3157
b45a04f
67b3686
97889da
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
import streamlit as st
import pandas as pd
import torch
import logging
import os
from transformers import AutoTokenizer, T5ForConditionalGeneration
import arxiv
import requests
import xml.etree.ElementTree as ET
import re

# Configure logging
logging.basicConfig(level=logging.INFO)

# Define data paths and constants
DATA_DIR = "/data" if os.path.exists("/data") else "."
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
MODEL_PATH = "google/flan-t5-small"  # Using flan-t5-small for better performance

@st.cache_resource
def load_local_model():
    """Load the local Hugging Face model"""
    try:
        tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
        model = T5ForConditionalGeneration.from_pretrained(
            MODEL_PATH,
            device_map={"": "cpu"},  # Force CPU
            torch_dtype=torch.float32
        )
        return model, tokenizer
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
        return None, None

def clean_text(text):
    """Clean and normalize text content"""
    if not text:
        return ""
    
    # Remove special characters and normalize spaces
    text = re.sub(r'[^\w\s.,;:()\-\'"]', ' ', text)
    text = re.sub(r'\s+', ' ', text)
    text = text.replace('’', "'").replace('“', '"').replace('â€', '"')
    
    # Remove any remaining weird characters
    text = ''.join(char for char in text if ord(char) < 128)
    
    return text.strip()

def format_paper(title, abstract):
    """Format paper information consistently"""
    title = clean_text(title)
    abstract = clean_text(abstract)
    
    if len(abstract) > 1000:
        abstract = abstract[:997] + "..."
    
    return f"""Title: {title}

Abstract: {abstract}

---"""

def fetch_arxiv_papers(query, max_results=5):
    """Fetch papers from arXiv"""
    client = arxiv.Client()
    
    # Always include autism in the search query
    search_query = f"(ti:autism OR abs:autism) AND (ti:\"{query}\" OR abs:\"{query}\") AND cat:q-bio"
    
    # Search arXiv
    search = arxiv.Search(
        query=search_query,
        max_results=max_results,
        sort_by=arxiv.SortCriterion.Relevance
    )
    
    papers = []
    for result in client.results(search):
        # Only include papers that mention autism in title or abstract
        if ('autism' in result.title.lower() or 
            'asd' in result.title.lower() or 
            'autism' in result.summary.lower() or 
            'asd' in result.summary.lower()):
            papers.append({
                'title': result.title,
                'abstract': result.summary,
                'url': result.pdf_url,
                'published': result.published.strftime("%Y-%m-%d"),
                'relevance_score': 1 if 'autism' in result.title.lower() else 0.5
            })
    
    return papers

def fetch_pubmed_papers(query, max_results=5):
    """Fetch papers from PubMed"""
    base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
    
    # Always include autism in the search term
    search_term = f"(autism[Title/Abstract] OR ASD[Title/Abstract]) AND ({query}[Title/Abstract])"
    
    # Search for papers
    search_url = f"{base_url}/esearch.fcgi"
    search_params = {
        'db': 'pubmed',
        'term': search_term,
        'retmax': max_results,
        'sort': 'relevance',
        'retmode': 'xml'
    }
    
    papers = []
    try:
        # Get paper IDs
        response = requests.get(search_url, params=search_params)
        root = ET.fromstring(response.content)
        id_list = [id_elem.text for id_elem in root.findall('.//Id')]
        
        if not id_list:
            return papers
        
        # Fetch paper details
        fetch_url = f"{base_url}/efetch.fcgi"
        fetch_params = {
            'db': 'pubmed',
            'id': ','.join(id_list),
            'retmode': 'xml'
        }
        
        response = requests.get(fetch_url, params=fetch_params)
        articles = ET.fromstring(response.content)
        
        for article in articles.findall('.//PubmedArticle'):
            title = article.find('.//ArticleTitle')
            abstract = article.find('.//Abstract/AbstractText')
            year = article.find('.//PubDate/Year')
            pmid = article.find('.//PMID')
            
            if title is not None and abstract is not None:
                title_text = title.text.lower()
                abstract_text = abstract.text.lower()
                
                # Only include papers that mention autism
                if ('autism' in title_text or 'asd' in title_text or 
                    'autism' in abstract_text or 'asd' in abstract_text):
                    papers.append({
                        'title': title.text,
                        'abstract': abstract.text,
                        'url': f"https://pubmed.ncbi.nlm.nih.gov/{pmid.text}/",
                        'published': year.text if year is not None else 'Unknown',
                        'relevance_score': 1 if ('autism' in title_text or 'asd' in title_text) else 0.5
                    })
            
    except Exception as e:
        st.error(f"Error fetching PubMed papers: {str(e)}")
    
    return papers

def search_research_papers(query):
    """Search both arXiv and PubMed for papers"""
    arxiv_papers = fetch_arxiv_papers(query)
    pubmed_papers = fetch_pubmed_papers(query)
    
    # Combine and format papers
    all_papers = []
    for paper in arxiv_papers + pubmed_papers:
        if paper['abstract'] and len(paper['abstract'].strip()) > 0:
            # Clean and format the paper content
            clean_title = clean_text(paper['title'])
            clean_abstract = clean_text(paper['abstract'])
            
            # Check if the paper is actually about autism
            if ('autism' in clean_title.lower() or 
                'asd' in clean_title.lower() or 
                'autism' in clean_abstract.lower() or 
                'asd' in clean_abstract.lower()):
                
                formatted_text = format_paper(clean_title, clean_abstract)
                
                all_papers.append({
                    'title': clean_title,
                    'text': formatted_text,
                    'url': paper['url'],
                    'published': paper['published'],
                    'relevance_score': paper.get('relevance_score', 0.5)
                })
    
    # Sort papers by relevance score and convert to DataFrame
    all_papers.sort(key=lambda x: x['relevance_score'], reverse=True)
    df = pd.DataFrame(all_papers)
    
    if df.empty:
        st.warning("No autism-related papers found. Please try a different search term.")
        return pd.DataFrame(columns=['title', 'text', 'url', 'published', 'relevance_score'])
    
    return df

def generate_answer(question, context, max_length=512):
    """Generate a comprehensive answer using the local model"""
    model, tokenizer = load_local_model()
    
    if model is None or tokenizer is None:
        return "Error: Could not load the model. Please try again later."
    
    # Clean and format the context
    clean_context = clean_text(context)
    clean_question = clean_text(question)
    
    # Format the input for T5 (it expects a specific format)
    input_text = f"""Generate a comprehensive answer about autism, using the research papers as references to support your explanations.

Question: {clean_question}

Research Papers:
{clean_context}

Instructions: 
1. Provide a general explanation about the topic
2. Use the research papers as references, citing them in the format "According to [PAPER TITLE], ..."
3. Integrate research findings naturally into the explanation
4. Keep the focus on being informative and helpful

Answer in a clear, informative way, using the papers as references."""
    
    try:
        # T5 expects a specific format for the input
        inputs = tokenizer(input_text, 
                         return_tensors="pt", 
                         max_length=1024, 
                         truncation=True,
                         padding=True)
        
        with torch.inference_mode():
            outputs = model.generate(
                **inputs,
                max_length=max_length,
                min_length=200,
                num_beams=5,
                length_penalty=1.5,
                temperature=0.7,
                repetition_penalty=1.2,
                early_stopping=True,
                no_repeat_ngram_size=3,
                do_sample=True,
                top_k=50,
                top_p=0.95
            )
        
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = clean_text(response)
        
        # If response is too short or empty, provide a general overview
        if len(response.strip()) < 100:
            return """Autism Spectrum Disorder (ASD) is a complex neurodevelopmental condition. Unfortunately, the provided papers don't contain specific information about this aspect of autism.

To get research-based information, try asking more specific questions about:
- Genetics and environmental factors
- Early intervention
- Treatments and therapies
- Neurological development

This will allow us to provide accurate information supported by recent scientific research."""
        
        # Format the response for better readability
        formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ")
        
        return formatted_response
    
    except Exception as e:
        st.error(f"Error generating response: {str(e)}")
        return "Error: Could not generate response. Please try again with a different question."

# Streamlit App
st.title("🧩 AMA Autism")

st.write("""
This app searches through scientific papers to answer your questions about autism.
For best results, be specific in your questions.
""")

query = st.text_input("Please ask me anything about autism ✨")

if query:
    with st.status("Searching for answers...") as status:
        # Search for papers
        df = search_research_papers(query)
        
        st.write("Searching for data in PubMed and arXiv...")
        st.write(f"Found {len(df)} relevant papers!")
        
        # Get relevant context
        context = "\n".join([
            f"{text[:1000]}" for text in df['text'].head(3)
        ])
        
        # Generate answer
        st.write("Generating answer...")
        answer = generate_answer(query, context)
    # Display paper sources
    with st.expander("View source papers"):
        for _, paper in df.iterrows():
            st.markdown(f"- [{paper['title']}]({paper['url']}) ({paper['published']})")
    st.success("Answer found!")
    st.markdown(answer)