Spaces:
Sleeping
Sleeping
fix: rag
Browse files- app.py +81 -70
- faiss_index/index.py +80 -52
app.py
CHANGED
@@ -2,15 +2,13 @@ import streamlit as st
|
|
2 |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
|
3 |
import faiss
|
4 |
import os
|
5 |
-
from datasets import load_from_disk
|
6 |
import torch
|
7 |
import logging
|
8 |
-
import
|
9 |
-
from pathlib import Path
|
10 |
|
11 |
# Configure logging
|
12 |
-
logging.basicConfig(level=logging.
|
13 |
-
warnings.filterwarnings('ignore')
|
14 |
|
15 |
# Title
|
16 |
st.title("🧩 AMA Austim")
|
@@ -18,92 +16,105 @@ st.title("🧩 AMA Austim")
|
|
18 |
# Input: Query
|
19 |
query = st.text_input("Please ask me anything about autism ✨")
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
if _dataset_path and _index_path:
|
33 |
-
retriever_config.update({
|
34 |
-
"passages_path": _dataset_path,
|
35 |
-
"index_path": _index_path
|
36 |
-
})
|
37 |
-
|
38 |
-
retriever = RagRetriever.from_pretrained(model_name, **retriever_config)
|
39 |
-
model = RagSequenceForGeneration.from_pretrained(model_name)
|
40 |
-
return tokenizer, retriever, model
|
41 |
|
42 |
# Load or create RAG dataset
|
43 |
def load_rag_dataset(dataset_dir="rag_dataset"):
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
|
59 |
# RAG Pipeline
|
60 |
-
def rag_pipeline(query, dataset, index
|
61 |
try:
|
62 |
-
#
|
63 |
-
|
|
|
64 |
|
65 |
-
#
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
# Generate answer
|
71 |
inputs = tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
|
72 |
with torch.no_grad():
|
73 |
-
|
74 |
inputs["input_ids"],
|
75 |
max_length=200,
|
76 |
min_length=50,
|
77 |
num_beams=5,
|
78 |
-
early_stopping=True
|
|
|
79 |
)
|
80 |
-
answer = tokenizer.batch_decode(
|
81 |
|
82 |
return answer
|
83 |
-
|
84 |
except Exception as e:
|
85 |
-
st.error(f"
|
86 |
return None
|
87 |
|
88 |
# Run the app
|
89 |
if query:
|
90 |
with st.status("Looking for data in the best sources...", expanded=True) as status:
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
|
3 |
import faiss
|
4 |
import os
|
5 |
+
from datasets import load_from_disk, Dataset
|
6 |
import torch
|
7 |
import logging
|
8 |
+
import traceback
|
|
|
9 |
|
10 |
# Configure logging
|
11 |
+
logging.basicConfig(level=logging.INFO)
|
|
|
12 |
|
13 |
# Title
|
14 |
st.title("🧩 AMA Austim")
|
|
|
16 |
# Input: Query
|
17 |
query = st.text_input("Please ask me anything about autism ✨")
|
18 |
|
19 |
+
def prepare_rag_passages(dataset):
|
20 |
+
"""Convert dataset to the format expected by RAG"""
|
21 |
+
return [
|
22 |
+
{
|
23 |
+
"id": str(i),
|
24 |
+
"text": row["text"],
|
25 |
+
"title": row["title"],
|
26 |
+
"document_id": int(row["id"])
|
27 |
+
}
|
28 |
+
for i, row in enumerate(dataset)
|
29 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
# Load or create RAG dataset
|
32 |
def load_rag_dataset(dataset_dir="rag_dataset"):
|
33 |
+
try:
|
34 |
+
if not os.path.exists(dataset_dir):
|
35 |
+
with st.spinner("Building initial dataset from autism research papers..."):
|
36 |
+
import faiss_index.index as faiss_index_index
|
37 |
+
initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
|
38 |
+
dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
|
39 |
+
|
40 |
+
# Load the dataset and index
|
41 |
+
dataset = load_from_disk(os.path.join(dataset_dir, "dataset"))
|
42 |
+
index = faiss.read_index(os.path.join(dataset_dir, "embeddings.faiss"))
|
43 |
+
return dataset, index
|
44 |
+
except Exception as e:
|
45 |
+
st.error(f"Error loading dataset: {str(e)}\n{traceback.format_exc()}")
|
46 |
+
return None, None
|
47 |
|
48 |
# RAG Pipeline
|
49 |
+
def rag_pipeline(query, dataset, index):
|
50 |
try:
|
51 |
+
# Initialize components
|
52 |
+
model_name = "facebook/rag-sequence-nq"
|
53 |
+
tokenizer = RagTokenizer.from_pretrained(model_name)
|
54 |
|
55 |
+
# Convert dataset to passages format
|
56 |
+
passages = prepare_rag_passages(dataset)
|
57 |
+
|
58 |
+
# Initialize retriever with passages
|
59 |
+
retriever = RagRetriever.from_pretrained(
|
60 |
+
model_name,
|
61 |
+
index_name="custom",
|
62 |
+
passages=passages,
|
63 |
+
index=index
|
64 |
+
)
|
65 |
+
|
66 |
+
# Initialize model with retriever
|
67 |
+
model = RagSequenceForGeneration.from_pretrained(
|
68 |
+
model_name,
|
69 |
+
retriever=retriever,
|
70 |
+
use_auth_token=False
|
71 |
+
)
|
72 |
|
73 |
# Generate answer
|
74 |
inputs = tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
|
75 |
with torch.no_grad():
|
76 |
+
outputs = model.generate(
|
77 |
inputs["input_ids"],
|
78 |
max_length=200,
|
79 |
min_length=50,
|
80 |
num_beams=5,
|
81 |
+
early_stopping=True,
|
82 |
+
no_repeat_ngram_size=3
|
83 |
)
|
84 |
+
answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
85 |
|
86 |
return answer
|
|
|
87 |
except Exception as e:
|
88 |
+
st.error(f"Error generating answer: {str(e)}\n{traceback.format_exc()}")
|
89 |
return None
|
90 |
|
91 |
# Run the app
|
92 |
if query:
|
93 |
with st.status("Looking for data in the best sources...", expanded=True) as status:
|
94 |
+
try:
|
95 |
+
st.write("Still looking... this may take a while as we look at some prestigious papers...")
|
96 |
+
dataset, index = load_rag_dataset()
|
97 |
+
|
98 |
+
if dataset is None or index is None:
|
99 |
+
st.error("Failed to load or create the dataset.")
|
100 |
+
status.update(label="Error loading data", state="error")
|
101 |
+
else:
|
102 |
+
st.write("Found the best sources!")
|
103 |
+
st.write("Now answering your question...")
|
104 |
+
answer = rag_pipeline(query, dataset, index)
|
105 |
+
|
106 |
+
if answer:
|
107 |
+
status.update(label="Search complete!", state="complete", expanded=False)
|
108 |
+
st.write("### Answer:")
|
109 |
+
st.write_stream(answer)
|
110 |
+
st.write("### Retrieved Papers:")
|
111 |
+
for i in range(min(5, len(dataset))):
|
112 |
+
st.write(f"**Title:** {dataset[i]['title']}")
|
113 |
+
st.write(f"**Summary:** {dataset[i]['text'][:200]}...")
|
114 |
+
st.write("---")
|
115 |
+
else:
|
116 |
+
status.update(label="Error generating answer", state="error")
|
117 |
+
|
118 |
+
except Exception as e:
|
119 |
+
st.error(f"Unexpected error: {str(e)}\n{traceback.format_exc()}")
|
120 |
+
status.update(label="Error", state="error")
|
faiss_index/index.py
CHANGED
@@ -5,63 +5,91 @@ from datasets import Dataset
|
|
5 |
import os
|
6 |
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
|
7 |
import torch
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
# Fetch arXiv papers
|
10 |
def fetch_arxiv_papers(query, max_results=10):
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
# Build and save dataset with FAISS index
|
22 |
def build_faiss_index(papers, dataset_dir="rag_dataset"):
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
34 |
-
|
35 |
-
# Create embeddings
|
36 |
-
embeddings = []
|
37 |
-
batch_size = 8
|
38 |
-
|
39 |
-
for i in range(0, len(dataset), batch_size):
|
40 |
-
batch = dataset[i:i + batch_size]["text"]
|
41 |
-
inputs = ctx_tokenizer(batch, max_length=512, padding=True, truncation=True, return_tensors="pt")
|
42 |
-
with torch.no_grad():
|
43 |
-
outputs = ctx_encoder(**inputs)
|
44 |
-
batch_embeddings = outputs.pooler_output.cpu().numpy()
|
45 |
-
embeddings.append(batch_embeddings)
|
46 |
-
|
47 |
-
embeddings = np.vstack(embeddings)
|
48 |
-
|
49 |
-
# Create FAISS index
|
50 |
-
dimension = embeddings.shape[1] # Should be 768 for DPR
|
51 |
-
index = faiss.IndexFlatL2(dimension)
|
52 |
-
index.add(embeddings.astype(np.float32))
|
53 |
-
|
54 |
-
# Save dataset and index
|
55 |
-
os.makedirs(dataset_dir, exist_ok=True)
|
56 |
-
|
57 |
-
# Save dataset with embeddings
|
58 |
-
dataset = dataset.add_column("embeddings", [emb.tolist() for emb in embeddings])
|
59 |
-
dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
|
60 |
-
|
61 |
-
# Save FAISS index
|
62 |
-
faiss.write_index(index, os.path.join(dataset_dir, "embeddings.faiss"))
|
63 |
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
# Example usage
|
67 |
if __name__ == "__main__":
|
|
|
5 |
import os
|
6 |
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
|
7 |
import torch
|
8 |
+
import logging
|
9 |
+
|
10 |
+
# Configure logging
|
11 |
+
logging.basicConfig(level=logging.INFO)
|
12 |
|
|
|
13 |
def fetch_arxiv_papers(query, max_results=10):
|
14 |
+
"""Fetch papers from arXiv and format them for RAG"""
|
15 |
+
try:
|
16 |
+
client = arxiv.Client()
|
17 |
+
search = arxiv.Search(
|
18 |
+
query=query,
|
19 |
+
max_results=max_results,
|
20 |
+
sort_by=arxiv.SortCriterion.SubmittedDate
|
21 |
+
)
|
22 |
+
results = list(client.results(search))
|
23 |
+
papers = []
|
24 |
+
|
25 |
+
for i, result in enumerate(results):
|
26 |
+
papers.append({
|
27 |
+
"id": str(i), # Unique identifier
|
28 |
+
"text": result.summary, # Main content for embedding
|
29 |
+
"title": result.title, # Title for display
|
30 |
+
})
|
31 |
+
|
32 |
+
logging.info(f"Fetched {len(papers)} papers from arXiv")
|
33 |
+
return papers
|
34 |
+
|
35 |
+
except Exception as e:
|
36 |
+
logging.error(f"Error fetching papers: {str(e)}")
|
37 |
+
raise
|
38 |
|
|
|
39 |
def build_faiss_index(papers, dataset_dir="rag_dataset"):
|
40 |
+
"""Build and save dataset with FAISS index for RAG"""
|
41 |
+
try:
|
42 |
+
# Create dataset with required fields for RAG
|
43 |
+
dataset = Dataset.from_dict({
|
44 |
+
"id": [p["id"] for p in papers],
|
45 |
+
"text": [p["text"] for p in papers], # Main content field
|
46 |
+
"title": [p["title"] for p in papers], # Additional metadata
|
47 |
+
})
|
48 |
+
|
49 |
+
logging.info(f"Created dataset with {len(dataset)} papers")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
+
# Initialize DPR encoder (same as used by RAG)
|
52 |
+
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
53 |
+
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
54 |
+
|
55 |
+
# Create embeddings in batches
|
56 |
+
embeddings = []
|
57 |
+
batch_size = 8
|
58 |
+
|
59 |
+
for i in range(0, len(dataset), batch_size):
|
60 |
+
batch = dataset[i:i + batch_size]["text"]
|
61 |
+
inputs = ctx_tokenizer(
|
62 |
+
batch,
|
63 |
+
max_length=512,
|
64 |
+
padding=True,
|
65 |
+
truncation=True,
|
66 |
+
return_tensors="pt"
|
67 |
+
)
|
68 |
+
|
69 |
+
with torch.no_grad():
|
70 |
+
outputs = ctx_encoder(**inputs)
|
71 |
+
batch_embeddings = outputs.pooler_output.cpu().numpy()
|
72 |
+
embeddings.append(batch_embeddings)
|
73 |
+
|
74 |
+
embeddings = np.vstack(embeddings)
|
75 |
+
logging.info(f"Created embeddings with shape {embeddings.shape}")
|
76 |
+
|
77 |
+
# Create FAISS index (L2 distance)
|
78 |
+
dimension = embeddings.shape[1] # Should be 768 for DPR
|
79 |
+
index = faiss.IndexFlatL2(dimension)
|
80 |
+
index.add(embeddings.astype(np.float32))
|
81 |
+
|
82 |
+
# Save dataset and index
|
83 |
+
os.makedirs(dataset_dir, exist_ok=True)
|
84 |
+
dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
|
85 |
+
faiss.write_index(index, os.path.join(dataset_dir, "embeddings.faiss"))
|
86 |
+
|
87 |
+
logging.info(f"Saved dataset and index to {dataset_dir}")
|
88 |
+
return dataset_dir
|
89 |
+
|
90 |
+
except Exception as e:
|
91 |
+
logging.error(f"Error building index: {str(e)}")
|
92 |
+
raise
|
93 |
|
94 |
# Example usage
|
95 |
if __name__ == "__main__":
|