File size: 7,367 Bytes
f1586e3
67b3686
f91cc3b
8903db2
0f8445a
5a09d5c
8903db2
4660a83
 
 
5a09d5c
 
8108db5
f1586e3
4660a83
0452175
 
 
97889da
0452175
d3e32db
4660a83
 
 
5f3cb01
4660a83
97889da
4660a83
5f3cb01
4660a83
f68ac31
4660a83
 
 
62b3157
97889da
 
 
 
 
4660a83
 
 
 
 
 
 
 
 
 
 
 
 
 
97889da
4660a83
8903db2
4660a83
f944585
4660a83
 
 
 
97889da
 
 
 
 
 
4660a83
 
 
 
97889da
4660a83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97889da
 
4660a83
97889da
 
 
 
 
 
 
4660a83
 
 
d3e32db
4660a83
 
 
 
 
 
cb9a068
4660a83
 
 
97889da
 
 
 
 
 
 
cc0b0d6
4660a83
 
 
 
 
62b3157
4660a83
97889da
4660a83
 
 
 
97889da
 
 
 
 
4660a83
97889da
4660a83
97889da
f99a008
4660a83
97889da
d3e32db
cc0b0d6
8903db2
42d1dd5
8903db2
97889da
 
 
4660a83
97889da
 
8903db2
 
42d1dd5
4660a83
8903db2
97889da
 
 
 
 
 
 
 
 
 
 
4660a83
cc0b0d6
4660a83
f1586e3
f68ac31
 
67b3686
 
 
 
 
 
 
f1586e3
 
b45a04f
4660a83
 
67b3686
b45a04f
67b3686
 
62b3157
 
 
 
4660a83
62b3157
b45a04f
67b3686
97889da
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import os
from datasets import load_from_disk, Dataset
import torch
import logging
import pandas as pd
import arxiv
import requests
import xml.etree.ElementTree as ET

# Configure logging
logging.basicConfig(level=logging.INFO)

# Define data paths and constants
DATA_DIR = "/data" if os.path.exists("/data") else "."
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
MODEL_PATH = "facebook/bart-large-cnn"  # Changed to BART model which is better for summarization

@st.cache_resource
def load_local_model():
    """Load the local Hugging Face model"""
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    model = AutoModelForSeq2SeqLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=torch.float32,
        device_map="auto"
    )
    return model, tokenizer

def fetch_arxiv_papers(query, max_results=5):
    """Fetch papers from arXiv"""
    client = arxiv.Client()
    
    # Ensure query includes autism-related terms
    if 'autism' not in query.lower():
        search_query = f"(ti:{query} OR abs:{query}) AND (ti:autism OR abs:autism) AND cat:q-bio"
    else:
        search_query = f"(ti:{query} OR abs:{query}) AND cat:q-bio"
    
    # Search arXiv
    search = arxiv.Search(
        query=search_query,
        max_results=max_results,
        sort_by=arxiv.SortCriterion.Relevance
    )
    
    papers = []
    for result in client.results(search):
        papers.append({
            'title': result.title,
            'abstract': result.summary,
            'url': result.pdf_url,
            'published': result.published.strftime("%Y-%m-%d")
        })
    
    return papers

def fetch_pubmed_papers(query, max_results=5):
    """Fetch papers from PubMed"""
    base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
    
    # Ensure query includes autism-related terms
    if 'autism' not in query.lower():
        search_term = f"({query}) AND (autism[Title/Abstract] OR ASD[Title/Abstract])"
    else:
        search_term = query
    
    # Search for papers
    search_url = f"{base_url}/esearch.fcgi"
    search_params = {
        'db': 'pubmed',
        'term': search_term,
        'retmax': max_results,
        'sort': 'relevance',
        'retmode': 'xml'
    }
    
    papers = []
    try:
        # Get paper IDs
        response = requests.get(search_url, params=search_params)
        root = ET.fromstring(response.content)
        id_list = [id_elem.text for id_elem in root.findall('.//Id')]
        
        if not id_list:
            return papers
        
        # Fetch paper details
        fetch_url = f"{base_url}/efetch.fcgi"
        fetch_params = {
            'db': 'pubmed',
            'id': ','.join(id_list),
            'retmode': 'xml'
        }
        
        response = requests.get(fetch_url, params=fetch_params)
        articles = ET.fromstring(response.content)
        
        for article in articles.findall('.//PubmedArticle'):
            title = article.find('.//ArticleTitle')
            abstract = article.find('.//Abstract/AbstractText')
            year = article.find('.//PubDate/Year')
            pmid = article.find('.//PMID')
            
            if title is not None and abstract is not None:
                papers.append({
                    'title': title.text,
                    'abstract': abstract.text,
                    'url': f"https://pubmed.ncbi.nlm.nih.gov/{pmid.text}/",
                    'published': year.text if year is not None else 'Unknown'
                })
            
    except Exception as e:
        st.error(f"Error fetching PubMed papers: {str(e)}")
    
    return papers

def search_research_papers(query):
    """Search both arXiv and PubMed for papers"""
    arxiv_papers = fetch_arxiv_papers(query)
    pubmed_papers = fetch_pubmed_papers(query)
    
    # Combine and format papers
    all_papers = []
    for paper in arxiv_papers + pubmed_papers:
        if paper['abstract'] and len(paper['abstract'].strip()) > 0:
            all_papers.append({
                'title': paper['title'],
                'text': f"Title: {paper['title']}\n\nAbstract: {paper['abstract']}",
                'url': paper['url'],
                'published': paper['published']
            })
    
    return pd.DataFrame(all_papers)

def generate_answer(question, context, max_length=512):
    """Generate a comprehensive answer using the local model"""
    model, tokenizer = load_local_model()
    
    # Format the context as a structured query
    prompt = f"""Summarize the following research about autism and answer the question.

Research Context:
{context}

Question: {question}

Provide a detailed answer that includes:
1. Main findings from the research
2. Research methods used
3. Clinical implications
4. Limitations of the studies

If the research doesn't address the question directly, explain what information is missing."""
    
    # Generate response
    inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
    
    with torch.inference_mode():
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            min_length=200,  # Ensure longer responses
            num_beams=5,
            length_penalty=2.0,  # Encourage even longer responses
            temperature=0.7,
            no_repeat_ngram_size=3,
            repetition_penalty=1.3,
            early_stopping=True
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # If response is too short or empty, provide a fallback message
    if len(response.strip()) < 100:
        return """I apologize, but I couldn't generate a specific answer from the research papers provided. 
        This might be because:
        1. The research papers don't directly address your question
        2. The context needs more specific information
        3. The question might need to be more specific
        
        Please try rephrasing your question or ask about a more specific aspect of autism."""
    
    # Format the response for better readability
    formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ")
    
    return formatted_response

# Streamlit App
st.title("🧩 AMA Autism")

st.write("""
This app searches through scientific papers to answer your questions about autism.
For best results, be specific in your questions.
""")

query = st.text_input("Please ask me anything about autism >")

if query:
    with st.status("Searching for answers...") as status:
        # Search for papers
        df = search_research_papers(query)
        
        st.write("Searching for data in PubMed and arXiv...")
        st.write(f"Found {len(df)} relevant papers!")
        
        # Get relevant context
        context = "\n".join([
            f"{text[:1000]}" for text in df['text'].head(3)
        ])
        
        # Generate answer
        st.write("Generating answer...")
        answer = generate_answer(query, context)
    # Display paper sources
    with st.expander("View source papers"):
        for _, paper in df.iterrows():
            st.markdown(f"- [{paper['title']}]({paper['url']}) ({paper['published']})")
    st.success("Answer found!")
    st.markdown(answer)