VictorTomas09 commited on
Commit
970694a
·
verified ·
1 Parent(s): ff2408d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -42
app.py CHANGED
@@ -28,20 +28,19 @@ DIST_THRESHOLD = float(os.getenv("DIST_THRESHOLD", 1.0))
28
  MAX_CTX_WORDS = int(os.getenv("MAX_CTX_WORDS", 200))
29
 
30
  DEVICE = 0 if torch.cuda.is_available() else -1
31
-
32
  os.makedirs(DATA_DIR, exist_ok=True)
33
 
34
- print(f"Using MODEL_NAME={MODEL_NAME}, EMBEDDER_MODEL={EMBEDDER_MODEL}, device={'GPU' if DEVICE==0 else 'CPU'}")
35
 
36
  # ── 2. Helpers ──
37
  def make_context_snippets(contexts, max_words=MAX_CTX_WORDS):
38
- out = []
39
  for c in contexts:
40
  words = c.split()
41
  if len(words) > max_words:
42
  c = " ".join(words[:max_words]) + " ... [truncated]"
43
- out.append(c)
44
- return out
45
 
46
  def chunk_text(text, max_tokens, stride=None):
47
  words = text.split()
@@ -57,20 +56,25 @@ def chunk_text(text, max_tokens, stride=None):
57
  # ── 3. Load & preprocess passages ──
58
  def load_passages():
59
  # 3.1 load raw corpora
60
- wiki = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages")["passage"]
61
- squad = load_dataset("rajpurkar/squad_v2", split="train[:100]")["context"]
62
  trivia_ds = load_dataset("mandarjoshi/trivia_qa", "rc", split="validation[:100]")
63
- trivia = []
 
 
 
64
  for ex in trivia_ds:
65
  for fld in ("wiki_context", "search_context"):
66
  txt = ex.get(fld) or ""
67
- if txt: trivia.append(txt)
 
68
 
69
- all_passages = list(dict.fromkeys(wiki + squad + trivia))
70
- # 3.2 chunk long passages
71
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
72
- max_tokens = tokenizer.model_max_length
73
 
 
 
 
74
  chunks = []
75
  for p in all_passages:
76
  toks = tokenizer.tokenize(p)
@@ -86,20 +90,24 @@ def load_passages():
86
 
87
  # ── 4. Build or load FAISS ──
88
  def load_faiss_index(passages):
89
- # sentence‐transformers embedder + cross‐encoder
90
  embedder = SentenceTransformer(EMBEDDER_MODEL)
91
  reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
92
 
93
  if os.path.exists(INDEX_PATH) and os.path.exists(EMB_PATH):
94
- print("Loading FAISS index & embeddings from disk …")
95
- index = faiss.read_index(INDEX_PATH)
96
  embeddings = np.load(EMB_PATH)
97
  else:
98
- print("Encoding passages & building FAISS index …")
99
- embeddings = embedder.encode(passages, show_progress_bar=True, convert_to_numpy=True, batch_size=32)
 
 
 
 
 
100
  embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
101
 
102
- dim = embeddings.shape[1]
103
  index = faiss.IndexFlatIP(dim)
104
  index.add(embeddings)
105
 
@@ -108,9 +116,8 @@ def load_faiss_index(passages):
108
 
109
  return embedder, reranker, index
110
 
111
- # ── 5. Set up RAG pipeline ──
112
  def setup_rag():
113
- # 5.1 load or build index + embedder/reranker
114
  if os.path.exists(PCTX_PATH):
115
  with open(PCTX_PATH, "rb") as f:
116
  passages = pickle.load(f)
@@ -119,8 +126,7 @@ def setup_rag():
119
 
120
  embedder, reranker, index = load_faiss_index(passages)
121
 
122
- # 5.2 load generator model & HF pipeline
123
- tok = AutoTokenizer.from_pretrained(MODEL_NAME)
124
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
125
  qa_pipe = hf_pipeline(
126
  "text2text-generation",
@@ -129,28 +135,28 @@ def setup_rag():
129
  device=DEVICE,
130
  truncation=True,
131
  max_length=512,
132
- num_beams=4, # optional: enable beam search
133
  early_stopping=True
134
  )
135
 
136
  return passages, embedder, reranker, index, qa_pipe
137
 
138
- # ── 6. Retrieval + Generation ──
139
  def retrieve(question, passages, embedder, index, k=20, rerank_k=5):
140
- q_emb = embedder.encode([question], convert_to_numpy=True)
141
  distances, idxs = index.search(q_emb, k)
142
 
143
- cands = [passages[i] for i in idxs[0]]
144
  scores = reranker.predict([[question, c] for c in cands])
145
- top = np.argsort(scores)[-rerank_k:][::-1]
146
 
147
- final_ctxs = [cands[i] for i in top]
148
- final_dists = [distances[0][i] for i in top]
149
- return final_ctxs, final_dists
150
 
151
  def generate(question, contexts, qa_pipe):
152
- lines = [ f"Context {i+1}: {s}"
153
- for i,s in enumerate(make_context_snippets(contexts)) ]
 
 
154
  prompt = (
155
  "You are a helpful assistant. Use ONLY the following contexts to answer. "
156
  "If the answer is not contained, say 'Sorry, I don't know.'\n\n"
@@ -160,20 +166,18 @@ def generate(question, contexts, qa_pipe):
160
  return qa_pipe(prompt)[0]["generated_text"].strip()
161
 
162
  def retrieve_and_answer(question, passages, embedder, reranker, index, qa_pipe):
163
- ctxs, dists = retrieve(question, passages, embedder, index)
164
- if not ctxs or dists[0] > DIST_THRESHOLD:
165
  return "Sorry, I don't know.", []
166
- ans = generate(question, ctxs, qa_pipe)
167
- return ans, ctxs
168
 
169
- def answer_and_contexts(question,
170
- passages, embedder, reranker, index, qa_pipe):
171
  ans, ctxs = retrieve_and_answer(question, passages, embedder, reranker, index, qa_pipe)
172
  if not ctxs:
173
  return ans, ""
174
  snippets = [
175
- f"Context {i+1}: {s}"
176
- for i,s in enumerate(make_context_snippets(ctxs))
177
  ]
178
  return ans, "\n\n---\n\n".join(snippets)
179
 
@@ -191,7 +195,8 @@ def main():
191
  "When was Abraham Lincoln inaugurated?",
192
  "What is the capital of France?",
193
  "Who wrote '1984'?"
194
- ]
 
195
  )
196
  demo.launch()
197
 
 
28
  MAX_CTX_WORDS = int(os.getenv("MAX_CTX_WORDS", 200))
29
 
30
  DEVICE = 0 if torch.cuda.is_available() else -1
 
31
  os.makedirs(DATA_DIR, exist_ok=True)
32
 
33
+ print(f"MODEL={MODEL_NAME}, EMBEDDER={EMBEDDER_MODEL}, DEVICE={'GPU' if DEVICE==0 else 'CPU'}")
34
 
35
  # ── 2. Helpers ──
36
  def make_context_snippets(contexts, max_words=MAX_CTX_WORDS):
37
+ snippets = []
38
  for c in contexts:
39
  words = c.split()
40
  if len(words) > max_words:
41
  c = " ".join(words[:max_words]) + " ... [truncated]"
42
+ snippets.append(c)
43
+ return snippets
44
 
45
  def chunk_text(text, max_tokens, stride=None):
46
  words = text.split()
 
56
  # ── 3. Load & preprocess passages ──
57
  def load_passages():
58
  # 3.1 load raw corpora
59
+ wiki_ds = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages")
60
+ squad_ds = load_dataset("rajpurkar/squad_v2", split="train[:100]")
61
  trivia_ds = load_dataset("mandarjoshi/trivia_qa", "rc", split="validation[:100]")
62
+
63
+ wiki_passages = wiki_ds["passage"]
64
+ squad_passages = [ex["context"] for ex in squad_ds]
65
+ trivia_passages = []
66
  for ex in trivia_ds:
67
  for fld in ("wiki_context", "search_context"):
68
  txt = ex.get(fld) or ""
69
+ if txt:
70
+ trivia_passages.append(txt)
71
 
72
+ # dedupe
73
+ all_passages = list(dict.fromkeys(wiki_passages + squad_passages + trivia_passages))
 
 
74
 
75
+ # chunk long passages
76
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
77
+ max_tokens = tokenizer.model_max_length
78
  chunks = []
79
  for p in all_passages:
80
  toks = tokenizer.tokenize(p)
 
90
 
91
  # ── 4. Build or load FAISS ──
92
  def load_faiss_index(passages):
 
93
  embedder = SentenceTransformer(EMBEDDER_MODEL)
94
  reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
95
 
96
  if os.path.exists(INDEX_PATH) and os.path.exists(EMB_PATH):
97
+ print("Loading FAISS index & embeddings…")
98
+ index = faiss.read_index(INDEX_PATH)
99
  embeddings = np.load(EMB_PATH)
100
  else:
101
+ print("Encoding passages & building FAISS index…")
102
+ embeddings = embedder.encode(
103
+ passages,
104
+ show_progress_bar=True,
105
+ convert_to_numpy=True,
106
+ batch_size=32
107
+ )
108
  embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
109
 
110
+ dim = embeddings.shape[1]
111
  index = faiss.IndexFlatIP(dim)
112
  index.add(embeddings)
113
 
 
116
 
117
  return embedder, reranker, index
118
 
119
+ # ── 5. Initialize RAG components ──
120
  def setup_rag():
 
121
  if os.path.exists(PCTX_PATH):
122
  with open(PCTX_PATH, "rb") as f:
123
  passages = pickle.load(f)
 
126
 
127
  embedder, reranker, index = load_faiss_index(passages)
128
 
129
+ tok = AutoTokenizer.from_pretrained(MODEL_NAME)
 
130
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
131
  qa_pipe = hf_pipeline(
132
  "text2text-generation",
 
135
  device=DEVICE,
136
  truncation=True,
137
  max_length=512,
138
+ num_beams=4,
139
  early_stopping=True
140
  )
141
 
142
  return passages, embedder, reranker, index, qa_pipe
143
 
144
+ # ── 6. Retrieval & generation ──
145
  def retrieve(question, passages, embedder, index, k=20, rerank_k=5):
146
+ q_emb = embedder.encode([question], convert_to_numpy=True)
147
  distances, idxs = index.search(q_emb, k)
148
 
149
+ cands = [passages[i] for i in idxs[0]]
150
  scores = reranker.predict([[question, c] for c in cands])
151
+ top = np.argsort(scores)[-rerank_k:][::-1]
152
 
153
+ return [cands[i] for i in top], [distances[0][i] for i in top]
 
 
154
 
155
  def generate(question, contexts, qa_pipe):
156
+ lines = [
157
+ f"Context {i+1}: {s}"
158
+ for i, s in enumerate(make_context_snippets(contexts))
159
+ ]
160
  prompt = (
161
  "You are a helpful assistant. Use ONLY the following contexts to answer. "
162
  "If the answer is not contained, say 'Sorry, I don't know.'\n\n"
 
166
  return qa_pipe(prompt)[0]["generated_text"].strip()
167
 
168
  def retrieve_and_answer(question, passages, embedder, reranker, index, qa_pipe):
169
+ contexts, dists = retrieve(question, passages, embedder, index)
170
+ if not contexts or dists[0] > DIST_THRESHOLD:
171
  return "Sorry, I don't know.", []
172
+ return generate(question, contexts, qa_pipe), contexts
 
173
 
174
+ def answer_and_contexts(question, passages, embedder, reranker, index, qa_pipe):
 
175
  ans, ctxs = retrieve_and_answer(question, passages, embedder, reranker, index, qa_pipe)
176
  if not ctxs:
177
  return ans, ""
178
  snippets = [
179
+ f"Context {i+1}: {s}"
180
+ for i, s in enumerate(make_context_snippets(ctxs))
181
  ]
182
  return ans, "\n\n---\n\n".join(snippets)
183
 
 
195
  "When was Abraham Lincoln inaugurated?",
196
  "What is the capital of France?",
197
  "Who wrote '1984'?"
198
+ ],
199
+ allow_flagging="never",
200
  )
201
  demo.launch()
202