import streamlit as st from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration import faiss import os from datasets import load_from_disk # Title st.title("AMA Austim 🧩") # Input: Query query = st.text_input("Please ask me anything about autism ✨") # Load or create RAG dataset def load_rag_dataset(dataset_dir="rag_dataset"): if not os.path.exists(dataset_dir): # Import the build function from the other file import faiss_index.index as faiss_index_index # Fetch some initial papers to build the index initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100) dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir) # Load the dataset and index dataset = load_from_disk(os.path.join(dataset_dir, "dataset")) index = faiss.read_index(os.path.join(dataset_dir, "embeddings.faiss")) return dataset, index # RAG Pipeline def rag_pipeline(query, dataset, index): # Load pre-trained RAG model and configure retriever tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") retriever = RagRetriever.from_pretrained( "facebook/rag-sequence-nq", index_name="custom", passages_path=os.path.join("rag_dataset", "dataset"), index_path=os.path.join("rag_dataset", "embeddings.faiss") ) model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever) # Generate answer using RAG inputs = tokenizer(query, return_tensors="pt") generated_ids = model.generate(inputs["input_ids"]) answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] return answer # Run the app if query: st.write("Loading or creating RAG dataset...") dataset, index = load_rag_dataset() st.write("Running RAG pipeline...") answer = rag_pipeline(query, dataset, index) st.write("### Answer:") st.write(answer) st.write("### Retrieved Papers:") for i in range(min(5, len(dataset))): st.write(f"**Title:** {dataset[i]['title']}") st.write(f"**Summary:** {dataset[i]['text'][:200]}...") st.write("---")