wuhp commited on
Commit
81dea5d
·
verified ·
1 Parent(s): 2555047

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -48
app.py CHANGED
@@ -16,19 +16,13 @@ from transformers import (
16
  BitsAndBytesConfig,
17
  )
18
 
19
- # PEFT (LoRA / QLoRA)
20
  from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, PeftModel
21
 
22
- # For embeddings
23
  from sentence_transformers import SentenceTransformer
24
 
25
- ##############################################################################
26
- # QLoRA Demo Setup
27
- ##############################################################################
28
-
29
  TEXT_PIPELINE = None
30
  COMPARISON_PIPELINE = None
31
- NUM_EXAMPLES = 50 # We'll train on 50 rows for demonstration
32
 
33
  @spaces.GPU(duration=300)
34
  def finetune_small_subset():
@@ -40,7 +34,6 @@ def finetune_small_subset():
40
  5) Reloads LoRA adapters for inference in a pipeline.
41
  """
42
 
43
- # --- 1) Load a small subset of the Magpie dataset ---
44
  ds = load_dataset(
45
  "Magpie-Align/Magpie-Reasoning-V2-250K-CoT-Deepseek-R1-Llama-70B",
46
  split="train"
@@ -52,10 +45,9 @@ def finetune_small_subset():
52
 
53
  ds = ds.select(range(min(NUM_EXAMPLES, len(ds))))
54
 
55
- # --- 2) Setup 4-bit quantization ---
56
  bnb_config = BitsAndBytesConfig(
57
  load_in_4bit=True,
58
- bnb_4bit_compute_dtype=torch.bfloat16, # or torch.float16
59
  bnb_4bit_use_double_quant=True,
60
  bnb_4bit_quant_type="nf4",
61
  )
@@ -75,14 +67,13 @@ def finetune_small_subset():
75
  "wuhp/myr1",
76
  subfolder="myr1",
77
  config=config,
78
- quantization_config=bnb_config, # <--- QLoRA 4-bit
79
  device_map="auto",
80
  trust_remote_code=True
81
  )
82
 
83
  base_model = prepare_model_for_kbit_training(base_model)
84
 
85
- # --- 3) Create LoRA config & wrap the base model in LoRA ---
86
  lora_config = LoraConfig(
87
  r=16,
88
  lora_alpha=32,
@@ -93,7 +84,6 @@ def finetune_small_subset():
93
  )
94
  lora_model = get_peft_model(base_model, lora_config)
95
 
96
- # --- 4) Tokenize dataset ---
97
  def tokenize_fn(ex):
98
  text = (
99
  f"Instruction: {ex['instruction']}\n\n"
@@ -106,7 +96,6 @@ def finetune_small_subset():
106
 
107
  collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
108
 
109
- # Training args
110
  training_args = TrainingArguments(
111
  output_dir="finetuned_myr1",
112
  num_train_epochs=1,
@@ -126,11 +115,9 @@ def finetune_small_subset():
126
  )
127
  trainer.train()
128
 
129
- # --- 5) Save LoRA adapter + tokenizer ---
130
  trainer.model.save_pretrained("finetuned_myr1")
131
  tokenizer.save_pretrained("finetuned_myr1")
132
 
133
- # --- 6) Reload for inference
134
  base_model_2 = AutoModelForCausalLM.from_pretrained(
135
  "wuhp/myr1",
136
  subfolder="myr1",
@@ -235,9 +222,6 @@ def compare_models(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
235
  )
236
  return local_out[0]["generated_text"], comp_out[0]["generated_text"]
237
 
238
- ###############################################################################
239
- # Retrieval-Augmented Memory with FAISS
240
- ###############################################################################
241
  class ConversationRetriever:
242
  """
243
  A simple in-memory store + FAISS for retrieval of conversation chunks.
@@ -253,12 +237,10 @@ class ConversationRetriever:
253
  self.embed_model = SentenceTransformer(model_name)
254
  self.embed_dim = embed_dim
255
 
256
- # We'll store (text, vector) in FAISS. For metadata, store in python list/dict.
257
- # For a real app, you'd probably want a more robust store.
258
  self.index = faiss.IndexFlatL2(embed_dim)
259
- self.texts = [] # store the raw text chunks
260
- self.vectors = [] # store vectors (redundant but simpler to show)
261
- self.ids = [] # store an integer ID or similar
262
 
263
  self.id_counter = 0
264
 
@@ -271,7 +253,7 @@ class ConversationRetriever:
271
  return
272
 
273
  emb = self.embed_model.encode([text], convert_to_numpy=True)
274
- vec = emb[0].astype(np.float32) # shape [embed_dim]
275
  self.index.add(vec.reshape(1, -1))
276
 
277
  self.texts.append(text)
@@ -288,17 +270,13 @@ class ConversationRetriever:
288
  q_vec = q_emb[0].reshape(1, -1)
289
  distances, indices = self.index.search(q_vec, top_k)
290
 
291
- # indices is shape [1, top_k], distances is shape [1, top_k]
292
  results = []
293
  for dist, idx in zip(distances[0], indices[0]):
294
- if idx < len(self.texts): # safety check
295
  results.append((self.texts[idx], dist))
296
  return results
297
 
298
- ###############################################################################
299
- # Build a Chat that uses RAG
300
- ###############################################################################
301
- retriever = ConversationRetriever() # global retriever instance
302
 
303
  def build_rag_prompt(user_query, retrieved_chunks):
304
  """
@@ -331,18 +309,13 @@ def chat_rag(user_input, history, temperature, top_p, min_new_tokens, max_new_to
331
  """
332
  pipe = ensure_pipeline()
333
 
334
- # 1) Add the user input as a chunk to the retriever DB.
335
  retriever.add_text(f"User: {user_input}")
336
 
337
- # 2) Retrieve top-3 older chunks. We can skip the chunk we just added if we want to
338
- # (since it's the same text), but for simplicity let's just do a search for user_input.
339
  top_k = 3
340
  results = retriever.search(user_input, top_k=top_k)
341
 
342
- # 3) Build final prompt
343
  prompt = build_rag_prompt(user_input, results)
344
 
345
- # 4) Generate
346
  output = pipe(
347
  prompt,
348
  temperature=float(temperature),
@@ -352,23 +325,16 @@ def chat_rag(user_input, history, temperature, top_p, min_new_tokens, max_new_to
352
  do_sample=True
353
  )[0]["generated_text"]
354
 
355
- # We only want the new part after "Assistant:"
356
- # Because the pipeline output includes the entire prompt + new text.
357
  if output.startswith(prompt):
358
  assistant_reply = output[len(prompt):].strip()
359
  else:
360
  assistant_reply = output.strip()
361
 
362
- # 5) Add the assistant's response to the DB as well
363
  retriever.add_text(f"Assistant: {assistant_reply}")
364
 
365
- # 6) Update the chat history for display in the Gradio Chatbot
366
  history.append([user_input, assistant_reply])
367
  return history, history
368
 
369
- ###############################################################################
370
- # Gradio UI
371
- ###############################################################################
372
  with gr.Blocks() as demo:
373
  gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo")
374
 
@@ -377,7 +343,6 @@ with gr.Blocks() as demo:
377
 
378
  finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
379
 
380
- # Simple generation UI (no retrieval):
381
  gr.Markdown("## Direct Generation (No Retrieval)")
382
  prompt_in = gr.Textbox(lines=3, label="Prompt")
383
  temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature")
@@ -393,7 +358,6 @@ with gr.Blocks() as demo:
393
  outputs=output_box
394
  )
395
 
396
- # Comparison UI:
397
  gr.Markdown("## Compare myr1 vs DeepSeek")
398
  compare_btn = gr.Button("Compare")
399
  out_local = gr.Textbox(label="myr1 Output", lines=6)
@@ -404,12 +368,11 @@ with gr.Blocks() as demo:
404
  outputs=[out_local, out_deepseek]
405
  )
406
 
407
- # RAG-based Chat
408
  gr.Markdown("## Chat with Retrieval-Augmented Memory")
409
  with gr.Row():
410
  with gr.Column():
411
  chatbot = gr.Chatbot(label="RAG Chat")
412
- chat_state = gr.State([]) # just for display
413
 
414
  user_input = gr.Textbox(
415
  show_label=False,
@@ -418,7 +381,6 @@ with gr.Blocks() as demo:
418
  )
419
  send_btn = gr.Button("Send")
420
 
421
- # On user submit, call chat_rag
422
  user_input.submit(
423
  fn=chat_rag,
424
  inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens],
 
16
  BitsAndBytesConfig,
17
  )
18
 
 
19
  from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, PeftModel
20
 
 
21
  from sentence_transformers import SentenceTransformer
22
 
 
 
 
 
23
  TEXT_PIPELINE = None
24
  COMPARISON_PIPELINE = None
25
+ NUM_EXAMPLES = 50
26
 
27
  @spaces.GPU(duration=300)
28
  def finetune_small_subset():
 
34
  5) Reloads LoRA adapters for inference in a pipeline.
35
  """
36
 
 
37
  ds = load_dataset(
38
  "Magpie-Align/Magpie-Reasoning-V2-250K-CoT-Deepseek-R1-Llama-70B",
39
  split="train"
 
45
 
46
  ds = ds.select(range(min(NUM_EXAMPLES, len(ds))))
47
 
 
48
  bnb_config = BitsAndBytesConfig(
49
  load_in_4bit=True,
50
+ bnb_4bit_compute_dtype=torch.bfloat16,
51
  bnb_4bit_use_double_quant=True,
52
  bnb_4bit_quant_type="nf4",
53
  )
 
67
  "wuhp/myr1",
68
  subfolder="myr1",
69
  config=config,
70
+ quantization_config=bnb_config,
71
  device_map="auto",
72
  trust_remote_code=True
73
  )
74
 
75
  base_model = prepare_model_for_kbit_training(base_model)
76
 
 
77
  lora_config = LoraConfig(
78
  r=16,
79
  lora_alpha=32,
 
84
  )
85
  lora_model = get_peft_model(base_model, lora_config)
86
 
 
87
  def tokenize_fn(ex):
88
  text = (
89
  f"Instruction: {ex['instruction']}\n\n"
 
96
 
97
  collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
98
 
 
99
  training_args = TrainingArguments(
100
  output_dir="finetuned_myr1",
101
  num_train_epochs=1,
 
115
  )
116
  trainer.train()
117
 
 
118
  trainer.model.save_pretrained("finetuned_myr1")
119
  tokenizer.save_pretrained("finetuned_myr1")
120
 
 
121
  base_model_2 = AutoModelForCausalLM.from_pretrained(
122
  "wuhp/myr1",
123
  subfolder="myr1",
 
222
  )
223
  return local_out[0]["generated_text"], comp_out[0]["generated_text"]
224
 
 
 
 
225
  class ConversationRetriever:
226
  """
227
  A simple in-memory store + FAISS for retrieval of conversation chunks.
 
237
  self.embed_model = SentenceTransformer(model_name)
238
  self.embed_dim = embed_dim
239
 
 
 
240
  self.index = faiss.IndexFlatL2(embed_dim)
241
+ self.texts = []
242
+ self.vectors = []
243
+ self.ids = []
244
 
245
  self.id_counter = 0
246
 
 
253
  return
254
 
255
  emb = self.embed_model.encode([text], convert_to_numpy=True)
256
+ vec = emb[0].astype(np.float32)
257
  self.index.add(vec.reshape(1, -1))
258
 
259
  self.texts.append(text)
 
270
  q_vec = q_emb[0].reshape(1, -1)
271
  distances, indices = self.index.search(q_vec, top_k)
272
 
 
273
  results = []
274
  for dist, idx in zip(distances[0], indices[0]):
275
+ if idx < len(self.texts):
276
  results.append((self.texts[idx], dist))
277
  return results
278
 
279
+ retriever = ConversationRetriever()
 
 
 
280
 
281
  def build_rag_prompt(user_query, retrieved_chunks):
282
  """
 
309
  """
310
  pipe = ensure_pipeline()
311
 
 
312
  retriever.add_text(f"User: {user_input}")
313
 
 
 
314
  top_k = 3
315
  results = retriever.search(user_input, top_k=top_k)
316
 
 
317
  prompt = build_rag_prompt(user_input, results)
318
 
 
319
  output = pipe(
320
  prompt,
321
  temperature=float(temperature),
 
325
  do_sample=True
326
  )[0]["generated_text"]
327
 
 
 
328
  if output.startswith(prompt):
329
  assistant_reply = output[len(prompt):].strip()
330
  else:
331
  assistant_reply = output.strip()
332
 
 
333
  retriever.add_text(f"Assistant: {assistant_reply}")
334
 
 
335
  history.append([user_input, assistant_reply])
336
  return history, history
337
 
 
 
 
338
  with gr.Blocks() as demo:
339
  gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo")
340
 
 
343
 
344
  finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
345
 
 
346
  gr.Markdown("## Direct Generation (No Retrieval)")
347
  prompt_in = gr.Textbox(lines=3, label="Prompt")
348
  temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature")
 
358
  outputs=output_box
359
  )
360
 
 
361
  gr.Markdown("## Compare myr1 vs DeepSeek")
362
  compare_btn = gr.Button("Compare")
363
  out_local = gr.Textbox(label="myr1 Output", lines=6)
 
368
  outputs=[out_local, out_deepseek]
369
  )
370
 
 
371
  gr.Markdown("## Chat with Retrieval-Augmented Memory")
372
  with gr.Row():
373
  with gr.Column():
374
  chatbot = gr.Chatbot(label="RAG Chat")
375
+ chat_state = gr.State([])
376
 
377
  user_input = gr.Textbox(
378
  show_label=False,
 
381
  )
382
  send_btn = gr.Button("Send")
383
 
 
384
  user_input.submit(
385
  fn=chat_rag,
386
  inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens],