Spaces:
Sleeping
Sleeping
fix: hf format
Browse files- app.py +17 -14
- faiss_index/index.py +29 -13
app.py
CHANGED
@@ -38,29 +38,32 @@ def load_rag_dataset(dataset_dir="rag_dataset"):
|
|
38 |
dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
|
39 |
|
40 |
# Load the dataset and index
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
except Exception as e:
|
45 |
st.error(f"Error loading dataset: {str(e)}\n{traceback.format_exc()}")
|
46 |
-
return None, None
|
47 |
|
48 |
# RAG Pipeline
|
49 |
-
def rag_pipeline(query, dataset,
|
50 |
try:
|
51 |
# Initialize components
|
52 |
model_name = "facebook/rag-sequence-nq"
|
53 |
tokenizer = RagTokenizer.from_pretrained(model_name)
|
54 |
|
55 |
-
#
|
56 |
-
passages = prepare_rag_passages(dataset)
|
57 |
-
|
58 |
-
# Initialize retriever with passages
|
59 |
retriever = RagRetriever.from_pretrained(
|
60 |
model_name,
|
61 |
index_name="custom",
|
62 |
-
|
63 |
-
|
64 |
)
|
65 |
|
66 |
# Initialize model with retriever
|
@@ -93,15 +96,15 @@ if query:
|
|
93 |
with st.status("Looking for data in the best sources...", expanded=True) as status:
|
94 |
try:
|
95 |
st.write("Still looking... this may take a while as we look at some prestigious papers...")
|
96 |
-
dataset,
|
97 |
|
98 |
-
if dataset is None
|
99 |
st.error("Failed to load or create the dataset.")
|
100 |
status.update(label="Error loading data", state="error")
|
101 |
else:
|
102 |
st.write("Found the best sources!")
|
103 |
st.write("Now answering your question...")
|
104 |
-
answer = rag_pipeline(query, dataset,
|
105 |
|
106 |
if answer:
|
107 |
status.update(label="Search complete!", state="complete", expanded=False)
|
|
|
38 |
dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
|
39 |
|
40 |
# Load the dataset and index
|
41 |
+
dataset_path = os.path.join(dataset_dir, "dataset")
|
42 |
+
index_path = os.path.join(dataset_dir, "embeddings.faiss")
|
43 |
+
|
44 |
+
dataset = load_from_disk(dataset_path)
|
45 |
+
|
46 |
+
# Add FAISS index back to dataset
|
47 |
+
dataset.load_faiss_index('embeddings', index_path)
|
48 |
+
|
49 |
+
return dataset, dataset_path, index_path
|
50 |
except Exception as e:
|
51 |
st.error(f"Error loading dataset: {str(e)}\n{traceback.format_exc()}")
|
52 |
+
return None, None, None
|
53 |
|
54 |
# RAG Pipeline
|
55 |
+
def rag_pipeline(query, dataset, dataset_path, index_path):
|
56 |
try:
|
57 |
# Initialize components
|
58 |
model_name = "facebook/rag-sequence-nq"
|
59 |
tokenizer = RagTokenizer.from_pretrained(model_name)
|
60 |
|
61 |
+
# Initialize retriever with correct paths
|
|
|
|
|
|
|
62 |
retriever = RagRetriever.from_pretrained(
|
63 |
model_name,
|
64 |
index_name="custom",
|
65 |
+
passages_path=dataset_path,
|
66 |
+
index_path=index_path
|
67 |
)
|
68 |
|
69 |
# Initialize model with retriever
|
|
|
96 |
with st.status("Looking for data in the best sources...", expanded=True) as status:
|
97 |
try:
|
98 |
st.write("Still looking... this may take a while as we look at some prestigious papers...")
|
99 |
+
dataset, dataset_path, index_path = load_rag_dataset()
|
100 |
|
101 |
+
if dataset is None:
|
102 |
st.error("Failed to load or create the dataset.")
|
103 |
status.update(label="Error loading data", state="error")
|
104 |
else:
|
105 |
st.write("Found the best sources!")
|
106 |
st.write("Now answering your question...")
|
107 |
+
answer = rag_pipeline(query, dataset, dataset_path, index_path)
|
108 |
|
109 |
if answer:
|
110 |
status.update(label="Search complete!", state="complete", expanded=False)
|
faiss_index/index.py
CHANGED
@@ -6,6 +6,7 @@ import os
|
|
6 |
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
|
7 |
import torch
|
8 |
import logging
|
|
|
9 |
|
10 |
# Configure logging
|
11 |
logging.basicConfig(level=logging.INFO)
|
@@ -24,9 +25,9 @@ def fetch_arxiv_papers(query, max_results=10):
|
|
24 |
|
25 |
for i, result in enumerate(results):
|
26 |
papers.append({
|
27 |
-
"id": str(i),
|
28 |
-
"text": result.summary,
|
29 |
-
"title": result.title,
|
30 |
})
|
31 |
|
32 |
logging.info(f"Fetched {len(papers)} papers from arXiv")
|
@@ -39,16 +40,16 @@ def fetch_arxiv_papers(query, max_results=10):
|
|
39 |
def build_faiss_index(papers, dataset_dir="rag_dataset"):
|
40 |
"""Build and save dataset with FAISS index for RAG"""
|
41 |
try:
|
42 |
-
# Create dataset
|
43 |
dataset = Dataset.from_dict({
|
44 |
"id": [p["id"] for p in papers],
|
45 |
-
"text": [p["text"] for p in papers],
|
46 |
-
"title": [p["title"] for p in papers],
|
47 |
})
|
48 |
|
49 |
logging.info(f"Created dataset with {len(dataset)} papers")
|
50 |
|
51 |
-
# Initialize DPR encoder
|
52 |
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
53 |
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
54 |
|
@@ -74,17 +75,32 @@ def build_faiss_index(papers, dataset_dir="rag_dataset"):
|
|
74 |
embeddings = np.vstack(embeddings)
|
75 |
logging.info(f"Created embeddings with shape {embeddings.shape}")
|
76 |
|
77 |
-
# Create FAISS index
|
78 |
-
dimension = embeddings.shape[1]
|
79 |
index = faiss.IndexFlatL2(dimension)
|
80 |
index.add(embeddings.astype(np.float32))
|
81 |
|
82 |
-
# Save
|
83 |
os.makedirs(dataset_dir, exist_ok=True)
|
84 |
-
dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
|
85 |
-
faiss.write_index(index, os.path.join(dataset_dir, "embeddings.faiss"))
|
86 |
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
return dataset_dir
|
89 |
|
90 |
except Exception as e:
|
|
|
6 |
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
|
7 |
import torch
|
8 |
import logging
|
9 |
+
from datasets.utils.file_utils import DownloadConfig
|
10 |
|
11 |
# Configure logging
|
12 |
logging.basicConfig(level=logging.INFO)
|
|
|
25 |
|
26 |
for i, result in enumerate(results):
|
27 |
papers.append({
|
28 |
+
"id": str(i),
|
29 |
+
"text": result.summary,
|
30 |
+
"title": result.title,
|
31 |
})
|
32 |
|
33 |
logging.info(f"Fetched {len(papers)} papers from arXiv")
|
|
|
40 |
def build_faiss_index(papers, dataset_dir="rag_dataset"):
|
41 |
"""Build and save dataset with FAISS index for RAG"""
|
42 |
try:
|
43 |
+
# Create dataset
|
44 |
dataset = Dataset.from_dict({
|
45 |
"id": [p["id"] for p in papers],
|
46 |
+
"text": [p["text"] for p in papers],
|
47 |
+
"title": [p["title"] for p in papers],
|
48 |
})
|
49 |
|
50 |
logging.info(f"Created dataset with {len(dataset)} papers")
|
51 |
|
52 |
+
# Initialize DPR encoder
|
53 |
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
54 |
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
55 |
|
|
|
75 |
embeddings = np.vstack(embeddings)
|
76 |
logging.info(f"Created embeddings with shape {embeddings.shape}")
|
77 |
|
78 |
+
# Create FAISS index
|
79 |
+
dimension = embeddings.shape[1]
|
80 |
index = faiss.IndexFlatL2(dimension)
|
81 |
index.add(embeddings.astype(np.float32))
|
82 |
|
83 |
+
# Save everything
|
84 |
os.makedirs(dataset_dir, exist_ok=True)
|
|
|
|
|
85 |
|
86 |
+
# Add embeddings to dataset
|
87 |
+
dataset = dataset.add_faiss_index(
|
88 |
+
column='embeddings',
|
89 |
+
custom_index=index,
|
90 |
+
device=0 if torch.cuda.is_available() else -1
|
91 |
+
)
|
92 |
+
dataset = dataset.add_column("embeddings", [emb.tolist() for emb in embeddings])
|
93 |
+
|
94 |
+
# Save dataset and index
|
95 |
+
dataset_path = os.path.join(dataset_dir, "dataset")
|
96 |
+
index_path = os.path.join(dataset_dir, "embeddings.faiss")
|
97 |
+
|
98 |
+
dataset.save_to_disk(dataset_path)
|
99 |
+
dataset.get_index('embeddings').save(index_path)
|
100 |
+
|
101 |
+
logging.info(f"Saved dataset to {dataset_path}")
|
102 |
+
logging.info(f"Saved index to {index_path}")
|
103 |
+
|
104 |
return dataset_dir
|
105 |
|
106 |
except Exception as e:
|