wakeupmh commited on
Commit
884d2bd
·
1 Parent(s): f68ac31

fix: faiss

Browse files
Files changed (2) hide show
  1. app.py +3 -3
  2. faiss_index/index.py +1 -1
app.py CHANGED
@@ -16,8 +16,8 @@ def load_models():
16
  retriever = RagRetriever.from_pretrained(
17
  "facebook/rag-sequence-nq",
18
  index_name="custom",
19
- passages_path="/data/rag_dataset/dataset",
20
- index_path="/data/rag_dataset/embeddings.faiss"
21
  )
22
  model = RagSequenceForGeneration.from_pretrained(
23
  "facebook/rag-sequence-nq",
@@ -28,7 +28,7 @@ def load_models():
28
 
29
  @st.cache_data # Cache dataset on disk
30
  def load_dataset():
31
- return load_from_disk("/data/rag_dataset/dataset")
32
 
33
  # RAG Pipeline
34
  def rag_pipeline(query, dataset, index):
 
16
  retriever = RagRetriever.from_pretrained(
17
  "facebook/rag-sequence-nq",
18
  index_name="custom",
19
+ passages_path="rag_dataset/dataset",
20
+ index_path="rag_dataset/embeddings.faiss"
21
  )
22
  model = RagSequenceForGeneration.from_pretrained(
23
  "facebook/rag-sequence-nq",
 
28
 
29
  @st.cache_data # Cache dataset on disk
30
  def load_dataset():
31
+ return load_from_disk("rag_dataset/dataset")
32
 
33
  # RAG Pipeline
34
  def rag_pipeline(query, dataset, index):
faiss_index/index.py CHANGED
@@ -23,7 +23,7 @@ def fetch_arxiv_papers(query, max_results=10):
23
  logging.info(f"Fetched {len(papers)} papers from arXiv")
24
  return papers
25
 
26
- def build_faiss_index(papers, dataset_dir="/data/rag_dataset"):
27
  """Build and save dataset with FAISS index for RAG"""
28
  # Initialize DPR encoder
29
  ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
 
23
  logging.info(f"Fetched {len(papers)} papers from arXiv")
24
  return papers
25
 
26
+ def build_faiss_index(papers, dataset_dir="rag_dataset"):
27
  """Build and save dataset with FAISS index for RAG"""
28
  # Initialize DPR encoder
29
  ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")