ama-autism / app.py
wakeupmh's picture
fix: improve mem usage
42d1dd5
raw
history blame
3.6 kB
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import os
from datasets import load_from_disk, Dataset
import torch
import logging
import pandas as pd
# Configure logging
logging.basicConfig(level=logging.INFO)
# Define data paths
DATA_DIR = "/data" if os.path.exists("/data") else "."
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
# Cache models and dataset
@st.cache_resource
def load_models():
model_name = "google/flan-t5-small" # Lighter model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map='auto',
max_memory={'cpu': '1GB'}
)
return tokenizer, model
@st.cache_data(ttl=3600) # Cache for 1 hour
def load_dataset(query):
# Create initial dataset if it doesn't exist
if not os.path.exists(DATASET_PATH):
with st.spinner("Building initial dataset from autism research papers..."):
import faiss_index.index as idx
papers = idx.fetch_arxiv_papers(f"{query} AND (cat:q-bio.NC OR cat:q-bio.QM OR cat:q-bio.GN OR cat:q-bio.CB OR cat:q-bio.MN)", max_results=25) # Reduced max results
idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
# Load and convert to pandas for easier handling
dataset = load_from_disk(DATASET_PATH)
df = pd.DataFrame({
'title': dataset['title'],
'text': dataset['text']
})
return df
def generate_answer(question, context, max_length=150): # Reduced max length
tokenizer, model = load_models()
# Add context about medical information
prompt = f"Based on scientific research about autism and health: question: {question} context: {context}"
# Optimize input processing
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
with torch.inference_mode(): # More efficient than no_grad
outputs = model.generate(
**inputs,
max_length=max_length,
num_beams=2, # Reduced beam search
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
early_stopping=True
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clear GPU memory if possible
if torch.cuda.is_available():
torch.cuda.empty_cache()
return answer if answer and not answer.isspace() else "I cannot find a specific answer to this question in the provided context."
# Streamlit App
st.title("🧩 AMA Autism")
query = st.text_input("Please ask me anything about autism ✨")
if query:
with st.status("Searching for answers..."):
# Load dataset
df = load_dataset(query)
# Get relevant context
context = "\n".join([
f"{text[:1000]}" for text in df['text'].head(3)
])
# Generate answer
answer = generate_answer(query, context)
if answer and not answer.isspace():
st.success("Answer found!")
st.write(answer)
st.write("### Sources Used:")
for _, row in df.head(3).iterrows():
st.write(f"**Title:** {row['title']}")
st.write(f"**Summary:** {row['text'][:200]}...")
st.write("---")
else:
st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.")