Spaces:
Sleeping
Sleeping
fix: rag
Browse files- app.py +5 -3
- faiss_index/index.py +16 -24
app.py
CHANGED
@@ -41,11 +41,13 @@ def load_rag_dataset(dataset_dir="rag_dataset"):
|
|
41 |
dataset_path = os.path.join(dataset_dir, "dataset")
|
42 |
index_path = os.path.join(dataset_dir, "embeddings.faiss")
|
43 |
|
|
|
|
|
|
|
44 |
dataset = load_from_disk(dataset_path)
|
|
|
45 |
|
46 |
-
|
47 |
-
dataset.load_faiss_index('embeddings', index_path)
|
48 |
-
|
49 |
return dataset, dataset_path, index_path
|
50 |
except Exception as e:
|
51 |
st.error(f"Error loading dataset: {str(e)}\n{traceback.format_exc()}")
|
|
|
41 |
dataset_path = os.path.join(dataset_dir, "dataset")
|
42 |
index_path = os.path.join(dataset_dir, "embeddings.faiss")
|
43 |
|
44 |
+
if not os.path.exists(dataset_path) or not os.path.exists(index_path):
|
45 |
+
raise ValueError("Dataset or index not found")
|
46 |
+
|
47 |
dataset = load_from_disk(dataset_path)
|
48 |
+
index = faiss.read_index(index_path)
|
49 |
|
50 |
+
logging.info("Successfully loaded dataset and index")
|
|
|
|
|
51 |
return dataset, dataset_path, index_path
|
52 |
except Exception as e:
|
53 |
st.error(f"Error loading dataset: {str(e)}\n{traceback.format_exc()}")
|
faiss_index/index.py
CHANGED
@@ -6,7 +6,6 @@ import os
|
|
6 |
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
|
7 |
import torch
|
8 |
import logging
|
9 |
-
from datasets.utils.file_utils import DownloadConfig
|
10 |
|
11 |
# Configure logging
|
12 |
logging.basicConfig(level=logging.INFO)
|
@@ -40,25 +39,17 @@ def fetch_arxiv_papers(query, max_results=10):
|
|
40 |
def build_faiss_index(papers, dataset_dir="rag_dataset"):
|
41 |
"""Build and save dataset with FAISS index for RAG"""
|
42 |
try:
|
43 |
-
# Create dataset
|
44 |
-
dataset = Dataset.from_dict({
|
45 |
-
"id": [p["id"] for p in papers],
|
46 |
-
"text": [p["text"] for p in papers],
|
47 |
-
"title": [p["title"] for p in papers],
|
48 |
-
})
|
49 |
-
|
50 |
-
logging.info(f"Created dataset with {len(dataset)} papers")
|
51 |
-
|
52 |
# Initialize DPR encoder
|
53 |
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
54 |
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
55 |
|
56 |
# Create embeddings in batches
|
|
|
57 |
embeddings = []
|
58 |
batch_size = 8
|
59 |
|
60 |
-
for i in range(0, len(
|
61 |
-
batch =
|
62 |
inputs = ctx_tokenizer(
|
63 |
batch,
|
64 |
max_length=512,
|
@@ -75,28 +66,29 @@ def build_faiss_index(papers, dataset_dir="rag_dataset"):
|
|
75 |
embeddings = np.vstack(embeddings)
|
76 |
logging.info(f"Created embeddings with shape {embeddings.shape}")
|
77 |
|
78 |
-
# Create
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
dimension = embeddings.shape[1]
|
80 |
index = faiss.IndexFlatL2(dimension)
|
81 |
index.add(embeddings.astype(np.float32))
|
82 |
|
83 |
# Save everything
|
84 |
os.makedirs(dataset_dir, exist_ok=True)
|
85 |
-
|
86 |
-
# Add embeddings to dataset
|
87 |
-
dataset = dataset.add_faiss_index(
|
88 |
-
column='embeddings',
|
89 |
-
custom_index=index,
|
90 |
-
device=0 if torch.cuda.is_available() else -1
|
91 |
-
)
|
92 |
-
dataset = dataset.add_column("embeddings", [emb.tolist() for emb in embeddings])
|
93 |
-
|
94 |
-
# Save dataset and index
|
95 |
dataset_path = os.path.join(dataset_dir, "dataset")
|
96 |
index_path = os.path.join(dataset_dir, "embeddings.faiss")
|
97 |
|
|
|
98 |
dataset.save_to_disk(dataset_path)
|
99 |
-
|
100 |
|
101 |
logging.info(f"Saved dataset to {dataset_path}")
|
102 |
logging.info(f"Saved index to {index_path}")
|
|
|
6 |
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
|
7 |
import torch
|
8 |
import logging
|
|
|
9 |
|
10 |
# Configure logging
|
11 |
logging.basicConfig(level=logging.INFO)
|
|
|
39 |
def build_faiss_index(papers, dataset_dir="rag_dataset"):
|
40 |
"""Build and save dataset with FAISS index for RAG"""
|
41 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
# Initialize DPR encoder
|
43 |
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
44 |
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
45 |
|
46 |
# Create embeddings in batches
|
47 |
+
texts = [p["text"] for p in papers]
|
48 |
embeddings = []
|
49 |
batch_size = 8
|
50 |
|
51 |
+
for i in range(0, len(texts), batch_size):
|
52 |
+
batch = texts[i:i + batch_size]
|
53 |
inputs = ctx_tokenizer(
|
54 |
batch,
|
55 |
max_length=512,
|
|
|
66 |
embeddings = np.vstack(embeddings)
|
67 |
logging.info(f"Created embeddings with shape {embeddings.shape}")
|
68 |
|
69 |
+
# Create dataset with embeddings
|
70 |
+
dataset = Dataset.from_dict({
|
71 |
+
"id": [p["id"] for p in papers],
|
72 |
+
"text": [p["text"] for p in papers],
|
73 |
+
"title": [p["title"] for p in papers],
|
74 |
+
"embeddings": [emb.tolist() for emb in embeddings],
|
75 |
+
})
|
76 |
+
|
77 |
+
logging.info(f"Created dataset with {len(dataset)} papers")
|
78 |
+
|
79 |
+
# Create FAISS index from embeddings
|
80 |
dimension = embeddings.shape[1]
|
81 |
index = faiss.IndexFlatL2(dimension)
|
82 |
index.add(embeddings.astype(np.float32))
|
83 |
|
84 |
# Save everything
|
85 |
os.makedirs(dataset_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
dataset_path = os.path.join(dataset_dir, "dataset")
|
87 |
index_path = os.path.join(dataset_dir, "embeddings.faiss")
|
88 |
|
89 |
+
# Save dataset and index
|
90 |
dataset.save_to_disk(dataset_path)
|
91 |
+
faiss.write_index(index, index_path)
|
92 |
|
93 |
logging.info(f"Saved dataset to {dataset_path}")
|
94 |
logging.info(f"Saved index to {index_path}")
|