wakeupmh commited on
Commit
f68ac31
·
1 Parent(s): 6f43c31
Files changed (2) hide show
  1. app.py +45 -109
  2. faiss_index/index.py +54 -86
app.py CHANGED
@@ -2,124 +2,60 @@ import streamlit as st
2
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
3
  import faiss
4
  import os
5
- from datasets import load_from_disk, Dataset
6
  import torch
7
  import logging
8
- import traceback
9
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.INFO)
12
 
13
- # Title
14
- st.title("🧩 AMA Austim")
15
-
16
- # Input: Query
17
- query = st.text_input("Please ask me anything about autism ✨")
18
-
19
- def prepare_rag_passages(dataset):
20
- """Convert dataset to the format expected by RAG"""
21
- return [
22
- {
23
- "id": str(i),
24
- "text": row["text"],
25
- "title": row["title"],
26
- "document_id": int(row["id"])
27
- }
28
- for i, row in enumerate(dataset)
29
- ]
30
-
31
- # Load or create RAG dataset
32
- def load_rag_dataset(dataset_dir="rag_dataset"):
33
- try:
34
- if not os.path.exists(dataset_dir):
35
- with st.spinner("Building initial dataset from autism research papers..."):
36
- import faiss_index.index as faiss_index_index
37
- initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
38
- dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
39
-
40
- # Load the dataset and index
41
- dataset_path = os.path.join(dataset_dir, "dataset")
42
- index_path = os.path.join(dataset_dir, "embeddings.faiss")
43
-
44
- if not os.path.exists(dataset_path) or not os.path.exists(index_path):
45
- raise ValueError("Dataset or index not found")
46
-
47
- dataset = load_from_disk(dataset_path)
48
- index = faiss.read_index(index_path)
49
-
50
- logging.info("Successfully loaded dataset and index")
51
- return dataset, dataset_path, index_path
52
- except Exception as e:
53
- st.error(f"Error loading dataset: {str(e)}\n{traceback.format_exc()}")
54
- return None, None, None
55
 
56
  # RAG Pipeline
57
- def rag_pipeline(query, dataset, dataset_path, index_path):
58
- try:
59
- # Initialize components
60
- model_name = "facebook/rag-sequence-nq"
61
- tokenizer = RagTokenizer.from_pretrained(model_name)
62
-
63
- # Initialize retriever with correct paths
64
- retriever = RagRetriever.from_pretrained(
65
- model_name,
66
- index_name="custom",
67
- passages_path=dataset_path,
68
- index_path=index_path
69
- )
70
-
71
- # Initialize model with retriever
72
- model = RagSequenceForGeneration.from_pretrained(
73
- model_name,
74
- retriever=retriever,
75
- use_auth_token=False
76
  )
 
 
77
 
78
- # Generate answer
79
- inputs = tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
80
- with torch.no_grad():
81
- outputs = model.generate(
82
- inputs["input_ids"],
83
- max_length=200,
84
- min_length=50,
85
- num_beams=5,
86
- early_stopping=True,
87
- no_repeat_ngram_size=3
88
- )
89
- answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
90
-
91
- return answer
92
- except Exception as e:
93
- st.error(f"Error generating answer: {str(e)}\n{traceback.format_exc()}")
94
- return None
95
 
96
- # Run the app
97
  if query:
98
- with st.status("Looking for data in the best sources...", expanded=True) as status:
99
- try:
100
- st.write("Still looking... this may take a while as we look at some prestigious papers...")
101
- dataset, dataset_path, index_path = load_rag_dataset()
102
-
103
- if dataset is None:
104
- st.error("Failed to load or create the dataset.")
105
- status.update(label="Error loading data", state="error")
106
- else:
107
- st.write("Found the best sources!")
108
- st.write("Now answering your question...")
109
- answer = rag_pipeline(query, dataset, dataset_path, index_path)
110
-
111
- if answer:
112
- status.update(label="Search complete!", state="complete", expanded=False)
113
- st.write("### Answer:")
114
- st.write_stream(answer)
115
- st.write("### Retrieved Papers:")
116
- for i in range(min(5, len(dataset))):
117
- st.write(f"**Title:** {dataset[i]['title']}")
118
- st.write(f"**Summary:** {dataset[i]['text'][:200]}...")
119
- st.write("---")
120
- else:
121
- status.update(label="Error generating answer", state="error")
122
-
123
- except Exception as e:
124
- st.error(f"Unexpected error: {str(e)}\n{traceback.format_exc()}")
125
- status.update(label="Error", state="error")
 
2
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
3
  import faiss
4
  import os
5
+ from datasets import load_from_disk
6
  import torch
7
  import logging
 
8
 
9
  # Configure logging
10
  logging.basicConfig(level=logging.INFO)
11
 
12
+ # Cache models and dataset
13
+ @st.cache_resource # Cache models in memory
14
+ def load_models():
15
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
16
+ retriever = RagRetriever.from_pretrained(
17
+ "facebook/rag-sequence-nq",
18
+ index_name="custom",
19
+ passages_path="/data/rag_dataset/dataset",
20
+ index_path="/data/rag_dataset/embeddings.faiss"
21
+ )
22
+ model = RagSequenceForGeneration.from_pretrained(
23
+ "facebook/rag-sequence-nq",
24
+ retriever=retriever,
25
+ device_map="auto"
26
+ )
27
+ return tokenizer, retriever, model
28
+
29
+ @st.cache_data # Cache dataset on disk
30
+ def load_dataset():
31
+ return load_from_disk("/data/rag_dataset/dataset")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  # RAG Pipeline
34
+ def rag_pipeline(query, dataset, index):
35
+ tokenizer, retriever, model = load_models()
36
+ inputs = tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
37
+ with torch.no_grad():
38
+ outputs = model.generate(
39
+ inputs["input_ids"],
40
+ max_length=200,
41
+ min_length=50,
42
+ num_beams=5,
43
+ early_stopping=True,
44
+ no_repeat_ngram_size=3
 
 
 
 
 
 
 
 
45
  )
46
+ answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
47
+ return answer
48
 
49
+ # Streamlit App
50
+ st.title("🧩 AMA Autism")
51
+ query = st.text_input("Please ask me anything about autism ✨")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
 
53
  if query:
54
+ with st.status("Searching for answers..."):
55
+ dataset = load_dataset()
56
+ answer = rag_pipeline(query, dataset, index=None)
57
+ if answer:
58
+ st.success("Answer found!")
59
+ st.write(answer)
60
+ else:
61
+ st.error("Failed to generate an answer.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
faiss_index/index.py CHANGED
@@ -12,95 +12,63 @@ logging.basicConfig(level=logging.INFO)
12
 
13
  def fetch_arxiv_papers(query, max_results=10):
14
  """Fetch papers from arXiv and format them for RAG"""
15
- try:
16
- client = arxiv.Client()
17
- search = arxiv.Search(
18
- query=query,
19
- max_results=max_results,
20
- sort_by=arxiv.SortCriterion.SubmittedDate
21
- )
22
- results = list(client.results(search))
23
- papers = []
24
-
25
- for i, result in enumerate(results):
26
- papers.append({
27
- "id": str(i),
28
- "text": result.summary,
29
- "title": result.title,
30
- })
31
-
32
- logging.info(f"Fetched {len(papers)} papers from arXiv")
33
- return papers
34
-
35
- except Exception as e:
36
- logging.error(f"Error fetching papers: {str(e)}")
37
- raise
38
 
39
- def build_faiss_index(papers, dataset_dir="rag_dataset"):
40
  """Build and save dataset with FAISS index for RAG"""
41
- try:
42
- # Initialize DPR encoder
43
- ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
44
- ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
45
-
46
- # Create embeddings in batches
47
- texts = [p["text"] for p in papers]
48
- embeddings = []
49
- batch_size = 8
50
-
51
- for i in range(0, len(texts), batch_size):
52
- batch = texts[i:i + batch_size]
53
- inputs = ctx_tokenizer(
54
- batch,
55
- max_length=512,
56
- padding=True,
57
- truncation=True,
58
- return_tensors="pt"
59
- )
60
-
61
- with torch.no_grad():
62
- outputs = ctx_encoder(**inputs)
63
- batch_embeddings = outputs.pooler_output.cpu().numpy()
64
- embeddings.append(batch_embeddings)
65
-
66
- embeddings = np.vstack(embeddings)
67
- logging.info(f"Created embeddings with shape {embeddings.shape}")
68
-
69
- # Create dataset with embeddings
70
- dataset = Dataset.from_dict({
71
- "id": [p["id"] for p in papers],
72
- "text": [p["text"] for p in papers],
73
- "title": [p["title"] for p in papers],
74
- "embeddings": [emb.tolist() for emb in embeddings],
75
- })
76
-
77
- logging.info(f"Created dataset with {len(dataset)} papers")
78
-
79
- # Create FAISS index from embeddings
80
- dimension = embeddings.shape[1]
81
- index = faiss.IndexFlatL2(dimension)
82
- index.add(embeddings.astype(np.float32))
83
-
84
- # Save everything
85
- os.makedirs(dataset_dir, exist_ok=True)
86
- dataset_path = os.path.join(dataset_dir, "dataset")
87
- index_path = os.path.join(dataset_dir, "embeddings.faiss")
88
-
89
- # Save dataset and index
90
- dataset.save_to_disk(dataset_path)
91
- faiss.write_index(index, index_path)
92
-
93
- logging.info(f"Saved dataset to {dataset_path}")
94
- logging.info(f"Saved index to {index_path}")
95
-
96
- return dataset_dir
97
-
98
- except Exception as e:
99
- logging.error(f"Error building index: {str(e)}")
100
- raise
101
 
102
  # Example usage
103
  if __name__ == "__main__":
104
- query = "quantum computing"
105
- papers = fetch_arxiv_papers(query)
106
  build_faiss_index(papers)
 
12
 
13
  def fetch_arxiv_papers(query, max_results=10):
14
  """Fetch papers from arXiv and format them for RAG"""
15
+ client = arxiv.Client()
16
+ search = arxiv.Search(
17
+ query=query,
18
+ max_results=max_results,
19
+ sort_by=arxiv.SortCriterion.SubmittedDate
20
+ )
21
+ results = list(client.results(search))
22
+ papers = [{"id": str(i), "text": result.summary, "title": result.title} for i, result in enumerate(results)]
23
+ logging.info(f"Fetched {len(papers)} papers from arXiv")
24
+ return papers
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ def build_faiss_index(papers, dataset_dir="/data/rag_dataset"):
27
  """Build and save dataset with FAISS index for RAG"""
28
+ # Initialize DPR encoder
29
+ ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
30
+ ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
31
+
32
+ # Create embeddings
33
+ texts = [p["text"] for p in papers]
34
+ embeddings = []
35
+ batch_size = 8
36
+ for i in range(0, len(texts), batch_size):
37
+ batch = texts[i:i + batch_size]
38
+ inputs = ctx_tokenizer(batch, max_length=512, padding=True, truncation=True, return_tensors="pt")
39
+ with torch.no_grad():
40
+ outputs = ctx_encoder(**inputs)
41
+ batch_embeddings = outputs.pooler_output.cpu().numpy()
42
+ embeddings.append(batch_embeddings)
43
+
44
+ embeddings = np.vstack(embeddings)
45
+ logging.info(f"Created embeddings with shape {embeddings.shape}")
46
+
47
+ # Create dataset
48
+ dataset = Dataset.from_dict({
49
+ "id": [p["id"] for p in papers],
50
+ "text": [p["text"] for p in papers],
51
+ "title": [p["title"] for p in papers],
52
+ "embeddings": [emb.tolist() for emb in embeddings],
53
+ })
54
+
55
+ # Create FAISS index
56
+ dimension = embeddings.shape[1]
57
+ index = faiss.IndexFlatL2(dimension)
58
+ index.add(embeddings.astype(np.float32))
59
+
60
+ # Save dataset and index
61
+ os.makedirs(dataset_dir, exist_ok=True)
62
+ dataset_path = os.path.join(dataset_dir, "dataset")
63
+ index_path = os.path.join(dataset_dir, "embeddings.faiss")
64
+ dataset.save_to_disk(dataset_path)
65
+ faiss.write_index(index, index_path)
66
+ logging.info(f"Saved dataset to {dataset_path}")
67
+ logging.info(f"Saved index to {index_path}")
68
+ return dataset_dir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  # Example usage
71
  if __name__ == "__main__":
72
+ query = "autism research"
73
+ papers = fetch_arxiv_papers(query, max_results=100)
74
  build_faiss_index(papers)