Spaces:
Sleeping
Sleeping
File size: 2,436 Bytes
f1586e3 d3e32db f91cc3b f68ac31 0f8445a 5a09d5c 8108db5 f1586e3 0452175 f68ac31 d3e32db f68ac31 d3e32db f68ac31 d3e32db f68ac31 d3e32db f1586e3 f68ac31 f1586e3 f68ac31 d3e32db f68ac31 d3e32db f68ac31 d3e32db f68ac31 d3e32db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import os
from datasets import load_from_disk
import torch
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
# Define data paths
DATA_DIR = "/data" if os.path.exists("/data") else "."
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
# Cache models and dataset
@st.cache_resource
def load_models():
model_name = "t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
return tokenizer, model
def generate_answer(question, context, max_length=200):
tokenizer, model = load_models()
# Encode the question and context
inputs = tokenizer(
f"question: {question} context: {context}",
add_special_tokens=True,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True
)
# Get model predictions
with torch.no_grad():
outputs = model(**inputs)
answer_ids = torch.argmax(outputs.logits, dim=-1)
# Convert token positions to text
answer = tokenizer.decode(answer_ids[0], skip_special_tokens=True)
return answer if answer and not answer.isspace() else "I cannot find a specific answer to this question in the provided context."
# Streamlit App
st.title("🧩 AMA Autism")
query = st.text_input("Please ask me anything about autism ✨")
if query:
with st.status("Searching for answers..."):
# Load dataset
dataset = load_dataset()
# Get relevant context
context = "\n".join([
f"{paper['text'][:1000]}" # Use more context for better answers
for paper in dataset[:3]
])
# Generate answer
answer = generate_answer(query, context)
if answer and not answer.isspace():
st.success("Answer found!")
st.write(answer)
st.write("### Sources Used:")
for i in range(min(3, len(dataset))):
st.write(f"**Title:** {dataset[i]['title']}")
st.write(f"**Summary:** {dataset[i]['text'][:200]}...")
st.write("---")
else:
st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.") |