Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| import os | |
| from datasets import load_from_disk, Dataset | |
| import torch | |
| import logging | |
| import pandas as pd | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| # Define data paths | |
| DATA_DIR = "/data" if os.path.exists("/data") else "." | |
| DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset") | |
| DATASET_PATH = os.path.join(DATASET_DIR, "dataset") | |
| # Cache models and dataset | |
| def load_models(): | |
| model_name = "google/flan-t5-small" # Lighter model | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSeq2SeqLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True, | |
| device_map='auto', | |
| max_memory={'cpu': '1GB'} | |
| ) | |
| return tokenizer, model | |
| # Cache for 1 hour | |
| def load_dataset(query): | |
| # Always fetch fresh results for the specific query | |
| with st.spinner("Searching research papers from arXiv and PubMed..."): | |
| import faiss_index.index as idx | |
| # Ensure both autism and the query terms are included | |
| if 'autism' not in query.lower(): | |
| search_query = f"autism {query}" | |
| else: | |
| search_query = query | |
| papers = idx.fetch_papers(search_query, max_results=25) # This now fetches from both sources | |
| if not papers: | |
| st.warning("No relevant papers found. Please try rephrasing your question.") | |
| return pd.DataFrame(columns=['title', 'text', 'url', 'published']) | |
| idx.build_faiss_index(papers, dataset_dir=DATASET_DIR) | |
| # Load and convert to pandas for easier handling | |
| dataset = load_from_disk(DATASET_PATH) | |
| df = pd.DataFrame({ | |
| 'title': dataset['title'], | |
| 'text': dataset['text'], | |
| 'url': [p['url'] for p in papers], | |
| 'published': [p['published'] for p in papers] | |
| }) | |
| return df | |
| def generate_answer(question, context, max_length=300): | |
| tokenizer, model = load_models() | |
| # Enhanced prompt for more detailed and structured answers | |
| prompt = f"""Based on scientific research about autism, provide a comprehensive and structured summary answering the following question. | |
| Include the following aspects when relevant: | |
| 1. Main findings and conclusions | |
| 2. Supporting evidence or research methods | |
| 3. Clinical implications or practical applications | |
| 4. Any limitations or areas needing further research | |
| Use clear headings and bullet points when appropriate to organize the information. | |
| If the context doesn't contain relevant information about autism, respond with 'I cannot find specific information about this topic in the autism research papers.' | |
| Question: {question} | |
| Context: {context} | |
| Detailed summary:""" | |
| # Optimize input processing | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=768) | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| num_beams=4, | |
| temperature=0.8, | |
| top_p=0.9, | |
| repetition_penalty=1.3, | |
| length_penalty=1.2, | |
| early_stopping=True | |
| ) | |
| answer = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Clear GPU memory if possible | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Enhanced answer validation and formatting | |
| if not answer or answer.isspace() or "cannot find" in answer.lower(): | |
| return "I cannot find specific information about this topic in the autism research papers." | |
| # Format the answer with proper line breaks and structure | |
| formatted_answer = answer.replace(". ", ".\n").replace("• ", "\n• ") | |
| return formatted_answer | |
| # Streamlit App | |
| st.title("🧩 AMA Autism") | |
| st.write("This app searches through scientific papers to answer your questions about autism. For best results, be specific in your questions.") | |
| query = st.text_input("Please ask me anything about autism ✨") | |
| if query: | |
| with st.status("Searching for answers...") as status: | |
| # Load dataset | |
| df = load_dataset(query) | |
| st.write("Searching for data in PubMed and arXiv...") | |
| # Get relevant context | |
| context = "\n".join([ | |
| f"{text[:1000]}" for text in df['text'].head(3) | |
| ]) | |
| st.write("Data found!") | |
| # Generate answer | |
| answer = generate_answer(query, context) | |
| st.write("Generating answer...") | |
| status.update( | |
| label="Search complete!", state="complete", expanded=False | |
| ) | |
| if answer and not answer.isspace(): | |
| st.success("Answer found!") | |
| st.write(answer) | |
| st.write("### Sources used:") | |
| for _, row in df.head(3).iterrows(): | |
| st.markdown(f"**[{row['title']}]({row['url']})** ({row['published']})") | |
| st.write(f"**Summary:** {row['text'][:200]}...") | |
| st.write("---") | |
| else: | |
| st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.") | |
| if df.empty: | |
| st.warning("I couldn't find any relevant research papers about this topic. Please try rephrasing your question or ask something else about autism.") |