File size: 10,120 Bytes
f1586e3
67b3686
f91cc3b
8903db2
0f8445a
5a09d5c
8903db2
4660a83
 
 
5a09d5c
 
8108db5
f1586e3
4660a83
0452175
 
 
a47c92e
0452175
d3e32db
4660a83
 
a47c92e
 
 
 
 
 
 
 
 
 
 
f68ac31
4660a83
 
 
62b3157
7842508
 
4660a83
 
 
 
 
 
 
 
 
 
7842508
 
 
 
 
 
 
 
 
 
 
 
8903db2
4660a83
f944585
4660a83
 
 
 
7842508
 
97889da
4660a83
 
 
 
97889da
4660a83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97889da
 
4660a83
97889da
7842508
 
 
 
 
 
 
 
 
 
 
 
 
4660a83
 
 
d3e32db
4660a83
 
 
 
 
 
cb9a068
4660a83
 
 
97889da
7842508
 
 
 
 
 
 
 
 
 
 
 
cc0b0d6
7842508
 
 
 
 
 
 
 
 
4660a83
 
 
 
62b3157
a47c92e
 
 
4660a83
218a8a7
4660a83
97889da
 
218a8a7
 
 
 
 
 
 
 
4660a83
218a8a7
f99a008
a47c92e
 
 
97889da
a47c92e
 
 
 
218a8a7
a47c92e
 
 
 
 
 
 
 
 
218a8a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a47c92e
 
 
 
 
cc0b0d6
a47c92e
 
 
f1586e3
f68ac31
 
67b3686
 
 
 
 
 
218a8a7
f1586e3
 
b45a04f
4660a83
 
67b3686
b45a04f
67b3686
 
62b3157
 
 
 
4660a83
62b3157
b45a04f
67b3686
97889da
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import os
from datasets import load_from_disk, Dataset
import torch
import logging
import pandas as pd
import arxiv
import requests
import xml.etree.ElementTree as ET

# Configure logging
logging.basicConfig(level=logging.INFO)

# Define data paths and constants
DATA_DIR = "/data" if os.path.exists("/data") else "."
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
MODEL_PATH = "t5-small"  # Changed to T5-small for better CPU compatibility

@st.cache_resource
def load_local_model():
    """Load the local Hugging Face model"""
    try:
        tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
        model = AutoModelForSeq2SeqLM.from_pretrained(
            MODEL_PATH,
            device_map={"": "cpu"},  # Force CPU
            torch_dtype=torch.float32
        )
        return model, tokenizer
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
        return None, None

def fetch_arxiv_papers(query, max_results=5):
    """Fetch papers from arXiv"""
    client = arxiv.Client()
    
    # Always include autism in the search query
    search_query = f"(ti:autism OR abs:autism) AND (ti:\"{query}\" OR abs:\"{query}\") AND cat:q-bio"
    
    # Search arXiv
    search = arxiv.Search(
        query=search_query,
        max_results=max_results,
        sort_by=arxiv.SortCriterion.Relevance
    )
    
    papers = []
    for result in client.results(search):
        # Only include papers that mention autism in title or abstract
        if ('autism' in result.title.lower() or 
            'asd' in result.title.lower() or 
            'autism' in result.summary.lower() or 
            'asd' in result.summary.lower()):
            papers.append({
                'title': result.title,
                'abstract': result.summary,
                'url': result.pdf_url,
                'published': result.published.strftime("%Y-%m-%d"),
                'relevance_score': 1 if 'autism' in result.title.lower() else 0.5
            })
    
    return papers

def fetch_pubmed_papers(query, max_results=5):
    """Fetch papers from PubMed"""
    base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
    
    # Always include autism in the search term
    search_term = f"(autism[Title/Abstract] OR ASD[Title/Abstract]) AND ({query}[Title/Abstract])"
    
    # Search for papers
    search_url = f"{base_url}/esearch.fcgi"
    search_params = {
        'db': 'pubmed',
        'term': search_term,
        'retmax': max_results,
        'sort': 'relevance',
        'retmode': 'xml'
    }
    
    papers = []
    try:
        # Get paper IDs
        response = requests.get(search_url, params=search_params)
        root = ET.fromstring(response.content)
        id_list = [id_elem.text for id_elem in root.findall('.//Id')]
        
        if not id_list:
            return papers
        
        # Fetch paper details
        fetch_url = f"{base_url}/efetch.fcgi"
        fetch_params = {
            'db': 'pubmed',
            'id': ','.join(id_list),
            'retmode': 'xml'
        }
        
        response = requests.get(fetch_url, params=fetch_params)
        articles = ET.fromstring(response.content)
        
        for article in articles.findall('.//PubmedArticle'):
            title = article.find('.//ArticleTitle')
            abstract = article.find('.//Abstract/AbstractText')
            year = article.find('.//PubDate/Year')
            pmid = article.find('.//PMID')
            
            if title is not None and abstract is not None:
                title_text = title.text.lower()
                abstract_text = abstract.text.lower()
                
                # Only include papers that mention autism
                if ('autism' in title_text or 'asd' in title_text or 
                    'autism' in abstract_text or 'asd' in abstract_text):
                    papers.append({
                        'title': title.text,
                        'abstract': abstract.text,
                        'url': f"https://pubmed.ncbi.nlm.nih.gov/{pmid.text}/",
                        'published': year.text if year is not None else 'Unknown',
                        'relevance_score': 1 if ('autism' in title_text or 'asd' in title_text) else 0.5
                    })
            
    except Exception as e:
        st.error(f"Error fetching PubMed papers: {str(e)}")
    
    return papers

def search_research_papers(query):
    """Search both arXiv and PubMed for papers"""
    arxiv_papers = fetch_arxiv_papers(query)
    pubmed_papers = fetch_pubmed_papers(query)
    
    # Combine and format papers
    all_papers = []
    for paper in arxiv_papers + pubmed_papers:
        if paper['abstract'] and len(paper['abstract'].strip()) > 0:
            # Check if the paper is actually about autism
            if ('autism' in paper['title'].lower() or 
                'asd' in paper['title'].lower() or 
                'autism' in paper['abstract'].lower() or 
                'asd' in paper['abstract'].lower()):
                all_papers.append({
                    'title': paper['title'],
                    'text': f"Title: {paper['title']}\n\nAbstract: {paper['abstract']}",
                    'url': paper['url'],
                    'published': paper['published'],
                    'relevance_score': paper.get('relevance_score', 0.5)
                })
    
    # Sort papers by relevance score and convert to DataFrame
    all_papers.sort(key=lambda x: x['relevance_score'], reverse=True)
    df = pd.DataFrame(all_papers)
    
    if df.empty:
        st.warning("No autism-related papers found. Please try a different search term.")
        return pd.DataFrame(columns=['title', 'text', 'url', 'published', 'relevance_score'])
    
    return df

def generate_answer(question, context, max_length=512):
    """Generate a comprehensive answer using the local model"""
    model, tokenizer = load_local_model()
    
    if model is None or tokenizer is None:
        return "Error: Could not load the model. Please try again later."
    
    # Format the context as a structured query
    prompt = f"""You are an expert in autism research. Provide a comprehensive answer about autism, incorporating both general knowledge and specific research findings when available.

Question: {question}

Recent Research Context:
{context}

Instructions: Provide a detailed response that:
1. Starts with a general overview of the topic as it relates to autism
2. Incorporates specific findings from the provided research papers when relevant
3. Discusses practical implications for individuals with autism and their families
4. Mentions any limitations in current understanding

If the research papers don't directly address the question, focus on providing general, well-established information about autism while noting what specific research would be helpful."""
    
    try:
        # Generate response
        inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
        
        with torch.inference_mode():
            outputs = model.generate(
                **inputs,
                max_length=max_length,
                min_length=150,  # Increased minimum length for more comprehensive answers
                num_beams=4,
                length_penalty=1.5,
                temperature=0.7,
                repetition_penalty=1.2,
                early_stopping=True
            )
        
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # If response is too short or empty, provide a general overview
        if len(response.strip()) < 100:
            return f"""Here's what we know about autism in relation to your question about {question}:

1. General Understanding:
- Autism Spectrum Disorder (ASD) is a complex developmental condition
- It affects how a person communicates, learns, and interacts with others
- Each person with autism has unique strengths and challenges

2. Key Aspects:
- Communication and social interaction
- Repetitive behaviors and specific interests
- Sensory sensitivities
- Early intervention is important

3. Current Research:
While the provided research papers don't directly address your specific question, researchers are actively studying various aspects of autism to better understand its causes, characteristics, and effective interventions.

For more specific information, try asking about:
- Specific symptoms or characteristics
- Diagnostic processes
- Treatment approaches
- Current research in specific areas"""
        
        # Format the response for better readability
        formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ")
        
        return formatted_response
    
    except Exception as e:
        st.error(f"Error generating response: {str(e)}")
        return "Error: Could not generate response. Please try again with a different question."

# Streamlit App
st.title("🧩 AMA Autism")

st.write("""
This app searches through scientific papers to answer your questions about autism.
For best results, be specific in your questions.
""")

query = st.text_input("Please ask me anything about autism ✨")

if query:
    with st.status("Searching for answers...") as status:
        # Search for papers
        df = search_research_papers(query)
        
        st.write("Searching for data in PubMed and arXiv...")
        st.write(f"Found {len(df)} relevant papers!")
        
        # Get relevant context
        context = "\n".join([
            f"{text[:1000]}" for text in df['text'].head(3)
        ])
        
        # Generate answer
        st.write("Generating answer...")
        answer = generate_answer(query, context)
    # Display paper sources
    with st.expander("View source papers"):
        for _, paper in df.iterrows():
            st.markdown(f"- [{paper['title']}]({paper['url']}) ({paper['published']})")
    st.success("Answer found!")
    st.markdown(answer)