File size: 3,599 Bytes
f1586e3
d3e32db
f91cc3b
8903db2
0f8445a
5a09d5c
8903db2
5a09d5c
 
8108db5
f1586e3
0452175
 
 
 
 
f68ac31
d3e32db
f68ac31
f99a008
5f3cb01
 
 
bf36826
42d1dd5
 
 
5f3cb01
d3e32db
f68ac31
42d1dd5
f99a008
f944585
 
 
 
42d1dd5
f944585
8903db2
 
 
f99a008
8903db2
 
 
f99a008
f944585
42d1dd5
d3e32db
 
f99a008
 
 
42d1dd5
 
d3e32db
42d1dd5
8903db2
42d1dd5
8903db2
42d1dd5
 
 
 
8903db2
 
42d1dd5
 
 
 
 
 
8903db2
d3e32db
f1586e3
f68ac31
 
 
f1586e3
 
f68ac31
d3e32db
f99a008
d3e32db
 
 
8903db2
d3e32db
 
 
 
 
 
f68ac31
 
d3e32db
 
8903db2
 
 
d3e32db
f68ac31
d3e32db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import os
from datasets import load_from_disk, Dataset
import torch
import logging
import pandas as pd

# Configure logging
logging.basicConfig(level=logging.INFO)

# Define data paths
DATA_DIR = "/data" if os.path.exists("/data") else "."
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")

# Cache models and dataset
@st.cache_resource
def load_models():
    model_name = "google/flan-t5-small"  # Lighter model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        device_map='auto',
        max_memory={'cpu': '1GB'}
    )
    return tokenizer, model

@st.cache_data(ttl=3600)  # Cache for 1 hour
def load_dataset(query):
    # Create initial dataset if it doesn't exist
    if not os.path.exists(DATASET_PATH):
        with st.spinner("Building initial dataset from autism research papers..."):
            import faiss_index.index as idx
            papers = idx.fetch_arxiv_papers(f"{query} AND (cat:q-bio.NC OR cat:q-bio.QM OR cat:q-bio.GN OR cat:q-bio.CB OR cat:q-bio.MN)", max_results=25)  # Reduced max results
            idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
    
    # Load and convert to pandas for easier handling
    dataset = load_from_disk(DATASET_PATH)
    df = pd.DataFrame({
        'title': dataset['title'],
        'text': dataset['text']
    })
    return df

def generate_answer(question, context, max_length=150):  # Reduced max length
    tokenizer, model = load_models()
    
    # Add context about medical information
    prompt = f"Based on scientific research about autism and health: question: {question} context: {context}"
    
    # Optimize input processing
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
    
    with torch.inference_mode():  # More efficient than no_grad
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            num_beams=2,  # Reduced beam search
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.2,
            early_stopping=True
        )
    
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Clear GPU memory if possible
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return answer if answer and not answer.isspace() else "I cannot find a specific answer to this question in the provided context."

# Streamlit App
st.title("🧩 AMA Autism")
query = st.text_input("Please ask me anything about autism ✨")

if query:
    with st.status("Searching for answers..."):
        # Load dataset
        df = load_dataset(query)
        
        # Get relevant context
        context = "\n".join([
            f"{text[:1000]}" for text in df['text'].head(3)
        ])
        
        # Generate answer
        answer = generate_answer(query, context)
        
        if answer and not answer.isspace():
            st.success("Answer found!")
            st.write(answer)
            
            st.write("### Sources Used:")
            for _, row in df.head(3).iterrows():
                st.write(f"**Title:** {row['title']}")
                st.write(f"**Summary:** {row['text'][:200]}...")
                st.write("---")
        else:
            st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.")