File size: 2,436 Bytes
f1586e3
d3e32db
f91cc3b
f68ac31
0f8445a
5a09d5c
 
 
8108db5
f1586e3
0452175
 
 
 
 
f68ac31
d3e32db
f68ac31
d3e32db
 
 
 
f68ac31
d3e32db
 
 
 
 
 
 
 
 
 
 
 
 
 
f68ac31
d3e32db
 
 
 
 
 
 
f1586e3
f68ac31
 
 
f1586e3
 
f68ac31
d3e32db
f68ac31
d3e32db
 
 
 
 
 
 
 
 
 
 
f68ac31
 
d3e32db
 
 
 
 
 
f68ac31
d3e32db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import os
from datasets import load_from_disk
import torch
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)

# Define data paths
DATA_DIR = "/data" if os.path.exists("/data") else "."
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")

# Cache models and dataset
@st.cache_resource
def load_models():
    model_name = "t5-base"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    return tokenizer, model

def generate_answer(question, context, max_length=200):
    tokenizer, model = load_models()
    
    # Encode the question and context
    inputs = tokenizer(
        f"question: {question} context: {context}",
        add_special_tokens=True,
        return_tensors="pt",
        max_length=512,
        truncation=True,
        padding=True
    )
    
    # Get model predictions
    with torch.no_grad():
        outputs = model(**inputs)
        answer_ids = torch.argmax(outputs.logits, dim=-1)
        
        # Convert token positions to text
        answer = tokenizer.decode(answer_ids[0], skip_special_tokens=True)
        
    return answer if answer and not answer.isspace() else "I cannot find a specific answer to this question in the provided context."

# Streamlit App
st.title("🧩 AMA Autism")
query = st.text_input("Please ask me anything about autism ✨")

if query:
    with st.status("Searching for answers..."):
        # Load dataset
        dataset = load_dataset()
        
        # Get relevant context
        context = "\n".join([
            f"{paper['text'][:1000]}"  # Use more context for better answers
            for paper in dataset[:3]
        ])
        
        # Generate answer
        answer = generate_answer(query, context)
        
        if answer and not answer.isspace():
            st.success("Answer found!")
            st.write(answer)
            
            st.write("### Sources Used:")
            for i in range(min(3, len(dataset))):
                st.write(f"**Title:** {dataset[i]['title']}")
                st.write(f"**Summary:** {dataset[i]['text'][:200]}...")
                st.write("---")
        else:
            st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.")