ama-autism / app.py
wakeupmh's picture
fix: try to lightweight it
d3e32db
raw
history blame
2.44 kB
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import os
from datasets import load_from_disk
import torch
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
# Define data paths
DATA_DIR = "/data" if os.path.exists("/data") else "."
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
# Cache models and dataset
@st.cache_resource
def load_models():
model_name = "t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
return tokenizer, model
def generate_answer(question, context, max_length=200):
tokenizer, model = load_models()
# Encode the question and context
inputs = tokenizer(
f"question: {question} context: {context}",
add_special_tokens=True,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True
)
# Get model predictions
with torch.no_grad():
outputs = model(**inputs)
answer_ids = torch.argmax(outputs.logits, dim=-1)
# Convert token positions to text
answer = tokenizer.decode(answer_ids[0], skip_special_tokens=True)
return answer if answer and not answer.isspace() else "I cannot find a specific answer to this question in the provided context."
# Streamlit App
st.title("🧩 AMA Autism")
query = st.text_input("Please ask me anything about autism ✨")
if query:
with st.status("Searching for answers..."):
# Load dataset
dataset = load_dataset()
# Get relevant context
context = "\n".join([
f"{paper['text'][:1000]}" # Use more context for better answers
for paper in dataset[:3]
])
# Generate answer
answer = generate_answer(query, context)
if answer and not answer.isspace():
st.success("Answer found!")
st.write(answer)
st.write("### Sources Used:")
for i in range(min(3, len(dataset))):
st.write(f"**Title:** {dataset[i]['title']}")
st.write(f"**Summary:** {dataset[i]['text'][:200]}...")
st.write("---")
else:
st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.")