File size: 2,877 Bytes
f1586e3
d3e32db
f91cc3b
f68ac31
0f8445a
5a09d5c
 
 
8108db5
f1586e3
0452175
 
 
 
 
f68ac31
d3e32db
f68ac31
d3e32db
 
 
 
f68ac31
f944585
 
 
 
 
 
 
 
 
 
d3e32db
 
 
 
 
 
 
 
 
 
 
 
 
 
f68ac31
d3e32db
 
 
 
 
 
 
f1586e3
f68ac31
 
 
f1586e3
 
f68ac31
d3e32db
f68ac31
d3e32db
 
 
 
 
 
 
 
 
 
 
f68ac31
 
d3e32db
 
 
 
 
 
f68ac31
d3e32db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import os
from datasets import load_from_disk
import torch
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)

# Define data paths
DATA_DIR = "/data" if os.path.exists("/data") else "."
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")

# Cache models and dataset
@st.cache_resource
def load_models():
    model_name = "t5-base"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    return tokenizer, model

@st.cache_data
def load_dataset():
    # Create initial dataset if it doesn't exist
    if not os.path.exists(DATASET_PATH):
        with st.spinner("Building initial dataset from autism research papers..."):
            import faiss_index.index as idx
            papers = idx.fetch_arxiv_papers("autism research", max_results=100)
            idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
    return load_from_disk(DATASET_PATH)

def generate_answer(question, context, max_length=200):
    tokenizer, model = load_models()
    
    # Encode the question and context
    inputs = tokenizer(
        f"question: {question} context: {context}",
        add_special_tokens=True,
        return_tensors="pt",
        max_length=512,
        truncation=True,
        padding=True
    )
    
    # Get model predictions
    with torch.no_grad():
        outputs = model(**inputs)
        answer_ids = torch.argmax(outputs.logits, dim=-1)
        
        # Convert token positions to text
        answer = tokenizer.decode(answer_ids[0], skip_special_tokens=True)
        
    return answer if answer and not answer.isspace() else "I cannot find a specific answer to this question in the provided context."

# Streamlit App
st.title("🧩 AMA Autism")
query = st.text_input("Please ask me anything about autism ✨")

if query:
    with st.status("Searching for answers..."):
        # Load dataset
        dataset = load_dataset()
        
        # Get relevant context
        context = "\n".join([
            f"{paper['text'][:1000]}"  # Use more context for better answers
            for paper in dataset[:3]
        ])
        
        # Generate answer
        answer = generate_answer(query, context)
        
        if answer and not answer.isspace():
            st.success("Answer found!")
            st.write(answer)
            
            st.write("### Sources Used:")
            for i in range(min(3, len(dataset))):
                st.write(f"**Title:** {dataset[i]['title']}")
                st.write(f"**Summary:** {dataset[i]['text'][:200]}...")
                st.write("---")
        else:
            st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.")