Update app.py
Browse files
app.py
CHANGED
|
@@ -16,19 +16,13 @@ from transformers import (
|
|
| 16 |
BitsAndBytesConfig,
|
| 17 |
)
|
| 18 |
|
| 19 |
-
# PEFT (LoRA / QLoRA)
|
| 20 |
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, PeftModel
|
| 21 |
|
| 22 |
-
# For embeddings
|
| 23 |
from sentence_transformers import SentenceTransformer
|
| 24 |
|
| 25 |
-
##############################################################################
|
| 26 |
-
# QLoRA Demo Setup
|
| 27 |
-
##############################################################################
|
| 28 |
-
|
| 29 |
TEXT_PIPELINE = None
|
| 30 |
COMPARISON_PIPELINE = None
|
| 31 |
-
NUM_EXAMPLES = 50
|
| 32 |
|
| 33 |
@spaces.GPU(duration=300)
|
| 34 |
def finetune_small_subset():
|
|
@@ -40,7 +34,6 @@ def finetune_small_subset():
|
|
| 40 |
5) Reloads LoRA adapters for inference in a pipeline.
|
| 41 |
"""
|
| 42 |
|
| 43 |
-
# --- 1) Load a small subset of the Magpie dataset ---
|
| 44 |
ds = load_dataset(
|
| 45 |
"Magpie-Align/Magpie-Reasoning-V2-250K-CoT-Deepseek-R1-Llama-70B",
|
| 46 |
split="train"
|
|
@@ -52,10 +45,9 @@ def finetune_small_subset():
|
|
| 52 |
|
| 53 |
ds = ds.select(range(min(NUM_EXAMPLES, len(ds))))
|
| 54 |
|
| 55 |
-
# --- 2) Setup 4-bit quantization ---
|
| 56 |
bnb_config = BitsAndBytesConfig(
|
| 57 |
load_in_4bit=True,
|
| 58 |
-
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 59 |
bnb_4bit_use_double_quant=True,
|
| 60 |
bnb_4bit_quant_type="nf4",
|
| 61 |
)
|
|
@@ -75,14 +67,13 @@ def finetune_small_subset():
|
|
| 75 |
"wuhp/myr1",
|
| 76 |
subfolder="myr1",
|
| 77 |
config=config,
|
| 78 |
-
quantization_config=bnb_config,
|
| 79 |
device_map="auto",
|
| 80 |
trust_remote_code=True
|
| 81 |
)
|
| 82 |
|
| 83 |
base_model = prepare_model_for_kbit_training(base_model)
|
| 84 |
|
| 85 |
-
# --- 3) Create LoRA config & wrap the base model in LoRA ---
|
| 86 |
lora_config = LoraConfig(
|
| 87 |
r=16,
|
| 88 |
lora_alpha=32,
|
|
@@ -93,7 +84,6 @@ def finetune_small_subset():
|
|
| 93 |
)
|
| 94 |
lora_model = get_peft_model(base_model, lora_config)
|
| 95 |
|
| 96 |
-
# --- 4) Tokenize dataset ---
|
| 97 |
def tokenize_fn(ex):
|
| 98 |
text = (
|
| 99 |
f"Instruction: {ex['instruction']}\n\n"
|
|
@@ -106,7 +96,6 @@ def finetune_small_subset():
|
|
| 106 |
|
| 107 |
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
| 108 |
|
| 109 |
-
# Training args
|
| 110 |
training_args = TrainingArguments(
|
| 111 |
output_dir="finetuned_myr1",
|
| 112 |
num_train_epochs=1,
|
|
@@ -126,11 +115,9 @@ def finetune_small_subset():
|
|
| 126 |
)
|
| 127 |
trainer.train()
|
| 128 |
|
| 129 |
-
# --- 5) Save LoRA adapter + tokenizer ---
|
| 130 |
trainer.model.save_pretrained("finetuned_myr1")
|
| 131 |
tokenizer.save_pretrained("finetuned_myr1")
|
| 132 |
|
| 133 |
-
# --- 6) Reload for inference
|
| 134 |
base_model_2 = AutoModelForCausalLM.from_pretrained(
|
| 135 |
"wuhp/myr1",
|
| 136 |
subfolder="myr1",
|
|
@@ -235,9 +222,6 @@ def compare_models(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
|
|
| 235 |
)
|
| 236 |
return local_out[0]["generated_text"], comp_out[0]["generated_text"]
|
| 237 |
|
| 238 |
-
###############################################################################
|
| 239 |
-
# Retrieval-Augmented Memory with FAISS
|
| 240 |
-
###############################################################################
|
| 241 |
class ConversationRetriever:
|
| 242 |
"""
|
| 243 |
A simple in-memory store + FAISS for retrieval of conversation chunks.
|
|
@@ -253,12 +237,10 @@ class ConversationRetriever:
|
|
| 253 |
self.embed_model = SentenceTransformer(model_name)
|
| 254 |
self.embed_dim = embed_dim
|
| 255 |
|
| 256 |
-
# We'll store (text, vector) in FAISS. For metadata, store in python list/dict.
|
| 257 |
-
# For a real app, you'd probably want a more robust store.
|
| 258 |
self.index = faiss.IndexFlatL2(embed_dim)
|
| 259 |
-
self.texts = []
|
| 260 |
-
self.vectors = []
|
| 261 |
-
self.ids = []
|
| 262 |
|
| 263 |
self.id_counter = 0
|
| 264 |
|
|
@@ -271,7 +253,7 @@ class ConversationRetriever:
|
|
| 271 |
return
|
| 272 |
|
| 273 |
emb = self.embed_model.encode([text], convert_to_numpy=True)
|
| 274 |
-
vec = emb[0].astype(np.float32)
|
| 275 |
self.index.add(vec.reshape(1, -1))
|
| 276 |
|
| 277 |
self.texts.append(text)
|
|
@@ -288,17 +270,13 @@ class ConversationRetriever:
|
|
| 288 |
q_vec = q_emb[0].reshape(1, -1)
|
| 289 |
distances, indices = self.index.search(q_vec, top_k)
|
| 290 |
|
| 291 |
-
# indices is shape [1, top_k], distances is shape [1, top_k]
|
| 292 |
results = []
|
| 293 |
for dist, idx in zip(distances[0], indices[0]):
|
| 294 |
-
if idx < len(self.texts):
|
| 295 |
results.append((self.texts[idx], dist))
|
| 296 |
return results
|
| 297 |
|
| 298 |
-
|
| 299 |
-
# Build a Chat that uses RAG
|
| 300 |
-
###############################################################################
|
| 301 |
-
retriever = ConversationRetriever() # global retriever instance
|
| 302 |
|
| 303 |
def build_rag_prompt(user_query, retrieved_chunks):
|
| 304 |
"""
|
|
@@ -331,18 +309,13 @@ def chat_rag(user_input, history, temperature, top_p, min_new_tokens, max_new_to
|
|
| 331 |
"""
|
| 332 |
pipe = ensure_pipeline()
|
| 333 |
|
| 334 |
-
# 1) Add the user input as a chunk to the retriever DB.
|
| 335 |
retriever.add_text(f"User: {user_input}")
|
| 336 |
|
| 337 |
-
# 2) Retrieve top-3 older chunks. We can skip the chunk we just added if we want to
|
| 338 |
-
# (since it's the same text), but for simplicity let's just do a search for user_input.
|
| 339 |
top_k = 3
|
| 340 |
results = retriever.search(user_input, top_k=top_k)
|
| 341 |
|
| 342 |
-
# 3) Build final prompt
|
| 343 |
prompt = build_rag_prompt(user_input, results)
|
| 344 |
|
| 345 |
-
# 4) Generate
|
| 346 |
output = pipe(
|
| 347 |
prompt,
|
| 348 |
temperature=float(temperature),
|
|
@@ -352,23 +325,16 @@ def chat_rag(user_input, history, temperature, top_p, min_new_tokens, max_new_to
|
|
| 352 |
do_sample=True
|
| 353 |
)[0]["generated_text"]
|
| 354 |
|
| 355 |
-
# We only want the new part after "Assistant:"
|
| 356 |
-
# Because the pipeline output includes the entire prompt + new text.
|
| 357 |
if output.startswith(prompt):
|
| 358 |
assistant_reply = output[len(prompt):].strip()
|
| 359 |
else:
|
| 360 |
assistant_reply = output.strip()
|
| 361 |
|
| 362 |
-
# 5) Add the assistant's response to the DB as well
|
| 363 |
retriever.add_text(f"Assistant: {assistant_reply}")
|
| 364 |
|
| 365 |
-
# 6) Update the chat history for display in the Gradio Chatbot
|
| 366 |
history.append([user_input, assistant_reply])
|
| 367 |
return history, history
|
| 368 |
|
| 369 |
-
###############################################################################
|
| 370 |
-
# Gradio UI
|
| 371 |
-
###############################################################################
|
| 372 |
with gr.Blocks() as demo:
|
| 373 |
gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo")
|
| 374 |
|
|
@@ -377,7 +343,6 @@ with gr.Blocks() as demo:
|
|
| 377 |
|
| 378 |
finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
|
| 379 |
|
| 380 |
-
# Simple generation UI (no retrieval):
|
| 381 |
gr.Markdown("## Direct Generation (No Retrieval)")
|
| 382 |
prompt_in = gr.Textbox(lines=3, label="Prompt")
|
| 383 |
temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature")
|
|
@@ -393,7 +358,6 @@ with gr.Blocks() as demo:
|
|
| 393 |
outputs=output_box
|
| 394 |
)
|
| 395 |
|
| 396 |
-
# Comparison UI:
|
| 397 |
gr.Markdown("## Compare myr1 vs DeepSeek")
|
| 398 |
compare_btn = gr.Button("Compare")
|
| 399 |
out_local = gr.Textbox(label="myr1 Output", lines=6)
|
|
@@ -404,12 +368,11 @@ with gr.Blocks() as demo:
|
|
| 404 |
outputs=[out_local, out_deepseek]
|
| 405 |
)
|
| 406 |
|
| 407 |
-
# RAG-based Chat
|
| 408 |
gr.Markdown("## Chat with Retrieval-Augmented Memory")
|
| 409 |
with gr.Row():
|
| 410 |
with gr.Column():
|
| 411 |
chatbot = gr.Chatbot(label="RAG Chat")
|
| 412 |
-
chat_state = gr.State([])
|
| 413 |
|
| 414 |
user_input = gr.Textbox(
|
| 415 |
show_label=False,
|
|
@@ -418,7 +381,6 @@ with gr.Blocks() as demo:
|
|
| 418 |
)
|
| 419 |
send_btn = gr.Button("Send")
|
| 420 |
|
| 421 |
-
# On user submit, call chat_rag
|
| 422 |
user_input.submit(
|
| 423 |
fn=chat_rag,
|
| 424 |
inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens],
|
|
|
|
| 16 |
BitsAndBytesConfig,
|
| 17 |
)
|
| 18 |
|
|
|
|
| 19 |
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, PeftModel
|
| 20 |
|
|
|
|
| 21 |
from sentence_transformers import SentenceTransformer
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
TEXT_PIPELINE = None
|
| 24 |
COMPARISON_PIPELINE = None
|
| 25 |
+
NUM_EXAMPLES = 50
|
| 26 |
|
| 27 |
@spaces.GPU(duration=300)
|
| 28 |
def finetune_small_subset():
|
|
|
|
| 34 |
5) Reloads LoRA adapters for inference in a pipeline.
|
| 35 |
"""
|
| 36 |
|
|
|
|
| 37 |
ds = load_dataset(
|
| 38 |
"Magpie-Align/Magpie-Reasoning-V2-250K-CoT-Deepseek-R1-Llama-70B",
|
| 39 |
split="train"
|
|
|
|
| 45 |
|
| 46 |
ds = ds.select(range(min(NUM_EXAMPLES, len(ds))))
|
| 47 |
|
|
|
|
| 48 |
bnb_config = BitsAndBytesConfig(
|
| 49 |
load_in_4bit=True,
|
| 50 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 51 |
bnb_4bit_use_double_quant=True,
|
| 52 |
bnb_4bit_quant_type="nf4",
|
| 53 |
)
|
|
|
|
| 67 |
"wuhp/myr1",
|
| 68 |
subfolder="myr1",
|
| 69 |
config=config,
|
| 70 |
+
quantization_config=bnb_config,
|
| 71 |
device_map="auto",
|
| 72 |
trust_remote_code=True
|
| 73 |
)
|
| 74 |
|
| 75 |
base_model = prepare_model_for_kbit_training(base_model)
|
| 76 |
|
|
|
|
| 77 |
lora_config = LoraConfig(
|
| 78 |
r=16,
|
| 79 |
lora_alpha=32,
|
|
|
|
| 84 |
)
|
| 85 |
lora_model = get_peft_model(base_model, lora_config)
|
| 86 |
|
|
|
|
| 87 |
def tokenize_fn(ex):
|
| 88 |
text = (
|
| 89 |
f"Instruction: {ex['instruction']}\n\n"
|
|
|
|
| 96 |
|
| 97 |
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
| 98 |
|
|
|
|
| 99 |
training_args = TrainingArguments(
|
| 100 |
output_dir="finetuned_myr1",
|
| 101 |
num_train_epochs=1,
|
|
|
|
| 115 |
)
|
| 116 |
trainer.train()
|
| 117 |
|
|
|
|
| 118 |
trainer.model.save_pretrained("finetuned_myr1")
|
| 119 |
tokenizer.save_pretrained("finetuned_myr1")
|
| 120 |
|
|
|
|
| 121 |
base_model_2 = AutoModelForCausalLM.from_pretrained(
|
| 122 |
"wuhp/myr1",
|
| 123 |
subfolder="myr1",
|
|
|
|
| 222 |
)
|
| 223 |
return local_out[0]["generated_text"], comp_out[0]["generated_text"]
|
| 224 |
|
|
|
|
|
|
|
|
|
|
| 225 |
class ConversationRetriever:
|
| 226 |
"""
|
| 227 |
A simple in-memory store + FAISS for retrieval of conversation chunks.
|
|
|
|
| 237 |
self.embed_model = SentenceTransformer(model_name)
|
| 238 |
self.embed_dim = embed_dim
|
| 239 |
|
|
|
|
|
|
|
| 240 |
self.index = faiss.IndexFlatL2(embed_dim)
|
| 241 |
+
self.texts = []
|
| 242 |
+
self.vectors = []
|
| 243 |
+
self.ids = []
|
| 244 |
|
| 245 |
self.id_counter = 0
|
| 246 |
|
|
|
|
| 253 |
return
|
| 254 |
|
| 255 |
emb = self.embed_model.encode([text], convert_to_numpy=True)
|
| 256 |
+
vec = emb[0].astype(np.float32)
|
| 257 |
self.index.add(vec.reshape(1, -1))
|
| 258 |
|
| 259 |
self.texts.append(text)
|
|
|
|
| 270 |
q_vec = q_emb[0].reshape(1, -1)
|
| 271 |
distances, indices = self.index.search(q_vec, top_k)
|
| 272 |
|
|
|
|
| 273 |
results = []
|
| 274 |
for dist, idx in zip(distances[0], indices[0]):
|
| 275 |
+
if idx < len(self.texts):
|
| 276 |
results.append((self.texts[idx], dist))
|
| 277 |
return results
|
| 278 |
|
| 279 |
+
retriever = ConversationRetriever()
|
|
|
|
|
|
|
|
|
|
| 280 |
|
| 281 |
def build_rag_prompt(user_query, retrieved_chunks):
|
| 282 |
"""
|
|
|
|
| 309 |
"""
|
| 310 |
pipe = ensure_pipeline()
|
| 311 |
|
|
|
|
| 312 |
retriever.add_text(f"User: {user_input}")
|
| 313 |
|
|
|
|
|
|
|
| 314 |
top_k = 3
|
| 315 |
results = retriever.search(user_input, top_k=top_k)
|
| 316 |
|
|
|
|
| 317 |
prompt = build_rag_prompt(user_input, results)
|
| 318 |
|
|
|
|
| 319 |
output = pipe(
|
| 320 |
prompt,
|
| 321 |
temperature=float(temperature),
|
|
|
|
| 325 |
do_sample=True
|
| 326 |
)[0]["generated_text"]
|
| 327 |
|
|
|
|
|
|
|
| 328 |
if output.startswith(prompt):
|
| 329 |
assistant_reply = output[len(prompt):].strip()
|
| 330 |
else:
|
| 331 |
assistant_reply = output.strip()
|
| 332 |
|
|
|
|
| 333 |
retriever.add_text(f"Assistant: {assistant_reply}")
|
| 334 |
|
|
|
|
| 335 |
history.append([user_input, assistant_reply])
|
| 336 |
return history, history
|
| 337 |
|
|
|
|
|
|
|
|
|
|
| 338 |
with gr.Blocks() as demo:
|
| 339 |
gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo")
|
| 340 |
|
|
|
|
| 343 |
|
| 344 |
finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
|
| 345 |
|
|
|
|
| 346 |
gr.Markdown("## Direct Generation (No Retrieval)")
|
| 347 |
prompt_in = gr.Textbox(lines=3, label="Prompt")
|
| 348 |
temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature")
|
|
|
|
| 358 |
outputs=output_box
|
| 359 |
)
|
| 360 |
|
|
|
|
| 361 |
gr.Markdown("## Compare myr1 vs DeepSeek")
|
| 362 |
compare_btn = gr.Button("Compare")
|
| 363 |
out_local = gr.Textbox(label="myr1 Output", lines=6)
|
|
|
|
| 368 |
outputs=[out_local, out_deepseek]
|
| 369 |
)
|
| 370 |
|
|
|
|
| 371 |
gr.Markdown("## Chat with Retrieval-Augmented Memory")
|
| 372 |
with gr.Row():
|
| 373 |
with gr.Column():
|
| 374 |
chatbot = gr.Chatbot(label="RAG Chat")
|
| 375 |
+
chat_state = gr.State([])
|
| 376 |
|
| 377 |
user_input = gr.Textbox(
|
| 378 |
show_label=False,
|
|
|
|
| 381 |
)
|
| 382 |
send_btn = gr.Button("Send")
|
| 383 |
|
|
|
|
| 384 |
user_input.submit(
|
| 385 |
fn=chat_rag,
|
| 386 |
inputs=[user_input, chat_state, temperature, top_p, min_tokens, max_tokens],
|