Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	fix: rag
Browse files- app.py +45 -109
- faiss_index/index.py +54 -86
    	
        app.py
    CHANGED
    
    | @@ -2,124 +2,60 @@ import streamlit as st | |
| 2 | 
             
            from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
         | 
| 3 | 
             
            import faiss
         | 
| 4 | 
             
            import os
         | 
| 5 | 
            -
            from datasets import load_from_disk | 
| 6 | 
             
            import torch
         | 
| 7 | 
             
            import logging
         | 
| 8 | 
            -
            import traceback
         | 
| 9 |  | 
| 10 | 
             
            # Configure logging
         | 
| 11 | 
             
            logging.basicConfig(level=logging.INFO)
         | 
| 12 |  | 
| 13 | 
            -
            #  | 
| 14 | 
            -
            st. | 
| 15 | 
            -
             | 
| 16 | 
            -
             | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
            -
             | 
| 20 | 
            -
             | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
             | 
| 33 | 
            -
                try:
         | 
| 34 | 
            -
                    if not os.path.exists(dataset_dir):
         | 
| 35 | 
            -
                        with st.spinner("Building initial dataset from autism research papers..."):
         | 
| 36 | 
            -
                            import faiss_index.index as faiss_index_index
         | 
| 37 | 
            -
                            initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
         | 
| 38 | 
            -
                            dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
         | 
| 39 | 
            -
                    
         | 
| 40 | 
            -
                    # Load the dataset and index
         | 
| 41 | 
            -
                    dataset_path = os.path.join(dataset_dir, "dataset")
         | 
| 42 | 
            -
                    index_path = os.path.join(dataset_dir, "embeddings.faiss")
         | 
| 43 | 
            -
                    
         | 
| 44 | 
            -
                    if not os.path.exists(dataset_path) or not os.path.exists(index_path):
         | 
| 45 | 
            -
                        raise ValueError("Dataset or index not found")
         | 
| 46 | 
            -
                        
         | 
| 47 | 
            -
                    dataset = load_from_disk(dataset_path)
         | 
| 48 | 
            -
                    index = faiss.read_index(index_path)
         | 
| 49 | 
            -
                    
         | 
| 50 | 
            -
                    logging.info("Successfully loaded dataset and index")
         | 
| 51 | 
            -
                    return dataset, dataset_path, index_path
         | 
| 52 | 
            -
                except Exception as e:
         | 
| 53 | 
            -
                    st.error(f"Error loading dataset: {str(e)}\n{traceback.format_exc()}")
         | 
| 54 | 
            -
                    return None, None, None
         | 
| 55 |  | 
| 56 | 
             
            # RAG Pipeline
         | 
| 57 | 
            -
            def rag_pipeline(query, dataset,  | 
| 58 | 
            -
                 | 
| 59 | 
            -
             | 
| 60 | 
            -
             | 
| 61 | 
            -
                     | 
| 62 | 
            -
             | 
| 63 | 
            -
             | 
| 64 | 
            -
             | 
| 65 | 
            -
                         | 
| 66 | 
            -
                         | 
| 67 | 
            -
                         | 
| 68 | 
            -
                        index_path=index_path
         | 
| 69 | 
            -
                    )
         | 
| 70 | 
            -
                    
         | 
| 71 | 
            -
                    # Initialize model with retriever
         | 
| 72 | 
            -
                    model = RagSequenceForGeneration.from_pretrained(
         | 
| 73 | 
            -
                        model_name,
         | 
| 74 | 
            -
                        retriever=retriever,
         | 
| 75 | 
            -
                        use_auth_token=False
         | 
| 76 | 
             
                    )
         | 
|  | |
|  | |
| 77 |  | 
| 78 | 
            -
             | 
| 79 | 
            -
             | 
| 80 | 
            -
             | 
| 81 | 
            -
                        outputs = model.generate(
         | 
| 82 | 
            -
                            inputs["input_ids"],
         | 
| 83 | 
            -
                            max_length=200,
         | 
| 84 | 
            -
                            min_length=50,
         | 
| 85 | 
            -
                            num_beams=5,
         | 
| 86 | 
            -
                            early_stopping=True,
         | 
| 87 | 
            -
                            no_repeat_ngram_size=3
         | 
| 88 | 
            -
                        )
         | 
| 89 | 
            -
                        answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
         | 
| 90 | 
            -
                    
         | 
| 91 | 
            -
                    return answer
         | 
| 92 | 
            -
                except Exception as e:
         | 
| 93 | 
            -
                    st.error(f"Error generating answer: {str(e)}\n{traceback.format_exc()}")
         | 
| 94 | 
            -
                    return None
         | 
| 95 |  | 
| 96 | 
            -
            # Run the app
         | 
| 97 | 
             
            if query:
         | 
| 98 | 
            -
                with st.status(" | 
| 99 | 
            -
                     | 
| 100 | 
            -
             | 
| 101 | 
            -
             | 
| 102 | 
            -
                        
         | 
| 103 | 
            -
                         | 
| 104 | 
            -
             | 
| 105 | 
            -
             | 
| 106 | 
            -
                        else:
         | 
| 107 | 
            -
                            st.write("Found the best sources!")
         | 
| 108 | 
            -
                            st.write("Now answering your question...")
         | 
| 109 | 
            -
                            answer = rag_pipeline(query, dataset, dataset_path, index_path)
         | 
| 110 | 
            -
                            
         | 
| 111 | 
            -
                            if answer:
         | 
| 112 | 
            -
                                status.update(label="Search complete!", state="complete", expanded=False)
         | 
| 113 | 
            -
                                st.write("### Answer:")
         | 
| 114 | 
            -
                                st.write_stream(answer)
         | 
| 115 | 
            -
                                st.write("### Retrieved Papers:")
         | 
| 116 | 
            -
                                for i in range(min(5, len(dataset))):
         | 
| 117 | 
            -
                                    st.write(f"**Title:** {dataset[i]['title']}")
         | 
| 118 | 
            -
                                    st.write(f"**Summary:** {dataset[i]['text'][:200]}...")
         | 
| 119 | 
            -
                                    st.write("---")
         | 
| 120 | 
            -
                            else:
         | 
| 121 | 
            -
                                status.update(label="Error generating answer", state="error")
         | 
| 122 | 
            -
                    
         | 
| 123 | 
            -
                    except Exception as e:
         | 
| 124 | 
            -
                        st.error(f"Unexpected error: {str(e)}\n{traceback.format_exc()}")
         | 
| 125 | 
            -
                        status.update(label="Error", state="error")
         | 
|  | |
| 2 | 
             
            from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
         | 
| 3 | 
             
            import faiss
         | 
| 4 | 
             
            import os
         | 
| 5 | 
            +
            from datasets import load_from_disk
         | 
| 6 | 
             
            import torch
         | 
| 7 | 
             
            import logging
         | 
|  | |
| 8 |  | 
| 9 | 
             
            # Configure logging
         | 
| 10 | 
             
            logging.basicConfig(level=logging.INFO)
         | 
| 11 |  | 
| 12 | 
            +
            # Cache models and dataset
         | 
| 13 | 
            +
            @st.cache_resource  # Cache models in memory
         | 
| 14 | 
            +
            def load_models():
         | 
| 15 | 
            +
                tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
         | 
| 16 | 
            +
                retriever = RagRetriever.from_pretrained(
         | 
| 17 | 
            +
                    "facebook/rag-sequence-nq",
         | 
| 18 | 
            +
                    index_name="custom",
         | 
| 19 | 
            +
                    passages_path="/data/rag_dataset/dataset",
         | 
| 20 | 
            +
                    index_path="/data/rag_dataset/embeddings.faiss"
         | 
| 21 | 
            +
                )
         | 
| 22 | 
            +
                model = RagSequenceForGeneration.from_pretrained(
         | 
| 23 | 
            +
                    "facebook/rag-sequence-nq",
         | 
| 24 | 
            +
                    retriever=retriever,
         | 
| 25 | 
            +
                    device_map="auto"
         | 
| 26 | 
            +
                )
         | 
| 27 | 
            +
                return tokenizer, retriever, model
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            @st.cache_data  # Cache dataset on disk
         | 
| 30 | 
            +
            def load_dataset():
         | 
| 31 | 
            +
                return load_from_disk("/data/rag_dataset/dataset")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 32 |  | 
| 33 | 
             
            # RAG Pipeline
         | 
| 34 | 
            +
            def rag_pipeline(query, dataset, index):
         | 
| 35 | 
            +
                tokenizer, retriever, model = load_models()
         | 
| 36 | 
            +
                inputs = tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
         | 
| 37 | 
            +
                with torch.no_grad():
         | 
| 38 | 
            +
                    outputs = model.generate(
         | 
| 39 | 
            +
                        inputs["input_ids"],
         | 
| 40 | 
            +
                        max_length=200,
         | 
| 41 | 
            +
                        min_length=50,
         | 
| 42 | 
            +
                        num_beams=5,
         | 
| 43 | 
            +
                        early_stopping=True,
         | 
| 44 | 
            +
                        no_repeat_ngram_size=3
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 45 | 
             
                    )
         | 
| 46 | 
            +
                    answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
         | 
| 47 | 
            +
                return answer
         | 
| 48 |  | 
| 49 | 
            +
            # Streamlit App
         | 
| 50 | 
            +
            st.title("🧩 AMA Autism")
         | 
| 51 | 
            +
            query = st.text_input("Please ask me anything about autism ✨")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 52 |  | 
|  | |
| 53 | 
             
            if query:
         | 
| 54 | 
            +
                with st.status("Searching for answers..."):
         | 
| 55 | 
            +
                    dataset = load_dataset()
         | 
| 56 | 
            +
                    answer = rag_pipeline(query, dataset, index=None)
         | 
| 57 | 
            +
                    if answer:
         | 
| 58 | 
            +
                        st.success("Answer found!")
         | 
| 59 | 
            +
                        st.write(answer)
         | 
| 60 | 
            +
                    else:
         | 
| 61 | 
            +
                        st.error("Failed to generate an answer.")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        faiss_index/index.py
    CHANGED
    
    | @@ -12,95 +12,63 @@ logging.basicConfig(level=logging.INFO) | |
| 12 |  | 
| 13 | 
             
            def fetch_arxiv_papers(query, max_results=10):
         | 
| 14 | 
             
                """Fetch papers from arXiv and format them for RAG"""
         | 
| 15 | 
            -
                 | 
| 16 | 
            -
             | 
| 17 | 
            -
                     | 
| 18 | 
            -
             | 
| 19 | 
            -
             | 
| 20 | 
            -
             | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
| 25 | 
            -
                    for i, result in enumerate(results):
         | 
| 26 | 
            -
                        papers.append({
         | 
| 27 | 
            -
                            "id": str(i),
         | 
| 28 | 
            -
                            "text": result.summary,
         | 
| 29 | 
            -
                            "title": result.title,
         | 
| 30 | 
            -
                        })
         | 
| 31 | 
            -
                    
         | 
| 32 | 
            -
                    logging.info(f"Fetched {len(papers)} papers from arXiv")
         | 
| 33 | 
            -
                    return papers
         | 
| 34 | 
            -
                
         | 
| 35 | 
            -
                except Exception as e:
         | 
| 36 | 
            -
                    logging.error(f"Error fetching papers: {str(e)}")
         | 
| 37 | 
            -
                    raise
         | 
| 38 |  | 
| 39 | 
            -
            def build_faiss_index(papers, dataset_dir="rag_dataset"):
         | 
| 40 | 
             
                """Build and save dataset with FAISS index for RAG"""
         | 
| 41 | 
            -
                 | 
| 42 | 
            -
             | 
| 43 | 
            -
             | 
| 44 | 
            -
             | 
| 45 | 
            -
             | 
| 46 | 
            -
             | 
| 47 | 
            -
             | 
| 48 | 
            -
             | 
| 49 | 
            -
             | 
| 50 | 
            -
                    
         | 
| 51 | 
            -
                     | 
| 52 | 
            -
             | 
| 53 | 
            -
                         | 
| 54 | 
            -
             | 
| 55 | 
            -
             | 
| 56 | 
            -
             | 
| 57 | 
            -
             | 
| 58 | 
            -
             | 
| 59 | 
            -
             | 
| 60 | 
            -
             | 
| 61 | 
            -
             | 
| 62 | 
            -
             | 
| 63 | 
            -
             | 
| 64 | 
            -
             | 
| 65 | 
            -
                    
         | 
| 66 | 
            -
             | 
| 67 | 
            -
             | 
| 68 | 
            -
             | 
| 69 | 
            -
             | 
| 70 | 
            -
             | 
| 71 | 
            -
             | 
| 72 | 
            -
             | 
| 73 | 
            -
             | 
| 74 | 
            -
             | 
| 75 | 
            -
             | 
| 76 | 
            -
             | 
| 77 | 
            -
             | 
| 78 | 
            -
             | 
| 79 | 
            -
             | 
| 80 | 
            -
             | 
| 81 | 
            -
             | 
| 82 | 
            -
                    index.add(embeddings.astype(np.float32))
         | 
| 83 | 
            -
                    
         | 
| 84 | 
            -
                    # Save everything
         | 
| 85 | 
            -
                    os.makedirs(dataset_dir, exist_ok=True)
         | 
| 86 | 
            -
                    dataset_path = os.path.join(dataset_dir, "dataset")
         | 
| 87 | 
            -
                    index_path = os.path.join(dataset_dir, "embeddings.faiss")
         | 
| 88 | 
            -
                    
         | 
| 89 | 
            -
                    # Save dataset and index
         | 
| 90 | 
            -
                    dataset.save_to_disk(dataset_path)
         | 
| 91 | 
            -
                    faiss.write_index(index, index_path)
         | 
| 92 | 
            -
                    
         | 
| 93 | 
            -
                    logging.info(f"Saved dataset to {dataset_path}")
         | 
| 94 | 
            -
                    logging.info(f"Saved index to {index_path}")
         | 
| 95 | 
            -
                    
         | 
| 96 | 
            -
                    return dataset_dir
         | 
| 97 | 
            -
                    
         | 
| 98 | 
            -
                except Exception as e:
         | 
| 99 | 
            -
                    logging.error(f"Error building index: {str(e)}")
         | 
| 100 | 
            -
                    raise
         | 
| 101 |  | 
| 102 | 
             
            # Example usage
         | 
| 103 | 
             
            if __name__ == "__main__":
         | 
| 104 | 
            -
                query = " | 
| 105 | 
            -
                papers = fetch_arxiv_papers(query)
         | 
| 106 | 
             
                build_faiss_index(papers)
         | 
|  | |
| 12 |  | 
| 13 | 
             
            def fetch_arxiv_papers(query, max_results=10):
         | 
| 14 | 
             
                """Fetch papers from arXiv and format them for RAG"""
         | 
| 15 | 
            +
                client = arxiv.Client()
         | 
| 16 | 
            +
                search = arxiv.Search(
         | 
| 17 | 
            +
                    query=query,
         | 
| 18 | 
            +
                    max_results=max_results,
         | 
| 19 | 
            +
                    sort_by=arxiv.SortCriterion.SubmittedDate
         | 
| 20 | 
            +
                )
         | 
| 21 | 
            +
                results = list(client.results(search))
         | 
| 22 | 
            +
                papers = [{"id": str(i), "text": result.summary, "title": result.title} for i, result in enumerate(results)]
         | 
| 23 | 
            +
                logging.info(f"Fetched {len(papers)} papers from arXiv")
         | 
| 24 | 
            +
                return papers
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 25 |  | 
| 26 | 
            +
            def build_faiss_index(papers, dataset_dir="/data/rag_dataset"):
         | 
| 27 | 
             
                """Build and save dataset with FAISS index for RAG"""
         | 
| 28 | 
            +
                # Initialize DPR encoder
         | 
| 29 | 
            +
                ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
         | 
| 30 | 
            +
                ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
         | 
| 31 | 
            +
                
         | 
| 32 | 
            +
                # Create embeddings
         | 
| 33 | 
            +
                texts = [p["text"] for p in papers]
         | 
| 34 | 
            +
                embeddings = []
         | 
| 35 | 
            +
                batch_size = 8
         | 
| 36 | 
            +
                for i in range(0, len(texts), batch_size):
         | 
| 37 | 
            +
                    batch = texts[i:i + batch_size]
         | 
| 38 | 
            +
                    inputs = ctx_tokenizer(batch, max_length=512, padding=True, truncation=True, return_tensors="pt")
         | 
| 39 | 
            +
                    with torch.no_grad():
         | 
| 40 | 
            +
                        outputs = ctx_encoder(**inputs)
         | 
| 41 | 
            +
                        batch_embeddings = outputs.pooler_output.cpu().numpy()
         | 
| 42 | 
            +
                        embeddings.append(batch_embeddings)
         | 
| 43 | 
            +
                
         | 
| 44 | 
            +
                embeddings = np.vstack(embeddings)
         | 
| 45 | 
            +
                logging.info(f"Created embeddings with shape {embeddings.shape}")
         | 
| 46 | 
            +
                
         | 
| 47 | 
            +
                # Create dataset
         | 
| 48 | 
            +
                dataset = Dataset.from_dict({
         | 
| 49 | 
            +
                    "id": [p["id"] for p in papers],
         | 
| 50 | 
            +
                    "text": [p["text"] for p in papers],
         | 
| 51 | 
            +
                    "title": [p["title"] for p in papers],
         | 
| 52 | 
            +
                    "embeddings": [emb.tolist() for emb in embeddings],
         | 
| 53 | 
            +
                })
         | 
| 54 | 
            +
                
         | 
| 55 | 
            +
                # Create FAISS index
         | 
| 56 | 
            +
                dimension = embeddings.shape[1]
         | 
| 57 | 
            +
                index = faiss.IndexFlatL2(dimension)
         | 
| 58 | 
            +
                index.add(embeddings.astype(np.float32))
         | 
| 59 | 
            +
                
         | 
| 60 | 
            +
                # Save dataset and index
         | 
| 61 | 
            +
                os.makedirs(dataset_dir, exist_ok=True)
         | 
| 62 | 
            +
                dataset_path = os.path.join(dataset_dir, "dataset")
         | 
| 63 | 
            +
                index_path = os.path.join(dataset_dir, "embeddings.faiss")
         | 
| 64 | 
            +
                dataset.save_to_disk(dataset_path)
         | 
| 65 | 
            +
                faiss.write_index(index, index_path)
         | 
| 66 | 
            +
                logging.info(f"Saved dataset to {dataset_path}")
         | 
| 67 | 
            +
                logging.info(f"Saved index to {index_path}")
         | 
| 68 | 
            +
                return dataset_dir
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 69 |  | 
| 70 | 
             
            # Example usage
         | 
| 71 | 
             
            if __name__ == "__main__":
         | 
| 72 | 
            +
                query = "autism research"
         | 
| 73 | 
            +
                papers = fetch_arxiv_papers(query, max_results=100)
         | 
| 74 | 
             
                build_faiss_index(papers)
         | 
