File size: 4,885 Bytes
f1586e3
d3e32db
f91cc3b
8903db2
0f8445a
5a09d5c
8903db2
5a09d5c
 
8108db5
f1586e3
0452175
 
 
 
 
f68ac31
d3e32db
f68ac31
f99a008
5f3cb01
 
 
bf36826
42d1dd5
 
 
5f3cb01
d3e32db
f68ac31
42d1dd5
f99a008
cc0b0d6
4a9703a
cc0b0d6
cc41495
 
 
 
 
 
4a9703a
62b3157
 
 
 
54a5022
62b3157
8903db2
 
 
f99a008
8903db2
62b3157
 
 
8903db2
f99a008
f944585
cc0b0d6
d3e32db
 
62b3157
 
 
cc0b0d6
 
 
62b3157
 
 
f99a008
42d1dd5
 
d3e32db
cc0b0d6
8903db2
42d1dd5
8903db2
cc0b0d6
42d1dd5
 
 
8903db2
 
42d1dd5
 
 
 
 
 
8903db2
cc0b0d6
 
 
 
 
f1586e3
f68ac31
 
cc41495
f68ac31
f1586e3
 
b45a04f
d3e32db
f99a008
b45a04f
62b3157
 
 
 
b45a04f
62b3157
 
b45a04f
 
62b3157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import os
from datasets import load_from_disk, Dataset
import torch
import logging
import pandas as pd

# Configure logging
logging.basicConfig(level=logging.INFO)

# Define data paths
DATA_DIR = "/data" if os.path.exists("/data") else "."
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")

# Cache models and dataset
@st.cache_resource
def load_models():
    model_name = "google/flan-t5-small"  # Lighter model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        device_map='auto',
        max_memory={'cpu': '1GB'}
    )
    return tokenizer, model

@st.cache_data(ttl=3600)  # Cache for 1 hour
def load_dataset(query):
    # Always fetch fresh results for the specific query
    with st.spinner("Searching research papers from arXiv and PubMed..."):
        import faiss_index.index as idx
        # Ensure both autism and the query terms are included
        if 'autism' not in query.lower():
            search_query = f"autism {query}"
        else:
            search_query = query
            
        papers = idx.fetch_papers(search_query, max_results=25)  # This now fetches from both sources
    
    if not papers:
        st.warning("No relevant papers found. Please try rephrasing your question.")
        return pd.DataFrame(columns=['title', 'text', 'url', 'published'])
            
    idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
    
    # Load and convert to pandas for easier handling
    dataset = load_from_disk(DATASET_PATH)
    df = pd.DataFrame({
        'title': dataset['title'],
        'text': dataset['text'],
        'url': [p['url'] for p in papers],
        'published': [p['published'] for p in papers]
    })
    return df

def generate_answer(question, context, max_length=150):
    tokenizer, model = load_models()
    
    # Improve prompt to generate concise, summarized answers
    prompt = f"""Based on scientific research about autism, provide a brief, clear summary answering the following question. 
    Focus only on the most important findings and be concise.
    If the context doesn't contain relevant information about autism, respond with 'I cannot find specific information about this topic in the autism research papers.'
    
    Question: {question}
    Context: {context}
    
    Provide a concise summary:"""
    
    # Optimize input processing
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
    
    with torch.inference_mode():
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            num_beams=2,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.2,
            early_stopping=True
        )
    
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Clear GPU memory if possible
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # Additional validation of the answer
    if not answer or answer.isspace() or "cannot find" in answer.lower():
        return "I cannot find specific information about this topic in the autism research papers."
    
    return answer

# Streamlit App
st.title("🧩 AMA Autism")
st.write("This app searches through scientific papers to answer your questions about autism. For best results, be specific in your questions.")
query = st.text_input("Please ask me anything about autism ✨")

if query:
    with st.status("Searching for answers...") as status:
        # Load dataset
        df = load_dataset(query)
        st.write("Searching for data in PubMed and arXiv...")
        # Get relevant context
        context = "\n".join([
            f"{text[:1000]}" for text in df['text'].head(3)
        ])
        st.write("Data found!")
        # Generate answer
        answer = generate_answer(query, context)
        st.write("Generating answer...")
        status.update(
            label="Search complete!", state="complete", expanded=False
        )
    if answer and not answer.isspace():
        st.success("Answer found!")
        st.write(answer)
        
        st.write("### Sources used:")
        for _, row in df.head(3).iterrows():
            st.markdown(f"**[{row['title']}]({row['url']})** ({row['published']})")
            st.write(f"**Summary:** {row['text'][:200]}...")
            st.write("---")
    else:
        st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.")
    if df.empty:
        st.warning("I couldn't find any relevant research papers about this topic. Please try rephrasing your question or ask something else about autism.")