wakeupmh commited on
Commit
0452175
·
1 Parent(s): 884d2bd

fix: run in hf

Browse files
Files changed (2) hide show
  1. app.py +15 -3
  2. faiss_index/index.py +5 -7
app.py CHANGED
@@ -9,6 +9,12 @@ import logging
9
  # Configure logging
10
  logging.basicConfig(level=logging.INFO)
11
 
 
 
 
 
 
 
12
  # Cache models and dataset
13
  @st.cache_resource # Cache models in memory
14
  def load_models():
@@ -16,8 +22,8 @@ def load_models():
16
  retriever = RagRetriever.from_pretrained(
17
  "facebook/rag-sequence-nq",
18
  index_name="custom",
19
- passages_path="rag_dataset/dataset",
20
- index_path="rag_dataset/embeddings.faiss"
21
  )
22
  model = RagSequenceForGeneration.from_pretrained(
23
  "facebook/rag-sequence-nq",
@@ -28,7 +34,13 @@ def load_models():
28
 
29
  @st.cache_data # Cache dataset on disk
30
  def load_dataset():
31
- return load_from_disk("rag_dataset/dataset")
 
 
 
 
 
 
32
 
33
  # RAG Pipeline
34
  def rag_pipeline(query, dataset, index):
 
9
  # Configure logging
10
  logging.basicConfig(level=logging.INFO)
11
 
12
+ # Define data paths
13
+ DATA_DIR = "/data" if os.path.exists("/data") else "."
14
+ DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
15
+ DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
16
+ INDEX_PATH = os.path.join(DATASET_DIR, "embeddings.faiss")
17
+
18
  # Cache models and dataset
19
  @st.cache_resource # Cache models in memory
20
  def load_models():
 
22
  retriever = RagRetriever.from_pretrained(
23
  "facebook/rag-sequence-nq",
24
  index_name="custom",
25
+ passages_path=DATASET_PATH,
26
+ index_path=INDEX_PATH
27
  )
28
  model = RagSequenceForGeneration.from_pretrained(
29
  "facebook/rag-sequence-nq",
 
34
 
35
  @st.cache_data # Cache dataset on disk
36
  def load_dataset():
37
+ # Create initial dataset if it doesn't exist
38
+ if not os.path.exists(DATASET_PATH):
39
+ with st.spinner("Building initial dataset from autism research papers..."):
40
+ import faiss_index.index as idx
41
+ papers = idx.fetch_arxiv_papers("autism research", max_results=100)
42
+ idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
43
+ return load_from_disk(DATASET_PATH)
44
 
45
  # RAG Pipeline
46
  def rag_pipeline(query, dataset, index):
faiss_index/index.py CHANGED
@@ -10,6 +10,10 @@ import logging
10
  # Configure logging
11
  logging.basicConfig(level=logging.INFO)
12
 
 
 
 
 
13
  def fetch_arxiv_papers(query, max_results=10):
14
  """Fetch papers from arXiv and format them for RAG"""
15
  client = arxiv.Client()
@@ -23,7 +27,7 @@ def fetch_arxiv_papers(query, max_results=10):
23
  logging.info(f"Fetched {len(papers)} papers from arXiv")
24
  return papers
25
 
26
- def build_faiss_index(papers, dataset_dir="rag_dataset"):
27
  """Build and save dataset with FAISS index for RAG"""
28
  # Initialize DPR encoder
29
  ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
@@ -66,9 +70,3 @@ def build_faiss_index(papers, dataset_dir="rag_dataset"):
66
  logging.info(f"Saved dataset to {dataset_path}")
67
  logging.info(f"Saved index to {index_path}")
68
  return dataset_dir
69
-
70
- # Example usage
71
- if __name__ == "__main__":
72
- query = "autism research"
73
- papers = fetch_arxiv_papers(query, max_results=100)
74
- build_faiss_index(papers)
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.INFO)
12
 
13
+ # Define data paths
14
+ DATA_DIR = os.getenv("DATA_DIR", "/data" if os.path.exists("/data") else ".")
15
+ DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
16
+
17
  def fetch_arxiv_papers(query, max_results=10):
18
  """Fetch papers from arXiv and format them for RAG"""
19
  client = arxiv.Client()
 
27
  logging.info(f"Fetched {len(papers)} papers from arXiv")
28
  return papers
29
 
30
+ def build_faiss_index(papers, dataset_dir=DATASET_DIR):
31
  """Build and save dataset with FAISS index for RAG"""
32
  # Initialize DPR encoder
33
  ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
 
70
  logging.info(f"Saved dataset to {dataset_path}")
71
  logging.info(f"Saved index to {index_path}")
72
  return dataset_dir