Spaces:
Sleeping
Sleeping
File size: 4,885 Bytes
f1586e3 d3e32db f91cc3b 8903db2 0f8445a 5a09d5c 8903db2 5a09d5c 8108db5 f1586e3 0452175 f68ac31 d3e32db f68ac31 f99a008 5f3cb01 bf36826 42d1dd5 5f3cb01 d3e32db f68ac31 42d1dd5 f99a008 cc0b0d6 4a9703a cc0b0d6 cc41495 4a9703a 62b3157 54a5022 62b3157 8903db2 f99a008 8903db2 62b3157 8903db2 f99a008 f944585 cc0b0d6 d3e32db 62b3157 cc0b0d6 62b3157 f99a008 42d1dd5 d3e32db cc0b0d6 8903db2 42d1dd5 8903db2 cc0b0d6 42d1dd5 8903db2 42d1dd5 8903db2 cc0b0d6 f1586e3 f68ac31 cc41495 f68ac31 f1586e3 b45a04f d3e32db f99a008 b45a04f 62b3157 b45a04f 62b3157 b45a04f 62b3157 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import os
from datasets import load_from_disk, Dataset
import torch
import logging
import pandas as pd
# Configure logging
logging.basicConfig(level=logging.INFO)
# Define data paths
DATA_DIR = "/data" if os.path.exists("/data") else "."
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
# Cache models and dataset
@st.cache_resource
def load_models():
model_name = "google/flan-t5-small" # Lighter model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map='auto',
max_memory={'cpu': '1GB'}
)
return tokenizer, model
@st.cache_data(ttl=3600) # Cache for 1 hour
def load_dataset(query):
# Always fetch fresh results for the specific query
with st.spinner("Searching research papers from arXiv and PubMed..."):
import faiss_index.index as idx
# Ensure both autism and the query terms are included
if 'autism' not in query.lower():
search_query = f"autism {query}"
else:
search_query = query
papers = idx.fetch_papers(search_query, max_results=25) # This now fetches from both sources
if not papers:
st.warning("No relevant papers found. Please try rephrasing your question.")
return pd.DataFrame(columns=['title', 'text', 'url', 'published'])
idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
# Load and convert to pandas for easier handling
dataset = load_from_disk(DATASET_PATH)
df = pd.DataFrame({
'title': dataset['title'],
'text': dataset['text'],
'url': [p['url'] for p in papers],
'published': [p['published'] for p in papers]
})
return df
def generate_answer(question, context, max_length=150):
tokenizer, model = load_models()
# Improve prompt to generate concise, summarized answers
prompt = f"""Based on scientific research about autism, provide a brief, clear summary answering the following question.
Focus only on the most important findings and be concise.
If the context doesn't contain relevant information about autism, respond with 'I cannot find specific information about this topic in the autism research papers.'
Question: {question}
Context: {context}
Provide a concise summary:"""
# Optimize input processing
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_length=max_length,
num_beams=2,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
early_stopping=True
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clear GPU memory if possible
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Additional validation of the answer
if not answer or answer.isspace() or "cannot find" in answer.lower():
return "I cannot find specific information about this topic in the autism research papers."
return answer
# Streamlit App
st.title("🧩 AMA Autism")
st.write("This app searches through scientific papers to answer your questions about autism. For best results, be specific in your questions.")
query = st.text_input("Please ask me anything about autism ✨")
if query:
with st.status("Searching for answers...") as status:
# Load dataset
df = load_dataset(query)
st.write("Searching for data in PubMed and arXiv...")
# Get relevant context
context = "\n".join([
f"{text[:1000]}" for text in df['text'].head(3)
])
st.write("Data found!")
# Generate answer
answer = generate_answer(query, context)
st.write("Generating answer...")
status.update(
label="Search complete!", state="complete", expanded=False
)
if answer and not answer.isspace():
st.success("Answer found!")
st.write(answer)
st.write("### Sources used:")
for _, row in df.head(3).iterrows():
st.markdown(f"**[{row['title']}]({row['url']})** ({row['published']})")
st.write(f"**Summary:** {row['text'][:200]}...")
st.write("---")
else:
st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.")
if df.empty:
st.warning("I couldn't find any relevant research papers about this topic. Please try rephrasing your question or ask something else about autism.") |