Spaces:
Sleeping
Sleeping
fix: rag
Browse files- app.py +45 -109
- faiss_index/index.py +54 -86
app.py
CHANGED
@@ -2,124 +2,60 @@ import streamlit as st
|
|
2 |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
|
3 |
import faiss
|
4 |
import os
|
5 |
-
from datasets import load_from_disk
|
6 |
import torch
|
7 |
import logging
|
8 |
-
import traceback
|
9 |
|
10 |
# Configure logging
|
11 |
logging.basicConfig(level=logging.INFO)
|
12 |
|
13 |
-
#
|
14 |
-
st.
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
try:
|
34 |
-
if not os.path.exists(dataset_dir):
|
35 |
-
with st.spinner("Building initial dataset from autism research papers..."):
|
36 |
-
import faiss_index.index as faiss_index_index
|
37 |
-
initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
|
38 |
-
dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
|
39 |
-
|
40 |
-
# Load the dataset and index
|
41 |
-
dataset_path = os.path.join(dataset_dir, "dataset")
|
42 |
-
index_path = os.path.join(dataset_dir, "embeddings.faiss")
|
43 |
-
|
44 |
-
if not os.path.exists(dataset_path) or not os.path.exists(index_path):
|
45 |
-
raise ValueError("Dataset or index not found")
|
46 |
-
|
47 |
-
dataset = load_from_disk(dataset_path)
|
48 |
-
index = faiss.read_index(index_path)
|
49 |
-
|
50 |
-
logging.info("Successfully loaded dataset and index")
|
51 |
-
return dataset, dataset_path, index_path
|
52 |
-
except Exception as e:
|
53 |
-
st.error(f"Error loading dataset: {str(e)}\n{traceback.format_exc()}")
|
54 |
-
return None, None, None
|
55 |
|
56 |
# RAG Pipeline
|
57 |
-
def rag_pipeline(query, dataset,
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
index_path=index_path
|
69 |
-
)
|
70 |
-
|
71 |
-
# Initialize model with retriever
|
72 |
-
model = RagSequenceForGeneration.from_pretrained(
|
73 |
-
model_name,
|
74 |
-
retriever=retriever,
|
75 |
-
use_auth_token=False
|
76 |
)
|
|
|
|
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
outputs = model.generate(
|
82 |
-
inputs["input_ids"],
|
83 |
-
max_length=200,
|
84 |
-
min_length=50,
|
85 |
-
num_beams=5,
|
86 |
-
early_stopping=True,
|
87 |
-
no_repeat_ngram_size=3
|
88 |
-
)
|
89 |
-
answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
90 |
-
|
91 |
-
return answer
|
92 |
-
except Exception as e:
|
93 |
-
st.error(f"Error generating answer: {str(e)}\n{traceback.format_exc()}")
|
94 |
-
return None
|
95 |
|
96 |
-
# Run the app
|
97 |
if query:
|
98 |
-
with st.status("
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
else:
|
107 |
-
st.write("Found the best sources!")
|
108 |
-
st.write("Now answering your question...")
|
109 |
-
answer = rag_pipeline(query, dataset, dataset_path, index_path)
|
110 |
-
|
111 |
-
if answer:
|
112 |
-
status.update(label="Search complete!", state="complete", expanded=False)
|
113 |
-
st.write("### Answer:")
|
114 |
-
st.write_stream(answer)
|
115 |
-
st.write("### Retrieved Papers:")
|
116 |
-
for i in range(min(5, len(dataset))):
|
117 |
-
st.write(f"**Title:** {dataset[i]['title']}")
|
118 |
-
st.write(f"**Summary:** {dataset[i]['text'][:200]}...")
|
119 |
-
st.write("---")
|
120 |
-
else:
|
121 |
-
status.update(label="Error generating answer", state="error")
|
122 |
-
|
123 |
-
except Exception as e:
|
124 |
-
st.error(f"Unexpected error: {str(e)}\n{traceback.format_exc()}")
|
125 |
-
status.update(label="Error", state="error")
|
|
|
2 |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
|
3 |
import faiss
|
4 |
import os
|
5 |
+
from datasets import load_from_disk
|
6 |
import torch
|
7 |
import logging
|
|
|
8 |
|
9 |
# Configure logging
|
10 |
logging.basicConfig(level=logging.INFO)
|
11 |
|
12 |
+
# Cache models and dataset
|
13 |
+
@st.cache_resource # Cache models in memory
|
14 |
+
def load_models():
|
15 |
+
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
|
16 |
+
retriever = RagRetriever.from_pretrained(
|
17 |
+
"facebook/rag-sequence-nq",
|
18 |
+
index_name="custom",
|
19 |
+
passages_path="/data/rag_dataset/dataset",
|
20 |
+
index_path="/data/rag_dataset/embeddings.faiss"
|
21 |
+
)
|
22 |
+
model = RagSequenceForGeneration.from_pretrained(
|
23 |
+
"facebook/rag-sequence-nq",
|
24 |
+
retriever=retriever,
|
25 |
+
device_map="auto"
|
26 |
+
)
|
27 |
+
return tokenizer, retriever, model
|
28 |
+
|
29 |
+
@st.cache_data # Cache dataset on disk
|
30 |
+
def load_dataset():
|
31 |
+
return load_from_disk("/data/rag_dataset/dataset")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
# RAG Pipeline
|
34 |
+
def rag_pipeline(query, dataset, index):
|
35 |
+
tokenizer, retriever, model = load_models()
|
36 |
+
inputs = tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
|
37 |
+
with torch.no_grad():
|
38 |
+
outputs = model.generate(
|
39 |
+
inputs["input_ids"],
|
40 |
+
max_length=200,
|
41 |
+
min_length=50,
|
42 |
+
num_beams=5,
|
43 |
+
early_stopping=True,
|
44 |
+
no_repeat_ngram_size=3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
)
|
46 |
+
answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
47 |
+
return answer
|
48 |
|
49 |
+
# Streamlit App
|
50 |
+
st.title("🧩 AMA Autism")
|
51 |
+
query = st.text_input("Please ask me anything about autism ✨")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
|
|
53 |
if query:
|
54 |
+
with st.status("Searching for answers..."):
|
55 |
+
dataset = load_dataset()
|
56 |
+
answer = rag_pipeline(query, dataset, index=None)
|
57 |
+
if answer:
|
58 |
+
st.success("Answer found!")
|
59 |
+
st.write(answer)
|
60 |
+
else:
|
61 |
+
st.error("Failed to generate an answer.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
faiss_index/index.py
CHANGED
@@ -12,95 +12,63 @@ logging.basicConfig(level=logging.INFO)
|
|
12 |
|
13 |
def fetch_arxiv_papers(query, max_results=10):
|
14 |
"""Fetch papers from arXiv and format them for RAG"""
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
for i, result in enumerate(results):
|
26 |
-
papers.append({
|
27 |
-
"id": str(i),
|
28 |
-
"text": result.summary,
|
29 |
-
"title": result.title,
|
30 |
-
})
|
31 |
-
|
32 |
-
logging.info(f"Fetched {len(papers)} papers from arXiv")
|
33 |
-
return papers
|
34 |
-
|
35 |
-
except Exception as e:
|
36 |
-
logging.error(f"Error fetching papers: {str(e)}")
|
37 |
-
raise
|
38 |
|
39 |
-
def build_faiss_index(papers, dataset_dir="rag_dataset"):
|
40 |
"""Build and save dataset with FAISS index for RAG"""
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
index.add(embeddings.astype(np.float32))
|
83 |
-
|
84 |
-
# Save everything
|
85 |
-
os.makedirs(dataset_dir, exist_ok=True)
|
86 |
-
dataset_path = os.path.join(dataset_dir, "dataset")
|
87 |
-
index_path = os.path.join(dataset_dir, "embeddings.faiss")
|
88 |
-
|
89 |
-
# Save dataset and index
|
90 |
-
dataset.save_to_disk(dataset_path)
|
91 |
-
faiss.write_index(index, index_path)
|
92 |
-
|
93 |
-
logging.info(f"Saved dataset to {dataset_path}")
|
94 |
-
logging.info(f"Saved index to {index_path}")
|
95 |
-
|
96 |
-
return dataset_dir
|
97 |
-
|
98 |
-
except Exception as e:
|
99 |
-
logging.error(f"Error building index: {str(e)}")
|
100 |
-
raise
|
101 |
|
102 |
# Example usage
|
103 |
if __name__ == "__main__":
|
104 |
-
query = "
|
105 |
-
papers = fetch_arxiv_papers(query)
|
106 |
build_faiss_index(papers)
|
|
|
12 |
|
13 |
def fetch_arxiv_papers(query, max_results=10):
|
14 |
"""Fetch papers from arXiv and format them for RAG"""
|
15 |
+
client = arxiv.Client()
|
16 |
+
search = arxiv.Search(
|
17 |
+
query=query,
|
18 |
+
max_results=max_results,
|
19 |
+
sort_by=arxiv.SortCriterion.SubmittedDate
|
20 |
+
)
|
21 |
+
results = list(client.results(search))
|
22 |
+
papers = [{"id": str(i), "text": result.summary, "title": result.title} for i, result in enumerate(results)]
|
23 |
+
logging.info(f"Fetched {len(papers)} papers from arXiv")
|
24 |
+
return papers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
+
def build_faiss_index(papers, dataset_dir="/data/rag_dataset"):
|
27 |
"""Build and save dataset with FAISS index for RAG"""
|
28 |
+
# Initialize DPR encoder
|
29 |
+
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
30 |
+
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
31 |
+
|
32 |
+
# Create embeddings
|
33 |
+
texts = [p["text"] for p in papers]
|
34 |
+
embeddings = []
|
35 |
+
batch_size = 8
|
36 |
+
for i in range(0, len(texts), batch_size):
|
37 |
+
batch = texts[i:i + batch_size]
|
38 |
+
inputs = ctx_tokenizer(batch, max_length=512, padding=True, truncation=True, return_tensors="pt")
|
39 |
+
with torch.no_grad():
|
40 |
+
outputs = ctx_encoder(**inputs)
|
41 |
+
batch_embeddings = outputs.pooler_output.cpu().numpy()
|
42 |
+
embeddings.append(batch_embeddings)
|
43 |
+
|
44 |
+
embeddings = np.vstack(embeddings)
|
45 |
+
logging.info(f"Created embeddings with shape {embeddings.shape}")
|
46 |
+
|
47 |
+
# Create dataset
|
48 |
+
dataset = Dataset.from_dict({
|
49 |
+
"id": [p["id"] for p in papers],
|
50 |
+
"text": [p["text"] for p in papers],
|
51 |
+
"title": [p["title"] for p in papers],
|
52 |
+
"embeddings": [emb.tolist() for emb in embeddings],
|
53 |
+
})
|
54 |
+
|
55 |
+
# Create FAISS index
|
56 |
+
dimension = embeddings.shape[1]
|
57 |
+
index = faiss.IndexFlatL2(dimension)
|
58 |
+
index.add(embeddings.astype(np.float32))
|
59 |
+
|
60 |
+
# Save dataset and index
|
61 |
+
os.makedirs(dataset_dir, exist_ok=True)
|
62 |
+
dataset_path = os.path.join(dataset_dir, "dataset")
|
63 |
+
index_path = os.path.join(dataset_dir, "embeddings.faiss")
|
64 |
+
dataset.save_to_disk(dataset_path)
|
65 |
+
faiss.write_index(index, index_path)
|
66 |
+
logging.info(f"Saved dataset to {dataset_path}")
|
67 |
+
logging.info(f"Saved index to {index_path}")
|
68 |
+
return dataset_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
# Example usage
|
71 |
if __name__ == "__main__":
|
72 |
+
query = "autism research"
|
73 |
+
papers = fetch_arxiv_papers(query, max_results=100)
|
74 |
build_faiss_index(papers)
|