Spaces:
Sleeping
Sleeping
fix: dimension error
Browse files- app.py +24 -13
- faiss_index/index.py +19 -4
app.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
import streamlit as st
|
2 |
-
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
|
3 |
import faiss
|
4 |
import os
|
5 |
from datasets import load_from_disk
|
|
|
6 |
|
7 |
# Title
|
8 |
st.title("AMA Austim 🧩")
|
@@ -29,33 +30,43 @@ def load_rag_dataset(dataset_dir="rag_dataset"):
|
|
29 |
# RAG Pipeline
|
30 |
def rag_pipeline(query, dataset, index):
|
31 |
# Load pre-trained RAG model and configure retriever
|
32 |
-
|
|
|
|
|
|
|
33 |
retriever = RagRetriever.from_pretrained(
|
34 |
-
|
35 |
index_name="custom",
|
36 |
passages_path=os.path.join("rag_dataset", "dataset"),
|
37 |
-
index_path=os.path.join("rag_dataset", "embeddings.faiss")
|
|
|
38 |
)
|
39 |
-
|
|
|
|
|
40 |
|
41 |
# Generate answer using RAG
|
42 |
inputs = tokenizer(query, return_tensors="pt")
|
43 |
-
|
44 |
-
|
|
|
45 |
|
46 |
return answer
|
47 |
|
48 |
# Run the app
|
49 |
if query:
|
50 |
-
st.
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
54 |
answer = rag_pipeline(query, dataset, index)
|
55 |
-
|
56 |
st.write("### Answer:")
|
57 |
st.write(answer)
|
58 |
-
|
59 |
st.write("### Retrieved Papers:")
|
60 |
for i in range(min(5, len(dataset))):
|
61 |
st.write(f"**Title:** {dataset[i]['title']}")
|
|
|
1 |
import streamlit as st
|
2 |
+
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, DPRQuestionEncoder, DPRQuestionEncoderTokenizer
|
3 |
import faiss
|
4 |
import os
|
5 |
from datasets import load_from_disk
|
6 |
+
import torch
|
7 |
|
8 |
# Title
|
9 |
st.title("AMA Austim 🧩")
|
|
|
30 |
# RAG Pipeline
|
31 |
def rag_pipeline(query, dataset, index):
|
32 |
# Load pre-trained RAG model and configure retriever
|
33 |
+
model_name = "facebook/rag-sequence-nq"
|
34 |
+
tokenizer = RagTokenizer.from_pretrained(model_name)
|
35 |
+
|
36 |
+
# Configure retriever with correct paths and question encoder
|
37 |
retriever = RagRetriever.from_pretrained(
|
38 |
+
model_name,
|
39 |
index_name="custom",
|
40 |
passages_path=os.path.join("rag_dataset", "dataset"),
|
41 |
+
index_path=os.path.join("rag_dataset", "embeddings.faiss"),
|
42 |
+
use_dummy_dataset=False
|
43 |
)
|
44 |
+
|
45 |
+
# Initialize the model with the configured retriever
|
46 |
+
model = RagSequenceForGeneration.from_pretrained(model_name, retriever=retriever)
|
47 |
|
48 |
# Generate answer using RAG
|
49 |
inputs = tokenizer(query, return_tensors="pt")
|
50 |
+
with torch.no_grad():
|
51 |
+
generated_ids = model.generate(inputs["input_ids"], max_length=200)
|
52 |
+
answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
53 |
|
54 |
return answer
|
55 |
|
56 |
# Run the app
|
57 |
if query:
|
58 |
+
with st.status("Looking for data in the best sources...", expanded=True) as status:
|
59 |
+
st.write("Still looking... this may take a while as we look at some prestigious papers...")
|
60 |
+
dataset, index = load_rag_dataset()
|
61 |
+
st.write("Found the best sources!")
|
62 |
+
status.update(
|
63 |
+
label="Download complete!",
|
64 |
+
state="complete",
|
65 |
+
expanded=False
|
66 |
+
)
|
67 |
answer = rag_pipeline(query, dataset, index)
|
|
|
68 |
st.write("### Answer:")
|
69 |
st.write(answer)
|
|
|
70 |
st.write("### Retrieved Papers:")
|
71 |
for i in range(min(5, len(dataset))):
|
72 |
st.write(f"**Title:** {dataset[i]['title']}")
|
faiss_index/index.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
import numpy as np
|
2 |
import faiss
|
3 |
-
from sentence_transformers import SentenceTransformer
|
4 |
import arxiv
|
5 |
from datasets import Dataset
|
6 |
import os
|
|
|
|
|
7 |
|
8 |
# Fetch arXiv papers
|
9 |
def fetch_arxiv_papers(query, max_results=10):
|
@@ -26,15 +27,29 @@ def build_faiss_index(papers, dataset_dir="rag_dataset"):
|
|
26 |
"text": [p["text"] for p in papers],
|
27 |
})
|
28 |
|
|
|
|
|
|
|
|
|
29 |
# Create embeddings
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
# Add embeddings to dataset
|
34 |
dataset = dataset.add_column("embeddings", [emb.tolist() for emb in embeddings])
|
35 |
|
36 |
# Create FAISS index
|
37 |
-
dimension = embeddings.shape[1]
|
38 |
index = faiss.IndexFlatL2(dimension)
|
39 |
index.add(embeddings.astype(np.float32))
|
40 |
|
|
|
1 |
import numpy as np
|
2 |
import faiss
|
|
|
3 |
import arxiv
|
4 |
from datasets import Dataset
|
5 |
import os
|
6 |
+
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
|
7 |
+
import torch
|
8 |
|
9 |
# Fetch arXiv papers
|
10 |
def fetch_arxiv_papers(query, max_results=10):
|
|
|
27 |
"text": [p["text"] for p in papers],
|
28 |
})
|
29 |
|
30 |
+
# Initialize DPR context encoder (same as used by RAG)
|
31 |
+
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
32 |
+
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
33 |
+
|
34 |
# Create embeddings
|
35 |
+
embeddings = []
|
36 |
+
batch_size = 8
|
37 |
+
|
38 |
+
for i in range(0, len(dataset), batch_size):
|
39 |
+
batch = dataset[i:i + batch_size]["text"]
|
40 |
+
inputs = ctx_tokenizer(batch, max_length=512, padding=True, truncation=True, return_tensors="pt")
|
41 |
+
with torch.no_grad():
|
42 |
+
outputs = ctx_encoder(**inputs)
|
43 |
+
batch_embeddings = outputs.pooler_output.cpu().numpy()
|
44 |
+
embeddings.append(batch_embeddings)
|
45 |
+
|
46 |
+
embeddings = np.vstack(embeddings)
|
47 |
|
48 |
# Add embeddings to dataset
|
49 |
dataset = dataset.add_column("embeddings", [emb.tolist() for emb in embeddings])
|
50 |
|
51 |
# Create FAISS index
|
52 |
+
dimension = embeddings.shape[1] # Should be 768 for DPR
|
53 |
index = faiss.IndexFlatL2(dimension)
|
54 |
index.add(embeddings.astype(np.float32))
|
55 |
|