Spaces:
Sleeping
Sleeping
fix: run in hf
Browse files- app.py +15 -3
- faiss_index/index.py +5 -7
app.py
CHANGED
@@ -9,6 +9,12 @@ import logging
|
|
9 |
# Configure logging
|
10 |
logging.basicConfig(level=logging.INFO)
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
# Cache models and dataset
|
13 |
@st.cache_resource # Cache models in memory
|
14 |
def load_models():
|
@@ -16,8 +22,8 @@ def load_models():
|
|
16 |
retriever = RagRetriever.from_pretrained(
|
17 |
"facebook/rag-sequence-nq",
|
18 |
index_name="custom",
|
19 |
-
passages_path=
|
20 |
-
index_path=
|
21 |
)
|
22 |
model = RagSequenceForGeneration.from_pretrained(
|
23 |
"facebook/rag-sequence-nq",
|
@@ -28,7 +34,13 @@ def load_models():
|
|
28 |
|
29 |
@st.cache_data # Cache dataset on disk
|
30 |
def load_dataset():
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
# RAG Pipeline
|
34 |
def rag_pipeline(query, dataset, index):
|
|
|
9 |
# Configure logging
|
10 |
logging.basicConfig(level=logging.INFO)
|
11 |
|
12 |
+
# Define data paths
|
13 |
+
DATA_DIR = "/data" if os.path.exists("/data") else "."
|
14 |
+
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
|
15 |
+
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
|
16 |
+
INDEX_PATH = os.path.join(DATASET_DIR, "embeddings.faiss")
|
17 |
+
|
18 |
# Cache models and dataset
|
19 |
@st.cache_resource # Cache models in memory
|
20 |
def load_models():
|
|
|
22 |
retriever = RagRetriever.from_pretrained(
|
23 |
"facebook/rag-sequence-nq",
|
24 |
index_name="custom",
|
25 |
+
passages_path=DATASET_PATH,
|
26 |
+
index_path=INDEX_PATH
|
27 |
)
|
28 |
model = RagSequenceForGeneration.from_pretrained(
|
29 |
"facebook/rag-sequence-nq",
|
|
|
34 |
|
35 |
@st.cache_data # Cache dataset on disk
|
36 |
def load_dataset():
|
37 |
+
# Create initial dataset if it doesn't exist
|
38 |
+
if not os.path.exists(DATASET_PATH):
|
39 |
+
with st.spinner("Building initial dataset from autism research papers..."):
|
40 |
+
import faiss_index.index as idx
|
41 |
+
papers = idx.fetch_arxiv_papers("autism research", max_results=100)
|
42 |
+
idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
|
43 |
+
return load_from_disk(DATASET_PATH)
|
44 |
|
45 |
# RAG Pipeline
|
46 |
def rag_pipeline(query, dataset, index):
|
faiss_index/index.py
CHANGED
@@ -10,6 +10,10 @@ import logging
|
|
10 |
# Configure logging
|
11 |
logging.basicConfig(level=logging.INFO)
|
12 |
|
|
|
|
|
|
|
|
|
13 |
def fetch_arxiv_papers(query, max_results=10):
|
14 |
"""Fetch papers from arXiv and format them for RAG"""
|
15 |
client = arxiv.Client()
|
@@ -23,7 +27,7 @@ def fetch_arxiv_papers(query, max_results=10):
|
|
23 |
logging.info(f"Fetched {len(papers)} papers from arXiv")
|
24 |
return papers
|
25 |
|
26 |
-
def build_faiss_index(papers, dataset_dir=
|
27 |
"""Build and save dataset with FAISS index for RAG"""
|
28 |
# Initialize DPR encoder
|
29 |
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
@@ -66,9 +70,3 @@ def build_faiss_index(papers, dataset_dir="rag_dataset"):
|
|
66 |
logging.info(f"Saved dataset to {dataset_path}")
|
67 |
logging.info(f"Saved index to {index_path}")
|
68 |
return dataset_dir
|
69 |
-
|
70 |
-
# Example usage
|
71 |
-
if __name__ == "__main__":
|
72 |
-
query = "autism research"
|
73 |
-
papers = fetch_arxiv_papers(query, max_results=100)
|
74 |
-
build_faiss_index(papers)
|
|
|
10 |
# Configure logging
|
11 |
logging.basicConfig(level=logging.INFO)
|
12 |
|
13 |
+
# Define data paths
|
14 |
+
DATA_DIR = os.getenv("DATA_DIR", "/data" if os.path.exists("/data") else ".")
|
15 |
+
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
|
16 |
+
|
17 |
def fetch_arxiv_papers(query, max_results=10):
|
18 |
"""Fetch papers from arXiv and format them for RAG"""
|
19 |
client = arxiv.Client()
|
|
|
27 |
logging.info(f"Fetched {len(papers)} papers from arXiv")
|
28 |
return papers
|
29 |
|
30 |
+
def build_faiss_index(papers, dataset_dir=DATASET_DIR):
|
31 |
"""Build and save dataset with FAISS index for RAG"""
|
32 |
# Initialize DPR encoder
|
33 |
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
|
|
70 |
logging.info(f"Saved dataset to {dataset_path}")
|
71 |
logging.info(f"Saved index to {index_path}")
|
72 |
return dataset_dir
|
|
|
|
|
|
|
|
|
|
|
|