Spaces:
Sleeping
Sleeping
fix: write
Browse files- app.py +27 -14
- faiss_index/index.py +8 -5
app.py
CHANGED
@@ -6,6 +6,7 @@ from datasets import load_from_disk
|
|
6 |
import torch
|
7 |
import logging
|
8 |
import warnings
|
|
|
9 |
|
10 |
# Configure logging
|
11 |
logging.basicConfig(level=logging.WARNING)
|
@@ -18,14 +19,23 @@ st.title("🧩 AMA Austim")
|
|
18 |
query = st.text_input("Please ask me anything about autism ✨")
|
19 |
|
20 |
@st.cache_resource
|
21 |
-
def load_rag_components(
|
22 |
"""Load and cache RAG components to avoid reloading."""
|
|
|
23 |
tokenizer = RagTokenizer.from_pretrained(model_name)
|
24 |
-
|
25 |
-
|
26 |
-
index_name
|
27 |
-
use_dummy_dataset
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
model = RagSequenceForGeneration.from_pretrained(model_name)
|
30 |
return tokenizer, retriever, model
|
31 |
|
@@ -37,17 +47,20 @@ def load_rag_dataset(dataset_dir="rag_dataset"):
|
|
37 |
initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
|
38 |
dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
|
39 |
|
|
|
|
|
|
|
40 |
# Load the dataset and index
|
41 |
-
dataset = load_from_disk(
|
42 |
-
index = faiss.read_index(
|
43 |
|
44 |
-
return dataset, index
|
45 |
|
46 |
# RAG Pipeline
|
47 |
-
def rag_pipeline(query, dataset, index):
|
48 |
try:
|
49 |
-
# Load cached components
|
50 |
-
tokenizer, retriever, model = load_rag_components()
|
51 |
|
52 |
# Configure retriever with our dataset
|
53 |
retriever.index.dataset = dataset
|
@@ -76,9 +89,9 @@ def rag_pipeline(query, dataset, index):
|
|
76 |
if query:
|
77 |
with st.status("Looking for data in the best sources...", expanded=True) as status:
|
78 |
st.write("Still looking... this may take a while as we look at some prestigious papers...")
|
79 |
-
dataset, index = load_rag_dataset()
|
80 |
st.write("Found the best sources!")
|
81 |
-
answer = rag_pipeline(query, dataset, index)
|
82 |
st.write("Now answering your question...")
|
83 |
status.update(
|
84 |
label="Searching complete!",
|
|
|
6 |
import torch
|
7 |
import logging
|
8 |
import warnings
|
9 |
+
from pathlib import Path
|
10 |
|
11 |
# Configure logging
|
12 |
logging.basicConfig(level=logging.WARNING)
|
|
|
19 |
query = st.text_input("Please ask me anything about autism ✨")
|
20 |
|
21 |
@st.cache_resource
|
22 |
+
def load_rag_components(_dataset_path=None, _index_path=None):
|
23 |
"""Load and cache RAG components to avoid reloading."""
|
24 |
+
model_name = "facebook/rag-sequence-nq"
|
25 |
tokenizer = RagTokenizer.from_pretrained(model_name)
|
26 |
+
|
27 |
+
retriever_config = {
|
28 |
+
"index_name": "custom",
|
29 |
+
"use_dummy_dataset": True
|
30 |
+
}
|
31 |
+
|
32 |
+
if _dataset_path and _index_path:
|
33 |
+
retriever_config.update({
|
34 |
+
"passages_path": _dataset_path,
|
35 |
+
"index_path": _index_path
|
36 |
+
})
|
37 |
+
|
38 |
+
retriever = RagRetriever.from_pretrained(model_name, **retriever_config)
|
39 |
model = RagSequenceForGeneration.from_pretrained(model_name)
|
40 |
return tokenizer, retriever, model
|
41 |
|
|
|
47 |
initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
|
48 |
dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
|
49 |
|
50 |
+
dataset_path = os.path.join(dataset_dir, "dataset")
|
51 |
+
index_path = os.path.join(dataset_dir, "embeddings.faiss")
|
52 |
+
|
53 |
# Load the dataset and index
|
54 |
+
dataset = load_from_disk(dataset_path)
|
55 |
+
index = faiss.read_index(index_path)
|
56 |
|
57 |
+
return dataset, index, dataset_path, index_path
|
58 |
|
59 |
# RAG Pipeline
|
60 |
+
def rag_pipeline(query, dataset, index, dataset_path, index_path):
|
61 |
try:
|
62 |
+
# Load cached components with paths
|
63 |
+
tokenizer, retriever, model = load_rag_components(dataset_path, index_path)
|
64 |
|
65 |
# Configure retriever with our dataset
|
66 |
retriever.index.dataset = dataset
|
|
|
89 |
if query:
|
90 |
with st.status("Looking for data in the best sources...", expanded=True) as status:
|
91 |
st.write("Still looking... this may take a while as we look at some prestigious papers...")
|
92 |
+
dataset, index, dataset_path, index_path = load_rag_dataset()
|
93 |
st.write("Found the best sources!")
|
94 |
+
answer = rag_pipeline(query, dataset, index, dataset_path, index_path)
|
95 |
st.write("Now answering your question...")
|
96 |
status.update(
|
97 |
label="Searching complete!",
|
faiss_index/index.py
CHANGED
@@ -20,11 +20,12 @@ def fetch_arxiv_papers(query, max_results=10):
|
|
20 |
|
21 |
# Build and save dataset with FAISS index
|
22 |
def build_faiss_index(papers, dataset_dir="rag_dataset"):
|
23 |
-
# Create dataset
|
24 |
dataset = Dataset.from_dict({
|
25 |
"id": [p["id"] for p in papers],
|
|
|
26 |
"title": [p["title"] for p in papers],
|
27 |
-
"
|
28 |
})
|
29 |
|
30 |
# Initialize DPR context encoder (same as used by RAG)
|
@@ -45,9 +46,6 @@ def build_faiss_index(papers, dataset_dir="rag_dataset"):
|
|
45 |
|
46 |
embeddings = np.vstack(embeddings)
|
47 |
|
48 |
-
# Add embeddings to dataset
|
49 |
-
dataset = dataset.add_column("embeddings", [emb.tolist() for emb in embeddings])
|
50 |
-
|
51 |
# Create FAISS index
|
52 |
dimension = embeddings.shape[1] # Should be 768 for DPR
|
53 |
index = faiss.IndexFlatL2(dimension)
|
@@ -55,7 +53,12 @@ def build_faiss_index(papers, dataset_dir="rag_dataset"):
|
|
55 |
|
56 |
# Save dataset and index
|
57 |
os.makedirs(dataset_dir, exist_ok=True)
|
|
|
|
|
|
|
58 |
dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
|
|
|
|
|
59 |
faiss.write_index(index, os.path.join(dataset_dir, "embeddings.faiss"))
|
60 |
|
61 |
return dataset_dir
|
|
|
20 |
|
21 |
# Build and save dataset with FAISS index
|
22 |
def build_faiss_index(papers, dataset_dir="rag_dataset"):
|
23 |
+
# Create dataset with required columns for RAG
|
24 |
dataset = Dataset.from_dict({
|
25 |
"id": [p["id"] for p in papers],
|
26 |
+
"text": [p["text"] for p in papers], # RAG expects 'text' field
|
27 |
"title": [p["title"] for p in papers],
|
28 |
+
"embeddings": None, # Will be filled later
|
29 |
})
|
30 |
|
31 |
# Initialize DPR context encoder (same as used by RAG)
|
|
|
46 |
|
47 |
embeddings = np.vstack(embeddings)
|
48 |
|
|
|
|
|
|
|
49 |
# Create FAISS index
|
50 |
dimension = embeddings.shape[1] # Should be 768 for DPR
|
51 |
index = faiss.IndexFlatL2(dimension)
|
|
|
53 |
|
54 |
# Save dataset and index
|
55 |
os.makedirs(dataset_dir, exist_ok=True)
|
56 |
+
|
57 |
+
# Save dataset with embeddings
|
58 |
+
dataset = dataset.add_column("embeddings", [emb.tolist() for emb in embeddings])
|
59 |
dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
|
60 |
+
|
61 |
+
# Save FAISS index
|
62 |
faiss.write_index(index, os.path.join(dataset_dir, "embeddings.faiss"))
|
63 |
|
64 |
return dataset_dir
|