wakeupmh commited on
Commit
f91cc3b
·
1 Parent(s): ce32c8e

feat: add dataset

Browse files
Files changed (3) hide show
  1. app.py +29 -39
  2. faiss_index/index.py +26 -10
  3. requirements.txt +2 -1
app.py CHANGED
@@ -1,46 +1,41 @@
1
  import streamlit as st
2
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
3
- from sentence_transformers import SentenceTransformer
4
  import faiss
5
- import numpy as np
6
- import arxiv
7
 
8
  # Title
9
- st.title("arXiv RAG with Streamlit")
10
 
11
  # Input: Query
12
- query = st.text_input("Enter your query:")
13
 
14
- # Fetch arXiv papers
15
- def fetch_arxiv_papers(query, max_results=5):
16
- client = arxiv.Client()
17
- search = arxiv.Search(
18
- query=query,
19
- max_results=max_results,
20
- sort_by=arxiv.SortCriterion.SubmittedDate
21
- )
22
- results = list(client.results(search))
23
- papers = [{"title": result.title, "summary": result.summary, "pdf_url": result.pdf_url} for result in results]
24
- return papers
25
-
26
- # Load FAISS index
27
- def load_faiss_index(index_file="faiss_index.index"):
28
- import os
29
- if not os.path.exists(index_file):
30
  # Import the build function from the other file
31
  import faiss_index.index as faiss_index_index
32
 
33
  # Fetch some initial papers to build the index
34
  initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
35
- faiss_index_index.build_faiss_index(initial_papers, index_file)
36
 
37
- return faiss.read_index(index_file)
 
 
 
 
38
 
39
  # RAG Pipeline
40
- def rag_pipeline(query, papers, index):
41
- # Load pre-trained RAG model
42
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
43
- retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="custom", passages=papers, index=index)
 
 
 
 
 
44
  model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)
45
 
46
  # Generate answer using RAG
@@ -52,22 +47,17 @@ def rag_pipeline(query, papers, index):
52
 
53
  # Run the app
54
  if query:
55
- st.write("Fetching arXiv papers...")
56
- papers = fetch_arxiv_papers(query)
57
- st.write(f"Found {len(papers)} papers.")
58
-
59
- st.write("Loading FAISS index...")
60
- index = load_faiss_index()
61
-
62
  st.write("Running RAG pipeline...")
63
- answer = rag_pipeline(query, papers, index)
64
 
65
  st.write("### Answer:")
66
  st.write(answer)
67
 
68
- st.write("### Relevant Papers:")
69
- for paper in papers:
70
- st.write(f"**Title:** {paper['title']}")
71
- st.write(f"**Summary:** {paper['summary']}")
72
- st.write(f"**PDF URL:** {paper['pdf_url']}")
73
  st.write("---")
 
1
  import streamlit as st
2
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
 
3
  import faiss
4
+ import os
5
+ from datasets import load_from_disk
6
 
7
  # Title
8
+ st.title("AMA Austim 🧩")
9
 
10
  # Input: Query
11
+ query = st.text_input("Please ask me anything about autism ✨")
12
 
13
+ # Load or create RAG dataset
14
+ def load_rag_dataset(dataset_dir="rag_dataset"):
15
+ if not os.path.exists(dataset_dir):
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  # Import the build function from the other file
17
  import faiss_index.index as faiss_index_index
18
 
19
  # Fetch some initial papers to build the index
20
  initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
21
+ dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
22
 
23
+ # Load the dataset and index
24
+ dataset = load_from_disk(os.path.join(dataset_dir, "dataset"))
25
+ index = faiss.read_index(os.path.join(dataset_dir, "embeddings.faiss"))
26
+
27
+ return dataset, index
28
 
29
  # RAG Pipeline
30
+ def rag_pipeline(query, dataset, index):
31
+ # Load pre-trained RAG model and configure retriever
32
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
33
+ retriever = RagRetriever.from_pretrained(
34
+ "facebook/rag-sequence-nq",
35
+ index_name="custom",
36
+ passages_path=os.path.join("rag_dataset", "dataset"),
37
+ index_path=os.path.join("rag_dataset", "embeddings.faiss")
38
+ )
39
  model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)
40
 
41
  # Generate answer using RAG
 
47
 
48
  # Run the app
49
  if query:
50
+ st.write("Loading or creating RAG dataset...")
51
+ dataset, index = load_rag_dataset()
52
+
 
 
 
 
53
  st.write("Running RAG pipeline...")
54
+ answer = rag_pipeline(query, dataset, index)
55
 
56
  st.write("### Answer:")
57
  st.write(answer)
58
 
59
+ st.write("### Retrieved Papers:")
60
+ for i in range(min(5, len(dataset))):
61
+ st.write(f"**Title:** {dataset[i]['title']}")
62
+ st.write(f"**Summary:** {dataset[i]['text'][:200]}...")
 
63
  st.write("---")
faiss_index/index.py CHANGED
@@ -2,6 +2,8 @@ import numpy as np
2
  import faiss
3
  from sentence_transformers import SentenceTransformer
4
  import arxiv
 
 
5
 
6
  # Fetch arXiv papers
7
  def fetch_arxiv_papers(query, max_results=10):
@@ -12,22 +14,36 @@ def fetch_arxiv_papers(query, max_results=10):
12
  sort_by=arxiv.SortCriterion.SubmittedDate
13
  )
14
  results = list(client.results(search))
15
- papers = [{"title": result.title, "summary": result.summary, "pdf_url": result.pdf_url} for result in results]
16
  return papers
17
 
18
- # Build and save FAISS index
19
- def build_faiss_index(papers, index_file="faiss_index.index"):
20
- embedder = SentenceTransformer('all-MiniLM-L6-v2')
21
- paper_embeddings = embedder.encode([paper["summary"] for paper in papers])
 
 
 
 
22
 
 
 
 
 
 
 
 
23
  # Create FAISS index
24
- dimension = paper_embeddings.shape[1]
25
  index = faiss.IndexFlatL2(dimension)
26
- index.add(paper_embeddings)
 
 
 
 
 
27
 
28
- # Save index to disk
29
- faiss.write_index(index, index_file)
30
- print(f"FAISS index saved to {index_file}")
31
 
32
  # Example usage
33
  if __name__ == "__main__":
 
2
  import faiss
3
  from sentence_transformers import SentenceTransformer
4
  import arxiv
5
+ from datasets import Dataset
6
+ import os
7
 
8
  # Fetch arXiv papers
9
  def fetch_arxiv_papers(query, max_results=10):
 
14
  sort_by=arxiv.SortCriterion.SubmittedDate
15
  )
16
  results = list(client.results(search))
17
+ papers = [{"title": result.title, "text": result.summary, "id": str(i)} for i, result in enumerate(results)]
18
  return papers
19
 
20
+ # Build and save dataset with FAISS index
21
+ def build_faiss_index(papers, dataset_dir="rag_dataset"):
22
+ # Create dataset
23
+ dataset = Dataset.from_dict({
24
+ "id": [p["id"] for p in papers],
25
+ "title": [p["title"] for p in papers],
26
+ "text": [p["text"] for p in papers],
27
+ })
28
 
29
+ # Create embeddings
30
+ embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
31
+ embeddings = embedder.encode(dataset["text"], show_progress_bar=True)
32
+
33
+ # Add embeddings to dataset
34
+ dataset = dataset.add_column("embeddings", [emb.tolist() for emb in embeddings])
35
+
36
  # Create FAISS index
37
+ dimension = embeddings.shape[1]
38
  index = faiss.IndexFlatL2(dimension)
39
+ index.add(embeddings.astype(np.float32))
40
+
41
+ # Save dataset and index
42
+ os.makedirs(dataset_dir, exist_ok=True)
43
+ dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
44
+ faiss.write_index(index, os.path.join(dataset_dir, "embeddings.faiss"))
45
 
46
+ return dataset_dir
 
 
47
 
48
  # Example usage
49
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -3,4 +3,5 @@ transformers
3
  datasets
4
  sentence-transformers
5
  faiss-cpu
6
- arxiv
 
 
3
  datasets
4
  sentence-transformers
5
  faiss-cpu
6
+ arxiv
7
+ torch