Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import os | |
from datasets import load_from_disk | |
import torch | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
# Define data paths | |
DATA_DIR = "/data" if os.path.exists("/data") else "." | |
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset") | |
DATASET_PATH = os.path.join(DATASET_DIR, "dataset") | |
# Cache models and dataset | |
def load_models(): | |
model_name = "t5-base" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
return tokenizer, model | |
def generate_answer(question, context, max_length=200): | |
tokenizer, model = load_models() | |
# Encode the question and context | |
inputs = tokenizer( | |
f"question: {question} context: {context}", | |
add_special_tokens=True, | |
return_tensors="pt", | |
max_length=512, | |
truncation=True, | |
padding=True | |
) | |
# Get model predictions | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
answer_ids = torch.argmax(outputs.logits, dim=-1) | |
# Convert token positions to text | |
answer = tokenizer.decode(answer_ids[0], skip_special_tokens=True) | |
return answer if answer and not answer.isspace() else "I cannot find a specific answer to this question in the provided context." | |
# Streamlit App | |
st.title("🧩 AMA Autism") | |
query = st.text_input("Please ask me anything about autism ✨") | |
if query: | |
with st.status("Searching for answers..."): | |
# Load dataset | |
dataset = load_dataset() | |
# Get relevant context | |
context = "\n".join([ | |
f"{paper['text'][:1000]}" # Use more context for better answers | |
for paper in dataset[:3] | |
]) | |
# Generate answer | |
answer = generate_answer(query, context) | |
if answer and not answer.isspace(): | |
st.success("Answer found!") | |
st.write(answer) | |
st.write("### Sources Used:") | |
for i in range(min(3, len(dataset))): | |
st.write(f"**Title:** {dataset[i]['title']}") | |
st.write(f"**Summary:** {dataset[i]['text'][:200]}...") | |
st.write("---") | |
else: | |
st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.") |