wakeupmh commited on
Commit
8108db5
·
1 Parent(s): 975c327
Files changed (2) hide show
  1. app.py +81 -70
  2. faiss_index/index.py +80 -52
app.py CHANGED
@@ -2,15 +2,13 @@ import streamlit as st
2
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
3
  import faiss
4
  import os
5
- from datasets import load_from_disk
6
  import torch
7
  import logging
8
- import warnings
9
- from pathlib import Path
10
 
11
  # Configure logging
12
- logging.basicConfig(level=logging.WARNING)
13
- warnings.filterwarnings('ignore')
14
 
15
  # Title
16
  st.title("🧩 AMA Austim")
@@ -18,92 +16,105 @@ st.title("🧩 AMA Austim")
18
  # Input: Query
19
  query = st.text_input("Please ask me anything about autism ✨")
20
 
21
- @st.cache_resource
22
- def load_rag_components(_dataset_path=None, _index_path=None):
23
- """Load and cache RAG components to avoid reloading."""
24
- model_name = "facebook/rag-sequence-nq"
25
- tokenizer = RagTokenizer.from_pretrained(model_name)
26
-
27
- retriever_config = {
28
- "index_name": "custom",
29
- "use_dummy_dataset": True
30
- }
31
-
32
- if _dataset_path and _index_path:
33
- retriever_config.update({
34
- "passages_path": _dataset_path,
35
- "index_path": _index_path
36
- })
37
-
38
- retriever = RagRetriever.from_pretrained(model_name, **retriever_config)
39
- model = RagSequenceForGeneration.from_pretrained(model_name)
40
- return tokenizer, retriever, model
41
 
42
  # Load or create RAG dataset
43
  def load_rag_dataset(dataset_dir="rag_dataset"):
44
- if not os.path.exists(dataset_dir):
45
- with st.spinner("Building initial dataset from autism research papers..."):
46
- import faiss_index.index as faiss_index_index
47
- initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
48
- dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
49
-
50
- dataset_path = os.path.join(dataset_dir, "dataset")
51
- index_path = os.path.join(dataset_dir, "embeddings.faiss")
52
-
53
- # Load the dataset and index
54
- dataset = load_from_disk(dataset_path)
55
- index = faiss.read_index(index_path)
56
-
57
- return dataset, index, dataset_path, index_path
58
 
59
  # RAG Pipeline
60
- def rag_pipeline(query, dataset, index, dataset_path, index_path):
61
  try:
62
- # Load cached components with paths
63
- tokenizer, retriever, model = load_rag_components(dataset_path, index_path)
 
64
 
65
- # Configure retriever with our dataset
66
- retriever.index.dataset = dataset
67
- retriever.index.index = index
68
- model.retriever = retriever
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  # Generate answer
71
  inputs = tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
72
  with torch.no_grad():
73
- generated_ids = model.generate(
74
  inputs["input_ids"],
75
  max_length=200,
76
  min_length=50,
77
  num_beams=5,
78
- early_stopping=True
 
79
  )
80
- answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
81
 
82
  return answer
83
-
84
  except Exception as e:
85
- st.error(f"An error occurred while processing your query: {str(e)}")
86
  return None
87
 
88
  # Run the app
89
  if query:
90
  with st.status("Looking for data in the best sources...", expanded=True) as status:
91
- st.write("Still looking... this may take a while as we look at some prestigious papers...")
92
- dataset, index, dataset_path, index_path = load_rag_dataset()
93
- st.write("Found the best sources!")
94
- answer = rag_pipeline(query, dataset, index, dataset_path, index_path)
95
- st.write("Now answering your question...")
96
- status.update(
97
- label="Searching complete!",
98
- state="complete",
99
- expanded=False
100
- )
101
-
102
- if answer:
103
- st.write("### Answer:")
104
- st.write_stream(answer)
105
- st.write("### Retrieved Papers:")
106
- for i in range(min(5, len(dataset))):
107
- st.write(f"**Title:** {dataset[i]['title']}")
108
- st.write(f"**Summary:** {dataset[i]['text'][:200]}...")
109
- st.write("---")
 
 
 
 
 
 
 
 
 
2
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
3
  import faiss
4
  import os
5
+ from datasets import load_from_disk, Dataset
6
  import torch
7
  import logging
8
+ import traceback
 
9
 
10
  # Configure logging
11
+ logging.basicConfig(level=logging.INFO)
 
12
 
13
  # Title
14
  st.title("🧩 AMA Austim")
 
16
  # Input: Query
17
  query = st.text_input("Please ask me anything about autism ✨")
18
 
19
+ def prepare_rag_passages(dataset):
20
+ """Convert dataset to the format expected by RAG"""
21
+ return [
22
+ {
23
+ "id": str(i),
24
+ "text": row["text"],
25
+ "title": row["title"],
26
+ "document_id": int(row["id"])
27
+ }
28
+ for i, row in enumerate(dataset)
29
+ ]
 
 
 
 
 
 
 
 
 
30
 
31
  # Load or create RAG dataset
32
  def load_rag_dataset(dataset_dir="rag_dataset"):
33
+ try:
34
+ if not os.path.exists(dataset_dir):
35
+ with st.spinner("Building initial dataset from autism research papers..."):
36
+ import faiss_index.index as faiss_index_index
37
+ initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
38
+ dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
39
+
40
+ # Load the dataset and index
41
+ dataset = load_from_disk(os.path.join(dataset_dir, "dataset"))
42
+ index = faiss.read_index(os.path.join(dataset_dir, "embeddings.faiss"))
43
+ return dataset, index
44
+ except Exception as e:
45
+ st.error(f"Error loading dataset: {str(e)}\n{traceback.format_exc()}")
46
+ return None, None
47
 
48
  # RAG Pipeline
49
+ def rag_pipeline(query, dataset, index):
50
  try:
51
+ # Initialize components
52
+ model_name = "facebook/rag-sequence-nq"
53
+ tokenizer = RagTokenizer.from_pretrained(model_name)
54
 
55
+ # Convert dataset to passages format
56
+ passages = prepare_rag_passages(dataset)
57
+
58
+ # Initialize retriever with passages
59
+ retriever = RagRetriever.from_pretrained(
60
+ model_name,
61
+ index_name="custom",
62
+ passages=passages,
63
+ index=index
64
+ )
65
+
66
+ # Initialize model with retriever
67
+ model = RagSequenceForGeneration.from_pretrained(
68
+ model_name,
69
+ retriever=retriever,
70
+ use_auth_token=False
71
+ )
72
 
73
  # Generate answer
74
  inputs = tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
75
  with torch.no_grad():
76
+ outputs = model.generate(
77
  inputs["input_ids"],
78
  max_length=200,
79
  min_length=50,
80
  num_beams=5,
81
+ early_stopping=True,
82
+ no_repeat_ngram_size=3
83
  )
84
+ answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
85
 
86
  return answer
 
87
  except Exception as e:
88
+ st.error(f"Error generating answer: {str(e)}\n{traceback.format_exc()}")
89
  return None
90
 
91
  # Run the app
92
  if query:
93
  with st.status("Looking for data in the best sources...", expanded=True) as status:
94
+ try:
95
+ st.write("Still looking... this may take a while as we look at some prestigious papers...")
96
+ dataset, index = load_rag_dataset()
97
+
98
+ if dataset is None or index is None:
99
+ st.error("Failed to load or create the dataset.")
100
+ status.update(label="Error loading data", state="error")
101
+ else:
102
+ st.write("Found the best sources!")
103
+ st.write("Now answering your question...")
104
+ answer = rag_pipeline(query, dataset, index)
105
+
106
+ if answer:
107
+ status.update(label="Search complete!", state="complete", expanded=False)
108
+ st.write("### Answer:")
109
+ st.write_stream(answer)
110
+ st.write("### Retrieved Papers:")
111
+ for i in range(min(5, len(dataset))):
112
+ st.write(f"**Title:** {dataset[i]['title']}")
113
+ st.write(f"**Summary:** {dataset[i]['text'][:200]}...")
114
+ st.write("---")
115
+ else:
116
+ status.update(label="Error generating answer", state="error")
117
+
118
+ except Exception as e:
119
+ st.error(f"Unexpected error: {str(e)}\n{traceback.format_exc()}")
120
+ status.update(label="Error", state="error")
faiss_index/index.py CHANGED
@@ -5,63 +5,91 @@ from datasets import Dataset
5
  import os
6
  from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
7
  import torch
 
 
 
 
8
 
9
- # Fetch arXiv papers
10
  def fetch_arxiv_papers(query, max_results=10):
11
- client = arxiv.Client()
12
- search = arxiv.Search(
13
- query=query,
14
- max_results=max_results,
15
- sort_by=arxiv.SortCriterion.SubmittedDate
16
- )
17
- results = list(client.results(search))
18
- papers = [{"title": result.title, "text": result.summary, "id": str(i)} for i, result in enumerate(results)]
19
- return papers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- # Build and save dataset with FAISS index
22
  def build_faiss_index(papers, dataset_dir="rag_dataset"):
23
- # Create dataset with required columns for RAG
24
- dataset = Dataset.from_dict({
25
- "id": [p["id"] for p in papers],
26
- "text": [p["text"] for p in papers], # RAG expects 'text' field
27
- "title": [p["title"] for p in papers],
28
- "embeddings": None, # Will be filled later
29
- })
30
-
31
- # Initialize DPR context encoder (same as used by RAG)
32
- ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
33
- ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
34
-
35
- # Create embeddings
36
- embeddings = []
37
- batch_size = 8
38
-
39
- for i in range(0, len(dataset), batch_size):
40
- batch = dataset[i:i + batch_size]["text"]
41
- inputs = ctx_tokenizer(batch, max_length=512, padding=True, truncation=True, return_tensors="pt")
42
- with torch.no_grad():
43
- outputs = ctx_encoder(**inputs)
44
- batch_embeddings = outputs.pooler_output.cpu().numpy()
45
- embeddings.append(batch_embeddings)
46
-
47
- embeddings = np.vstack(embeddings)
48
-
49
- # Create FAISS index
50
- dimension = embeddings.shape[1] # Should be 768 for DPR
51
- index = faiss.IndexFlatL2(dimension)
52
- index.add(embeddings.astype(np.float32))
53
-
54
- # Save dataset and index
55
- os.makedirs(dataset_dir, exist_ok=True)
56
-
57
- # Save dataset with embeddings
58
- dataset = dataset.add_column("embeddings", [emb.tolist() for emb in embeddings])
59
- dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
60
-
61
- # Save FAISS index
62
- faiss.write_index(index, os.path.join(dataset_dir, "embeddings.faiss"))
63
 
64
- return dataset_dir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # Example usage
67
  if __name__ == "__main__":
 
5
  import os
6
  from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
7
  import torch
8
+ import logging
9
+
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO)
12
 
 
13
  def fetch_arxiv_papers(query, max_results=10):
14
+ """Fetch papers from arXiv and format them for RAG"""
15
+ try:
16
+ client = arxiv.Client()
17
+ search = arxiv.Search(
18
+ query=query,
19
+ max_results=max_results,
20
+ sort_by=arxiv.SortCriterion.SubmittedDate
21
+ )
22
+ results = list(client.results(search))
23
+ papers = []
24
+
25
+ for i, result in enumerate(results):
26
+ papers.append({
27
+ "id": str(i), # Unique identifier
28
+ "text": result.summary, # Main content for embedding
29
+ "title": result.title, # Title for display
30
+ })
31
+
32
+ logging.info(f"Fetched {len(papers)} papers from arXiv")
33
+ return papers
34
+
35
+ except Exception as e:
36
+ logging.error(f"Error fetching papers: {str(e)}")
37
+ raise
38
 
 
39
  def build_faiss_index(papers, dataset_dir="rag_dataset"):
40
+ """Build and save dataset with FAISS index for RAG"""
41
+ try:
42
+ # Create dataset with required fields for RAG
43
+ dataset = Dataset.from_dict({
44
+ "id": [p["id"] for p in papers],
45
+ "text": [p["text"] for p in papers], # Main content field
46
+ "title": [p["title"] for p in papers], # Additional metadata
47
+ })
48
+
49
+ logging.info(f"Created dataset with {len(dataset)} papers")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ # Initialize DPR encoder (same as used by RAG)
52
+ ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
53
+ ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
54
+
55
+ # Create embeddings in batches
56
+ embeddings = []
57
+ batch_size = 8
58
+
59
+ for i in range(0, len(dataset), batch_size):
60
+ batch = dataset[i:i + batch_size]["text"]
61
+ inputs = ctx_tokenizer(
62
+ batch,
63
+ max_length=512,
64
+ padding=True,
65
+ truncation=True,
66
+ return_tensors="pt"
67
+ )
68
+
69
+ with torch.no_grad():
70
+ outputs = ctx_encoder(**inputs)
71
+ batch_embeddings = outputs.pooler_output.cpu().numpy()
72
+ embeddings.append(batch_embeddings)
73
+
74
+ embeddings = np.vstack(embeddings)
75
+ logging.info(f"Created embeddings with shape {embeddings.shape}")
76
+
77
+ # Create FAISS index (L2 distance)
78
+ dimension = embeddings.shape[1] # Should be 768 for DPR
79
+ index = faiss.IndexFlatL2(dimension)
80
+ index.add(embeddings.astype(np.float32))
81
+
82
+ # Save dataset and index
83
+ os.makedirs(dataset_dir, exist_ok=True)
84
+ dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
85
+ faiss.write_index(index, os.path.join(dataset_dir, "embeddings.faiss"))
86
+
87
+ logging.info(f"Saved dataset and index to {dataset_dir}")
88
+ return dataset_dir
89
+
90
+ except Exception as e:
91
+ logging.error(f"Error building index: {str(e)}")
92
+ raise
93
 
94
  # Example usage
95
  if __name__ == "__main__":