File size: 6,637 Bytes
f1586e3
4660a83
f91cc3b
8903db2
0f8445a
5a09d5c
8903db2
4660a83
 
 
 
 
5a09d5c
 
8108db5
f1586e3
4660a83
0452175
 
 
4660a83
0452175
d3e32db
4660a83
 
 
5f3cb01
4660a83
 
 
5f3cb01
4660a83
f68ac31
4660a83
 
 
62b3157
4660a83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8903db2
4660a83
f944585
4660a83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3e32db
4660a83
 
 
 
 
 
cb9a068
4660a83
 
 
 
 
 
 
 
 
cc0b0d6
4660a83
 
 
 
 
62b3157
4660a83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f99a008
4660a83
 
d3e32db
cc0b0d6
8903db2
42d1dd5
8903db2
4660a83
 
42d1dd5
4660a83
8903db2
 
42d1dd5
4660a83
8903db2
4660a83
 
cc0b0d6
4660a83
f1586e3
f68ac31
 
cc41495
f68ac31
f1586e3
 
b45a04f
4660a83
 
b45a04f
4660a83
 
62b3157
 
 
 
4660a83
62b3157
 
b45a04f
 
62b3157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
import os
from datasets import load_from_disk, Dataset
import torch
import logging
import pandas as pd
import arxiv
import requests
import xml.etree.ElementTree as ET
from agno.embedder.huggingface import HuggingfaceCustomEmbedder
from agno.vectordb.lancedb import LanceDb, SearchType

# Configure logging
logging.basicConfig(level=logging.INFO)

# Define data paths and constants
DATA_DIR = "/data" if os.path.exists("/data") else "."
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
MODEL_PATH = "google/flan-t5-base"  # Lighter model

@st.cache_resource
def load_local_model():
    """Load the local Hugging Face model"""
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    model = AutoModelForSeq2SeqLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=torch.float32,  # Using float32 for CPU compatibility
        device_map="auto"
    )
    return model, tokenizer

def fetch_arxiv_papers(query, max_results=5):
    """Fetch papers from arXiv"""
    client = arxiv.Client()
    
    # Clean and prepare the search query
    search_query = f"ti:{query} OR abs:{query} AND cat:q-bio"
    
    # Search arXiv
    search = arxiv.Search(
        query=search_query,
        max_results=max_results,
        sort_by=arxiv.SortCriterion.Relevance
    )
    
    papers = []
    for result in client.results(search):
        papers.append({
            'title': result.title,
            'abstract': result.summary,
            'url': result.pdf_url,
            'published': result.published
        })
    
    return papers

def fetch_pubmed_papers(query, max_results=5):
    """Fetch papers from PubMed"""
    base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
    
    # Search for papers
    search_url = f"{base_url}/esearch.fcgi"
    search_params = {
        'db': 'pubmed',
        'term': query,
        'retmax': max_results,
        'sort': 'relevance',
        'retmode': 'xml'
    }
    
    papers = []
    try:
        # Get paper IDs
        response = requests.get(search_url, params=search_params)
        root = ET.fromstring(response.content)
        id_list = [id_elem.text for id_elem in root.findall('.//Id')]
        
        if not id_list:
            return papers
        
        # Fetch paper details
        fetch_url = f"{base_url}/efetch.fcgi"
        fetch_params = {
            'db': 'pubmed',
            'id': ','.join(id_list),
            'retmode': 'xml'
        }
        
        response = requests.get(fetch_url, params=fetch_params)
        articles = ET.fromstring(response.content)
        
        for article in articles.findall('.//PubmedArticle'):
            title = article.find('.//ArticleTitle')
            abstract = article.find('.//Abstract/AbstractText')
            
            papers.append({
                'title': title.text if title is not None else 'No title available',
                'abstract': abstract.text if abstract is not None else 'No abstract available',
                'url': f"https://pubmed.ncbi.nlm.nih.gov/{article.find('.//PMID').text}/",
                'published': article.find('.//PubDate/Year').text if article.find('.//PubDate/Year') is not None else 'Unknown'
            })
            
    except Exception as e:
        st.error(f"Error fetching PubMed papers: {str(e)}")
    
    return papers

def search_research_papers(query):
    """Search both arXiv and PubMed for papers"""
    arxiv_papers = fetch_arxiv_papers(query)
    pubmed_papers = fetch_pubmed_papers(query)
    
    # Combine and format papers
    all_papers = []
    for paper in arxiv_papers + pubmed_papers:
        all_papers.append({
            'title': paper['title'],
            'text': f"Title: {paper['title']}\nAbstract: {paper['abstract']}",
            'url': paper['url'],
            'published': paper['published']
        })
    
    return pd.DataFrame(all_papers)

def generate_answer(question, context, max_length=512):
    """Generate a comprehensive answer using the local model"""
    model, tokenizer = load_local_model()
    
    # Format the context as a structured query
    prompt = f"""Based on the following research papers about autism, provide a detailed answer:

Question: {question}

Research Context:
{context}

Please analyze:
1. Main findings
2. Research methods
3. Clinical implications
4. Limitations

Answer:"""
    
    # Generate response
    inputs = tokenizer(prompt, return_tensors="pt", max_length=max_length, truncation=True)
    
    with torch.inference_mode():
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            num_beams=4,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.2,
            early_stopping=True
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Format the response
    formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ")
    
    return formatted_response

# Streamlit App
st.title("🧩 AMA Autism")
st.write("This app searches through scientific papers to answer your questions about autism. For best results, be specific in your questions.")
query = st.text_input("Please ask me anything about autism ✨")

if query:
    with st.status("Searching for answers...") as status:
        # Search for papers
        df = search_research_papers(query)
        st.write("Searching for data in PubMed and arXiv...")
        st.write("Data found!")
        
        # Get relevant context
        context = "\n".join([
            f"{text[:1000]}" for text in df['text'].head(3)
        ])
        
        # Generate answer
        answer = generate_answer(query, context)
        st.write("Generating answer...")
        status.update(
            label="Search complete!", state="complete", expanded=False
        )
    if answer and not answer.isspace():
        st.success("Answer found!")
        st.write(answer)
        
        st.write("### Sources used:")
        for _, row in df.head(3).iterrows():
            st.markdown(f"**[{row['title']}]({row['url']})** ({row['published']})")
            st.write(f"**Summary:** {row['text'][:200]}...")
            st.write("---")
    else:
        st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.")
    if df.empty:
        st.warning("I couldn't find any relevant research papers about this topic. Please try rephrasing your question or ask something else about autism.")