wakeupmh commited on
Commit
0f8445a
·
1 Parent(s): f91cc3b

fix: dimension error

Browse files
Files changed (2) hide show
  1. app.py +24 -13
  2. faiss_index/index.py +19 -4
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import streamlit as st
2
- from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
3
  import faiss
4
  import os
5
  from datasets import load_from_disk
 
6
 
7
  # Title
8
  st.title("AMA Austim 🧩")
@@ -29,33 +30,43 @@ def load_rag_dataset(dataset_dir="rag_dataset"):
29
  # RAG Pipeline
30
  def rag_pipeline(query, dataset, index):
31
  # Load pre-trained RAG model and configure retriever
32
- tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
 
 
 
33
  retriever = RagRetriever.from_pretrained(
34
- "facebook/rag-sequence-nq",
35
  index_name="custom",
36
  passages_path=os.path.join("rag_dataset", "dataset"),
37
- index_path=os.path.join("rag_dataset", "embeddings.faiss")
 
38
  )
39
- model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)
 
 
40
 
41
  # Generate answer using RAG
42
  inputs = tokenizer(query, return_tensors="pt")
43
- generated_ids = model.generate(inputs["input_ids"])
44
- answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
45
 
46
  return answer
47
 
48
  # Run the app
49
  if query:
50
- st.write("Loading or creating RAG dataset...")
51
- dataset, index = load_rag_dataset()
52
-
53
- st.write("Running RAG pipeline...")
 
 
 
 
 
54
  answer = rag_pipeline(query, dataset, index)
55
-
56
  st.write("### Answer:")
57
  st.write(answer)
58
-
59
  st.write("### Retrieved Papers:")
60
  for i in range(min(5, len(dataset))):
61
  st.write(f"**Title:** {dataset[i]['title']}")
 
1
  import streamlit as st
2
+ from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, DPRQuestionEncoder, DPRQuestionEncoderTokenizer
3
  import faiss
4
  import os
5
  from datasets import load_from_disk
6
+ import torch
7
 
8
  # Title
9
  st.title("AMA Austim 🧩")
 
30
  # RAG Pipeline
31
  def rag_pipeline(query, dataset, index):
32
  # Load pre-trained RAG model and configure retriever
33
+ model_name = "facebook/rag-sequence-nq"
34
+ tokenizer = RagTokenizer.from_pretrained(model_name)
35
+
36
+ # Configure retriever with correct paths and question encoder
37
  retriever = RagRetriever.from_pretrained(
38
+ model_name,
39
  index_name="custom",
40
  passages_path=os.path.join("rag_dataset", "dataset"),
41
+ index_path=os.path.join("rag_dataset", "embeddings.faiss"),
42
+ use_dummy_dataset=False
43
  )
44
+
45
+ # Initialize the model with the configured retriever
46
+ model = RagSequenceForGeneration.from_pretrained(model_name, retriever=retriever)
47
 
48
  # Generate answer using RAG
49
  inputs = tokenizer(query, return_tensors="pt")
50
+ with torch.no_grad():
51
+ generated_ids = model.generate(inputs["input_ids"], max_length=200)
52
+ answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
53
 
54
  return answer
55
 
56
  # Run the app
57
  if query:
58
+ with st.status("Looking for data in the best sources...", expanded=True) as status:
59
+ st.write("Still looking... this may take a while as we look at some prestigious papers...")
60
+ dataset, index = load_rag_dataset()
61
+ st.write("Found the best sources!")
62
+ status.update(
63
+ label="Download complete!",
64
+ state="complete",
65
+ expanded=False
66
+ )
67
  answer = rag_pipeline(query, dataset, index)
 
68
  st.write("### Answer:")
69
  st.write(answer)
 
70
  st.write("### Retrieved Papers:")
71
  for i in range(min(5, len(dataset))):
72
  st.write(f"**Title:** {dataset[i]['title']}")
faiss_index/index.py CHANGED
@@ -1,9 +1,10 @@
1
  import numpy as np
2
  import faiss
3
- from sentence_transformers import SentenceTransformer
4
  import arxiv
5
  from datasets import Dataset
6
  import os
 
 
7
 
8
  # Fetch arXiv papers
9
  def fetch_arxiv_papers(query, max_results=10):
@@ -26,15 +27,29 @@ def build_faiss_index(papers, dataset_dir="rag_dataset"):
26
  "text": [p["text"] for p in papers],
27
  })
28
 
 
 
 
 
29
  # Create embeddings
30
- embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
31
- embeddings = embedder.encode(dataset["text"], show_progress_bar=True)
 
 
 
 
 
 
 
 
 
 
32
 
33
  # Add embeddings to dataset
34
  dataset = dataset.add_column("embeddings", [emb.tolist() for emb in embeddings])
35
 
36
  # Create FAISS index
37
- dimension = embeddings.shape[1]
38
  index = faiss.IndexFlatL2(dimension)
39
  index.add(embeddings.astype(np.float32))
40
 
 
1
  import numpy as np
2
  import faiss
 
3
  import arxiv
4
  from datasets import Dataset
5
  import os
6
+ from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
7
+ import torch
8
 
9
  # Fetch arXiv papers
10
  def fetch_arxiv_papers(query, max_results=10):
 
27
  "text": [p["text"] for p in papers],
28
  })
29
 
30
+ # Initialize DPR context encoder (same as used by RAG)
31
+ ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
32
+ ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
33
+
34
  # Create embeddings
35
+ embeddings = []
36
+ batch_size = 8
37
+
38
+ for i in range(0, len(dataset), batch_size):
39
+ batch = dataset[i:i + batch_size]["text"]
40
+ inputs = ctx_tokenizer(batch, max_length=512, padding=True, truncation=True, return_tensors="pt")
41
+ with torch.no_grad():
42
+ outputs = ctx_encoder(**inputs)
43
+ batch_embeddings = outputs.pooler_output.cpu().numpy()
44
+ embeddings.append(batch_embeddings)
45
+
46
+ embeddings = np.vstack(embeddings)
47
 
48
  # Add embeddings to dataset
49
  dataset = dataset.add_column("embeddings", [emb.tolist() for emb in embeddings])
50
 
51
  # Create FAISS index
52
+ dimension = embeddings.shape[1] # Should be 768 for DPR
53
  index = faiss.IndexFlatL2(dimension)
54
  index.add(embeddings.astype(np.float32))
55