Update app.py
Browse files
app.py
CHANGED
@@ -16,19 +16,13 @@ from transformers import (
|
|
16 |
BitsAndBytesConfig,
|
17 |
)
|
18 |
|
19 |
-
# PEFT (LoRA / QLoRA)
|
20 |
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, PeftModel
|
21 |
|
22 |
-
# For embeddings
|
23 |
from sentence_transformers import SentenceTransformer
|
24 |
|
25 |
-
##############################################################################
|
26 |
-
# QLoRA Demo Setup
|
27 |
-
##############################################################################
|
28 |
-
|
29 |
TEXT_PIPELINE = None
|
30 |
COMPARISON_PIPELINE = None
|
31 |
-
NUM_EXAMPLES = 50
|
32 |
|
33 |
@spaces.GPU(duration=300)
|
34 |
def finetune_small_subset():
|
@@ -40,7 +34,6 @@ def finetune_small_subset():
|
|
40 |
5) Reloads LoRA adapters for inference in a pipeline.
|
41 |
"""
|
42 |
|
43 |
-
# --- 1) Load a small subset of the Magpie dataset ---
|
44 |
ds = load_dataset(
|
45 |
"Magpie-Align/Magpie-Reasoning-V2-250K-CoT-Deepseek-R1-Llama-70B",
|
46 |
split="train"
|
@@ -52,10 +45,9 @@ def finetune_small_subset():
|
|
52 |
|
53 |
ds = ds.select(range(min(NUM_EXAMPLES, len(ds))))
|
54 |
|
55 |
-
# --- 2) Setup 4-bit quantization ---
|
56 |
bnb_config = BitsAndBytesConfig(
|
57 |
load_in_4bit=True,
|
58 |
-
bnb_4bit_compute_dtype=torch.bfloat16,
|
59 |
bnb_4bit_use_double_quant=True,
|
60 |
bnb_4bit_quant_type="nf4",
|
61 |
)
|
@@ -75,14 +67,13 @@ def finetune_small_subset():
|
|
75 |
"wuhp/myr1",
|
76 |
subfolder="myr1",
|
77 |
config=config,
|
78 |
-
quantization_config=bnb_config,
|
79 |
device_map="auto",
|
80 |
trust_remote_code=True
|
81 |
)
|
82 |
|
83 |
base_model = prepare_model_for_kbit_training(base_model)
|
84 |
|
85 |
-
# --- 3) Create LoRA config & wrap the base model in LoRA ---
|
86 |
lora_config = LoraConfig(
|
87 |
r=16,
|
88 |
lora_alpha=32,
|
@@ -93,7 +84,6 @@ def finetune_small_subset():
|
|
93 |
)
|
94 |
lora_model = get_peft_model(base_model, lora_config)
|
95 |
|
96 |
-
# --- 4) Tokenize dataset ---
|
97 |
def tokenize_fn(ex):
|
98 |
text = (
|
99 |
f"Instruction: {ex['instruction']}\n\n"
|
@@ -106,7 +96,6 @@ def finetune_small_subset():
|
|
106 |
|
107 |
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
108 |
|
109 |
-
# Training args
|
110 |
training_args = TrainingArguments(
|
111 |
output_dir="finetuned_myr1",
|
112 |
num_train_epochs=1,
|
@@ -126,11 +115,9 @@ def finetune_small_subset():
|
|
126 |
)
|
127 |
trainer.train()
|
128 |
|
129 |
-
# --- 5) Save LoRA adapter + tokenizer ---
|
130 |
trainer.model.save_pretrained("finetuned_myr1")
|
131 |
tokenizer.save_pretrained("finetuned_myr1")
|
132 |
|
133 |
-
# --- 6) Reload for inference
|
134 |
base_model_2 = AutoModelForCausalLM.from_pretrained(
|
135 |
"wuhp/myr1",
|
136 |
subfolder="myr1",
|
@@ -235,9 +222,6 @@ def compare_models(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
|
|
235 |
)
|
236 |
return local_out[0]["generated_text"], comp_out[0]["generated_text"]
|
237 |
|
238 |
-
###############################################################################
|
239 |
-
# Retrieval-Augmented Memory with FAISS
|
240 |
-
###############################################################################
|
241 |
class ConversationRetriever:
|
242 |
"""
|
243 |
A simple in-memory store + FAISS for retrieval of conversation chunks.
|
@@ -253,12 +237,10 @@ class ConversationRetriever:
|
|
253 |
self.embed_model = SentenceTransformer(model_name)
|
254 |
self.embed_dim = embed_dim
|
255 |
|
256 |
-
# We'll store (text, vector) in FAISS. For metadata, store in python list/dict.
|
257 |
-
# For a real app, you'd probably want a more robust store.
|
258 |
self.index = faiss.IndexFlatL2(embed_dim)
|
259 |
-
self.texts = []
|
260 |
-
self.vectors = []
|
261 |
-
self.ids = []
|
262 |
|
263 |
self.id_counter = 0
|
264 |
|
@@ -271,7 +253,7 @@ class ConversationRetriever:
|
|
271 |
return
|
272 |
|
273 |
emb = self.embed_model.encode([text], convert_to_numpy=True)
|
274 |
-
vec = emb[0].astype(np.float32)
|
275 |
self.index.add(vec.reshape(1, -1))
|
276 |
|
277 |
self.texts.append(text)
|
@@ -288,17 +270,13 @@ class ConversationRetriever:
|
|
288 |
q_vec = q_emb[0].reshape(1, -1)
|
289 |
distances, indices = self.index.search(q_vec, top_k)
|
290 |
|
291 |
-
# indices is shape [1, top_k], distances is shape [1, top_k]
|
292 |
results = []
|
293 |
for dist, idx in zip(distances[0], indices[0]):
|
294 |
-
if idx < len(self.texts):
|
295 |
results.append((self.texts[idx], dist))
|
296 |
return results
|
297 |
|
298 |
-
|
299 |
-
# Build a Chat that uses RAG
|
300 |
-
###############################################################################
|
301 |
-
retriever = ConversationRetriever() # global retriever instance
|
302 |
|
303 |
def build_rag_prompt(user_query, retrieved_chunks):
|
304 |
"""
|
@@ -331,18 +309,13 @@ def chat_rag(user_input, history, temperature, top_p, min_new_tokens, max_new_to
|
|
331 |
"""
|
332 |
pipe = ensure_pipeline()
|
333 |
|
334 |
-
# 1) Add the user input as a chunk to the retriever DB.
|
335 |
retriever.add_text(f"User: {user_input}")
|
336 |
|
337 |
-
# 2) Retrieve top-3 older chunks. We can skip the chunk we just added if we want to
|
338 |
-
# (since it's the same text), but for simplicity let's just do a search for user_input.
|
339 |
top_k = 3
|
340 |
results = retriever.search(user_input, top_k=top_k)
|
341 |
|
342 |
-
# 3) Build final prompt
|
343 |
prompt = build_rag_prompt(user_input, results)
|
344 |
|
345 |
-
# 4) Generate
|
346 |
output = pipe(
|
347 |
prompt,
|
348 |
temperature=float(temperature),
|
@@ -352,23 +325,16 @@ def chat_rag(user_input, history, temperature, top_p, min_new_tokens, max_new_to
|
|
352 |
do_sample=True
|
353 |
)[0]["generated_text"]
|
354 |
|
355 |
-
# We only want the new part after "Assistant:"
|
356 |
-
# Because the pipeline output includes the entire prompt + new text.
|
357 |
if output.startswith(prompt):
|
358 |
assistant_reply = output[len(prompt):].strip()
|
359 |
else:
|
360 |
assistant_reply = output.strip()
|
361 |
|
362 |
-
# 5) Add the assistant's response to the DB as well
|
363 |
retriever.add_text(f"Assistant: {assistant_reply}")
|
364 |
|
365 |
-
# 6) Update the chat history for display in the Gradio Chatbot
|
366 |
history.append([user_input, assistant_reply])
|
367 |
return history, history
|
368 |
|
369 |
-
###############################################################################
|
370 |
-
# Gradio UI
|
371 |
-
###############################################################################
|
372 |
with gr.Blocks() as demo:
|
373 |
gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo")
|
374 |
|
@@ -377,7 +343,6 @@ with gr.Blocks() as demo:
|
|
377 |
|
378 |
finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
|
379 |
|
380 |
-
# Simple generation UI (no retrieval):
|
381 |
gr.Markdown("## Direct Generation (No Retrieval)")
|
382 |
prompt_in = gr.Textbox(lines=3, label="Prompt")
|
383 |
temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature")
|
@@ -393,7 +358,6 @@ with gr.Blocks() as demo:
|
|
393 |
outputs=output_box
|
394 |
)
|
395 |
|
396 |
-
# Comparison UI:
|
397 |
gr.Markdown("## Compare myr1 vs DeepSeek")
|
398 |
compare_btn = gr.Button("Compare")
|
399 |
out_local = gr.Textbox(label="myr1 Output", lines=6)
|
@@ -404,12 +368,11 @@ with gr.Blocks() as demo:
|
|
404 |
outputs=[out_local, out_deepseek]
|
405 |
)
|
406 |
|
407 |
-
# RAG-based Chat
|
408 |
gr.Markdown("## Chat with Retrieval-Augmented Memory")
|
409 |
with gr.Row():
|
410 |
with gr.Column():
|
411 |
chatbot = gr.Chatbot(label="RAG Chat")
|
412 |
-
chat_state = gr.State([])
|
413 |
|
414 |
user_input = gr.Textbox(
|
415 |
show_label=False,
|
@@ -418,7 +381,6 @@ with gr.Blocks() as demo:
|
|
418 |
)
|
419 |
send_btn = gr.Button("Send")
|
420 |
|
421 |
-
# On user submit, call chat_rag
|
422 |
user_input.submit(
|
423 |
fn=chat_rag,
|
424 |
inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens],
|
|
|
16 |
BitsAndBytesConfig,
|
17 |
)
|
18 |
|
|
|
19 |
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, PeftModel
|
20 |
|
|
|
21 |
from sentence_transformers import SentenceTransformer
|
22 |
|
|
|
|
|
|
|
|
|
23 |
TEXT_PIPELINE = None
|
24 |
COMPARISON_PIPELINE = None
|
25 |
+
NUM_EXAMPLES = 50
|
26 |
|
27 |
@spaces.GPU(duration=300)
|
28 |
def finetune_small_subset():
|
|
|
34 |
5) Reloads LoRA adapters for inference in a pipeline.
|
35 |
"""
|
36 |
|
|
|
37 |
ds = load_dataset(
|
38 |
"Magpie-Align/Magpie-Reasoning-V2-250K-CoT-Deepseek-R1-Llama-70B",
|
39 |
split="train"
|
|
|
45 |
|
46 |
ds = ds.select(range(min(NUM_EXAMPLES, len(ds))))
|
47 |
|
|
|
48 |
bnb_config = BitsAndBytesConfig(
|
49 |
load_in_4bit=True,
|
50 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
51 |
bnb_4bit_use_double_quant=True,
|
52 |
bnb_4bit_quant_type="nf4",
|
53 |
)
|
|
|
67 |
"wuhp/myr1",
|
68 |
subfolder="myr1",
|
69 |
config=config,
|
70 |
+
quantization_config=bnb_config,
|
71 |
device_map="auto",
|
72 |
trust_remote_code=True
|
73 |
)
|
74 |
|
75 |
base_model = prepare_model_for_kbit_training(base_model)
|
76 |
|
|
|
77 |
lora_config = LoraConfig(
|
78 |
r=16,
|
79 |
lora_alpha=32,
|
|
|
84 |
)
|
85 |
lora_model = get_peft_model(base_model, lora_config)
|
86 |
|
|
|
87 |
def tokenize_fn(ex):
|
88 |
text = (
|
89 |
f"Instruction: {ex['instruction']}\n\n"
|
|
|
96 |
|
97 |
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
98 |
|
|
|
99 |
training_args = TrainingArguments(
|
100 |
output_dir="finetuned_myr1",
|
101 |
num_train_epochs=1,
|
|
|
115 |
)
|
116 |
trainer.train()
|
117 |
|
|
|
118 |
trainer.model.save_pretrained("finetuned_myr1")
|
119 |
tokenizer.save_pretrained("finetuned_myr1")
|
120 |
|
|
|
121 |
base_model_2 = AutoModelForCausalLM.from_pretrained(
|
122 |
"wuhp/myr1",
|
123 |
subfolder="myr1",
|
|
|
222 |
)
|
223 |
return local_out[0]["generated_text"], comp_out[0]["generated_text"]
|
224 |
|
|
|
|
|
|
|
225 |
class ConversationRetriever:
|
226 |
"""
|
227 |
A simple in-memory store + FAISS for retrieval of conversation chunks.
|
|
|
237 |
self.embed_model = SentenceTransformer(model_name)
|
238 |
self.embed_dim = embed_dim
|
239 |
|
|
|
|
|
240 |
self.index = faiss.IndexFlatL2(embed_dim)
|
241 |
+
self.texts = []
|
242 |
+
self.vectors = []
|
243 |
+
self.ids = []
|
244 |
|
245 |
self.id_counter = 0
|
246 |
|
|
|
253 |
return
|
254 |
|
255 |
emb = self.embed_model.encode([text], convert_to_numpy=True)
|
256 |
+
vec = emb[0].astype(np.float32)
|
257 |
self.index.add(vec.reshape(1, -1))
|
258 |
|
259 |
self.texts.append(text)
|
|
|
270 |
q_vec = q_emb[0].reshape(1, -1)
|
271 |
distances, indices = self.index.search(q_vec, top_k)
|
272 |
|
|
|
273 |
results = []
|
274 |
for dist, idx in zip(distances[0], indices[0]):
|
275 |
+
if idx < len(self.texts):
|
276 |
results.append((self.texts[idx], dist))
|
277 |
return results
|
278 |
|
279 |
+
retriever = ConversationRetriever()
|
|
|
|
|
|
|
280 |
|
281 |
def build_rag_prompt(user_query, retrieved_chunks):
|
282 |
"""
|
|
|
309 |
"""
|
310 |
pipe = ensure_pipeline()
|
311 |
|
|
|
312 |
retriever.add_text(f"User: {user_input}")
|
313 |
|
|
|
|
|
314 |
top_k = 3
|
315 |
results = retriever.search(user_input, top_k=top_k)
|
316 |
|
|
|
317 |
prompt = build_rag_prompt(user_input, results)
|
318 |
|
|
|
319 |
output = pipe(
|
320 |
prompt,
|
321 |
temperature=float(temperature),
|
|
|
325 |
do_sample=True
|
326 |
)[0]["generated_text"]
|
327 |
|
|
|
|
|
328 |
if output.startswith(prompt):
|
329 |
assistant_reply = output[len(prompt):].strip()
|
330 |
else:
|
331 |
assistant_reply = output.strip()
|
332 |
|
|
|
333 |
retriever.add_text(f"Assistant: {assistant_reply}")
|
334 |
|
|
|
335 |
history.append([user_input, assistant_reply])
|
336 |
return history, history
|
337 |
|
|
|
|
|
|
|
338 |
with gr.Blocks() as demo:
|
339 |
gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo")
|
340 |
|
|
|
343 |
|
344 |
finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
|
345 |
|
|
|
346 |
gr.Markdown("## Direct Generation (No Retrieval)")
|
347 |
prompt_in = gr.Textbox(lines=3, label="Prompt")
|
348 |
temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature")
|
|
|
358 |
outputs=output_box
|
359 |
)
|
360 |
|
|
|
361 |
gr.Markdown("## Compare myr1 vs DeepSeek")
|
362 |
compare_btn = gr.Button("Compare")
|
363 |
out_local = gr.Textbox(label="myr1 Output", lines=6)
|
|
|
368 |
outputs=[out_local, out_deepseek]
|
369 |
)
|
370 |
|
|
|
371 |
gr.Markdown("## Chat with Retrieval-Augmented Memory")
|
372 |
with gr.Row():
|
373 |
with gr.Column():
|
374 |
chatbot = gr.Chatbot(label="RAG Chat")
|
375 |
+
chat_state = gr.State([])
|
376 |
|
377 |
user_input = gr.Textbox(
|
378 |
show_label=False,
|
|
|
381 |
)
|
382 |
send_btn = gr.Button("Send")
|
383 |
|
|
|
384 |
user_input.submit(
|
385 |
fn=chat_rag,
|
386 |
inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens],
|