Spaces:
Sleeping
Sleeping
File size: 3,599 Bytes
f1586e3 d3e32db f91cc3b 8903db2 0f8445a 5a09d5c 8903db2 5a09d5c 8108db5 f1586e3 0452175 f68ac31 d3e32db f68ac31 f99a008 5f3cb01 bf36826 42d1dd5 5f3cb01 d3e32db f68ac31 42d1dd5 f99a008 f944585 42d1dd5 f944585 8903db2 f99a008 8903db2 f99a008 f944585 42d1dd5 d3e32db f99a008 42d1dd5 d3e32db 42d1dd5 8903db2 42d1dd5 8903db2 42d1dd5 8903db2 42d1dd5 8903db2 d3e32db f1586e3 f68ac31 f1586e3 f68ac31 d3e32db f99a008 d3e32db 8903db2 d3e32db f68ac31 d3e32db 8903db2 d3e32db f68ac31 d3e32db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import os
from datasets import load_from_disk, Dataset
import torch
import logging
import pandas as pd
# Configure logging
logging.basicConfig(level=logging.INFO)
# Define data paths
DATA_DIR = "/data" if os.path.exists("/data") else "."
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
# Cache models and dataset
@st.cache_resource
def load_models():
model_name = "google/flan-t5-small" # Lighter model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map='auto',
max_memory={'cpu': '1GB'}
)
return tokenizer, model
@st.cache_data(ttl=3600) # Cache for 1 hour
def load_dataset(query):
# Create initial dataset if it doesn't exist
if not os.path.exists(DATASET_PATH):
with st.spinner("Building initial dataset from autism research papers..."):
import faiss_index.index as idx
papers = idx.fetch_arxiv_papers(f"{query} AND (cat:q-bio.NC OR cat:q-bio.QM OR cat:q-bio.GN OR cat:q-bio.CB OR cat:q-bio.MN)", max_results=25) # Reduced max results
idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
# Load and convert to pandas for easier handling
dataset = load_from_disk(DATASET_PATH)
df = pd.DataFrame({
'title': dataset['title'],
'text': dataset['text']
})
return df
def generate_answer(question, context, max_length=150): # Reduced max length
tokenizer, model = load_models()
# Add context about medical information
prompt = f"Based on scientific research about autism and health: question: {question} context: {context}"
# Optimize input processing
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
with torch.inference_mode(): # More efficient than no_grad
outputs = model.generate(
**inputs,
max_length=max_length,
num_beams=2, # Reduced beam search
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
early_stopping=True
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clear GPU memory if possible
if torch.cuda.is_available():
torch.cuda.empty_cache()
return answer if answer and not answer.isspace() else "I cannot find a specific answer to this question in the provided context."
# Streamlit App
st.title("🧩 AMA Autism")
query = st.text_input("Please ask me anything about autism ✨")
if query:
with st.status("Searching for answers..."):
# Load dataset
df = load_dataset(query)
# Get relevant context
context = "\n".join([
f"{text[:1000]}" for text in df['text'].head(3)
])
# Generate answer
answer = generate_answer(query, context)
if answer and not answer.isspace():
st.success("Answer found!")
st.write(answer)
st.write("### Sources Used:")
for _, row in df.head(3).iterrows():
st.write(f"**Title:** {row['title']}")
st.write(f"**Summary:** {row['text'][:200]}...")
st.write("---")
else:
st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.") |