Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -27,7 +27,7 @@ import shutil
|
|
27 |
|
28 |
# --- Configuration ---
|
29 |
YOUR_HF_USERNAME = "Twelve2five"
|
30 |
-
MODEL_REPO_NAME = "llama-3-
|
31 |
DATASET_REPO_NAME = "podcast-dialogue-rvq-pairs-3items"
|
32 |
|
33 |
hf_model_repo_id = f"{YOUR_HF_USERNAME}/{MODEL_REPO_NAME}"
|
@@ -329,21 +329,14 @@ def train_model(
|
|
329 |
model_repo_name,
|
330 |
dataset_repo_name,
|
331 |
epochs=1,
|
332 |
-
batch_size=
|
333 |
-
grad_accum_steps=
|
334 |
-
learning_rate=
|
335 |
progress=gr.Progress()
|
336 |
):
|
337 |
progress(0, desc="Setting up environment...")
|
338 |
log = []
|
339 |
|
340 |
-
# Aggressive memory cleanup
|
341 |
-
gc.collect()
|
342 |
-
if torch.cuda.is_available():
|
343 |
-
torch.cuda.empty_cache()
|
344 |
-
# Reset peak memory stats
|
345 |
-
torch.cuda.reset_peak_memory_stats()
|
346 |
-
|
347 |
# Clean up any existing model files to save space
|
348 |
if os.path.exists("./model_files"):
|
349 |
try:
|
@@ -371,8 +364,8 @@ def train_model(
|
|
371 |
from huggingface_hub import snapshot_download
|
372 |
import torch
|
373 |
import transformers
|
374 |
-
from transformers import AutoModelForCausalLM,
|
375 |
-
from transformers import BitsAndBytesConfig, TrainingArguments, Trainer
|
376 |
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
|
377 |
|
378 |
log.append(f"Transformers version: {transformers.__version__}")
|
@@ -393,40 +386,65 @@ def train_model(
|
|
393 |
n_gpus = torch.cuda.device_count()
|
394 |
log.append(f"Number of GPUs available: {n_gpus}")
|
395 |
|
396 |
-
# ---
|
397 |
-
|
398 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
399 |
try:
|
400 |
-
# Download
|
|
|
401 |
snapshot_download(
|
402 |
repo_id=hf_model_repo_id,
|
403 |
local_dir=local_model_path,
|
404 |
-
local_dir_use_symlinks=False
|
405 |
)
|
406 |
log.append(f"Model files downloaded to {local_model_path}")
|
407 |
|
408 |
-
#
|
409 |
config_path = os.path.join(local_model_path, "config.json")
|
410 |
with open(config_path, "r") as f:
|
411 |
config_data = json.load(f)
|
412 |
|
413 |
-
|
|
|
414 |
log.append(f"Model architecture type: {model_type}")
|
415 |
|
416 |
-
#
|
417 |
-
if model_type != "llama":
|
418 |
-
config_data["model_type"] = "llama"
|
419 |
-
# Also ensure architectures is set correctly
|
420 |
-
config_data["architectures"] = ["LlamaForCausalLM"]
|
421 |
-
with open(config_path, "w") as f:
|
422 |
-
json.dump(config_data, f, indent=2)
|
423 |
-
log.append("Updated config.json to use llama model_type")
|
424 |
-
|
425 |
-
# Load the config first
|
426 |
-
config = LlamaConfig.from_pretrained(local_model_path)
|
427 |
-
log.append(f"Successfully loaded config: {config.model_type}")
|
428 |
-
|
429 |
-
# Use 4-bit quantization for extreme memory savings
|
430 |
bnb_config = BitsAndBytesConfig(
|
431 |
load_in_4bit=True,
|
432 |
bnb_4bit_use_double_quant=True,
|
@@ -434,25 +452,28 @@ def train_model(
|
|
434 |
bnb_4bit_compute_dtype=torch.bfloat16
|
435 |
)
|
436 |
|
437 |
-
# Load
|
438 |
-
|
439 |
-
|
440 |
-
# Explicit device map to enable CPU offloading
|
441 |
-
max_memory = {0: "40GB", "cpu": "64GB"}
|
442 |
-
|
443 |
-
# Load the model with extreme memory optimization
|
444 |
-
model = LlamaForCausalLM.from_pretrained(
|
445 |
local_model_path,
|
446 |
-
config=config,
|
447 |
quantization_config=bnb_config,
|
448 |
device_map="auto",
|
449 |
-
|
450 |
-
|
451 |
-
low_cpu_mem_usage=True
|
452 |
)
|
453 |
|
454 |
-
|
455 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
456 |
except Exception as e:
|
457 |
error_msg = f"Error loading model: {str(e)}"
|
458 |
log.append(error_msg)
|
@@ -464,138 +485,77 @@ def train_model(
|
|
464 |
model = prepare_model_for_kbit_training(model)
|
465 |
log.append("Model prepared for k-bit training")
|
466 |
|
467 |
-
#
|
468 |
lora_config = LoraConfig(
|
469 |
task_type=TaskType.CAUSAL_LM,
|
470 |
-
r=8, # Reduced from 16 to
|
471 |
-
lora_alpha=16, # Reduced from 32
|
472 |
lora_dropout=0.05,
|
473 |
bias="none",
|
474 |
-
|
475 |
-
target_modules=["q_proj", "v_proj"] # Reduced target modules
|
476 |
)
|
477 |
-
|
478 |
-
# Apply LoRA
|
479 |
peft_model = get_peft_model(model, lora_config)
|
|
|
|
|
480 |
model_to_train = peft_model
|
481 |
-
log.append("LoRA applied to model")
|
482 |
-
|
483 |
-
# Free memory
|
484 |
-
gc.collect()
|
485 |
-
if torch.cuda.is_available():
|
486 |
-
torch.cuda.empty_cache()
|
487 |
except Exception as e:
|
488 |
error_msg = f"Error preparing model for training: {str(e)}"
|
489 |
log.append(error_msg)
|
490 |
return "\n".join(log)
|
491 |
|
492 |
-
# ---
|
493 |
-
progress(0.
|
494 |
try:
|
495 |
-
# Download
|
496 |
-
|
497 |
snapshot_download(
|
498 |
repo_id=hf_dataset_repo_id,
|
499 |
-
local_dir=
|
500 |
-
local_dir_use_symlinks=False
|
501 |
)
|
502 |
-
log.append(f"Dataset repository content downloaded to: {
|
503 |
|
504 |
-
# Find all RVQ
|
505 |
-
rvq_pair_files = glob.glob(os.path.join(
|
506 |
log.append(f"Found {len(rvq_pair_files)} RVQ pair files.")
|
507 |
|
508 |
-
# Load
|
509 |
-
|
510 |
-
|
511 |
-
# For memory conservation, use only half the dataset for now
|
512 |
-
max_file_count = min(12, len(rvq_pair_files))
|
513 |
|
514 |
-
for
|
515 |
-
|
516 |
-
|
517 |
-
training_pairs.extend(pairs)
|
518 |
-
except Exception as e:
|
519 |
-
log.append(f"Warning: Could not load {pair_file}: {e}")
|
520 |
|
521 |
-
|
|
|
522 |
|
523 |
-
|
524 |
-
dataset = Dataset.from_dict({
|
525 |
-
"input_ids": [pair[0].tolist() for pair in training_pairs],
|
526 |
-
"labels": [pair[1].tolist() for pair in training_pairs]
|
527 |
-
})
|
528 |
|
529 |
-
#
|
530 |
-
|
531 |
-
|
532 |
-
|
|
|
533 |
|
534 |
-
|
535 |
-
max_length = 512 # Reduced max sequence length
|
536 |
|
537 |
-
# Create data collator
|
538 |
-
|
539 |
-
# Convert lists back to tensors
|
540 |
-
for i in range(len(examples)):
|
541 |
-
examples[i]["input_ids"] = torch.tensor(examples[i]["input_ids"], dtype=torch.long)
|
542 |
-
examples[i]["labels"] = torch.tensor(examples[i]["labels"], dtype=torch.long)
|
543 |
-
|
544 |
-
# Get max length in this batch
|
545 |
-
batch_max_length = min(
|
546 |
-
max(len(example["input_ids"]) for example in examples),
|
547 |
-
max_length
|
548 |
-
)
|
549 |
-
|
550 |
-
batch = {
|
551 |
-
"input_ids": [],
|
552 |
-
"attention_mask": [],
|
553 |
-
"labels": []
|
554 |
-
}
|
555 |
-
|
556 |
-
# Prepare sequences
|
557 |
-
for example in examples:
|
558 |
-
input_ids = example["input_ids"][:batch_max_length]
|
559 |
-
labels = example["labels"][:batch_max_length]
|
560 |
-
|
561 |
-
# Pad sequences
|
562 |
-
padding_length = batch_max_length - len(input_ids)
|
563 |
-
attention_mask = torch.ones_like(input_ids)
|
564 |
-
|
565 |
-
if padding_length > 0:
|
566 |
-
padding = torch.ones(padding_length, dtype=input_ids.dtype) * tokenizer.pad_token_id
|
567 |
-
input_ids = torch.cat([input_ids, padding])
|
568 |
-
labels = torch.cat([labels, padding * -100]) # -100 to ignore in loss computation
|
569 |
-
attention_mask = torch.cat([attention_mask, torch.zeros(padding_length, dtype=attention_mask.dtype)])
|
570 |
-
|
571 |
-
batch["input_ids"].append(input_ids)
|
572 |
-
batch["attention_mask"].append(attention_mask)
|
573 |
-
batch["labels"].append(labels)
|
574 |
-
|
575 |
-
# Convert lists to tensors
|
576 |
-
for key in batch:
|
577 |
-
batch[key] = torch.stack(batch[key])
|
578 |
-
|
579 |
-
return batch
|
580 |
|
581 |
-
|
582 |
-
|
|
|
|
|
583 |
|
584 |
-
# Free memory
|
585 |
-
del dataset
|
586 |
-
gc.collect()
|
587 |
-
torch.cuda.empty_cache()
|
588 |
except Exception as e:
|
589 |
error_msg = f"Error loading dataset: {str(e)}"
|
590 |
log.append(error_msg)
|
591 |
return "\n".join(log)
|
592 |
|
593 |
# --- Training Arguments ---
|
594 |
-
progress(0.
|
595 |
output_dir = f"./results_{model_repo_name}"
|
596 |
os.makedirs(output_dir, exist_ok=True)
|
597 |
|
598 |
-
# Super-aggressive memory conservation
|
599 |
training_args = TrainingArguments(
|
600 |
output_dir=output_dir,
|
601 |
num_train_epochs=float(epochs),
|
@@ -604,90 +564,43 @@ def train_model(
|
|
604 |
learning_rate=learning_rate,
|
605 |
weight_decay=0.01,
|
606 |
logging_dir=f"{output_dir}/logs",
|
607 |
-
logging_steps=
|
608 |
-
save_steps=
|
609 |
-
save_total_limit=
|
610 |
remove_unused_columns=False,
|
611 |
push_to_hub=False,
|
612 |
disable_tqdm=False,
|
613 |
warmup_ratio=0.03,
|
614 |
lr_scheduler_type="cosine",
|
615 |
report_to="tensorboard",
|
616 |
-
bf16=True,
|
617 |
-
fp16=False,
|
618 |
-
|
619 |
-
# Memory optimization
|
620 |
gradient_checkpointing=True,
|
621 |
gradient_checkpointing_kwargs={'use_reentrant': False},
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
# Optimizer settings for memory efficiency
|
626 |
-
optim="adamw_torch",
|
627 |
-
adam_beta1=0.9,
|
628 |
-
adam_beta2=0.999,
|
629 |
-
adam_epsilon=1e-8,
|
630 |
-
|
631 |
-
# Evaluation settings
|
632 |
-
do_eval=False,
|
633 |
-
evaluation_strategy="no",
|
634 |
-
|
635 |
-
# Set this for smaller chunks of data processing
|
636 |
-
dataloader_num_workers=1,
|
637 |
-
|
638 |
-
# For memory efficiency when loading datasets
|
639 |
-
dataloader_drop_last=True,
|
640 |
)
|
641 |
|
642 |
# --- Initialize Trainer ---
|
643 |
-
progress(0.
|
644 |
-
|
645 |
-
# Use optimizer that requires less memory
|
646 |
-
class MemoryEfficientTrainer(Trainer):
|
647 |
-
def create_optimizer(self):
|
648 |
-
# Create optimizer with reduced memory footprint
|
649 |
-
optimizer = super().create_optimizer()
|
650 |
-
# Force optimizer to use CPU offloading for states
|
651 |
-
for param_group in optimizer.param_groups:
|
652 |
-
for param in param_group['params']:
|
653 |
-
if param.requires_grad:
|
654 |
-
param.data = param.data.to("cpu")
|
655 |
-
if param.grad is not None:
|
656 |
-
param.grad.data = param.grad.data.to("cpu")
|
657 |
-
return optimizer
|
658 |
-
|
659 |
-
def training_step(self, *args, **kwargs):
|
660 |
-
# Memory cleanup before each training step
|
661 |
-
gc.collect()
|
662 |
-
torch.cuda.empty_cache()
|
663 |
-
return super().training_step(*args, **kwargs)
|
664 |
-
|
665 |
-
trainer = MemoryEfficientTrainer(
|
666 |
model=model_to_train,
|
667 |
args=training_args,
|
668 |
train_dataset=train_dataset,
|
669 |
data_collator=data_collator,
|
670 |
)
|
671 |
|
672 |
-
log.append("Trainer initialized
|
673 |
|
674 |
# --- Start Training ---
|
|
|
|
|
|
|
|
|
|
|
675 |
try:
|
676 |
-
|
677 |
-
|
678 |
-
if torch.cuda.is_available():
|
679 |
-
torch.cuda.empty_cache()
|
680 |
-
|
681 |
-
progress(0.5, desc="Starting training...")
|
682 |
-
log.append("Starting training with extreme memory optimization...")
|
683 |
-
|
684 |
-
# Train in smaller chunks to manage memory better
|
685 |
-
total_steps = len(train_dataset) // (batch_size * grad_accum_steps)
|
686 |
-
log.append(f"Total training steps: {total_steps}")
|
687 |
-
|
688 |
-
# Train the model
|
689 |
train_result = trainer.train()
|
690 |
-
|
691 |
progress(0.95, desc="Saving model...")
|
692 |
|
693 |
# Save final model (adapter weights) and training state
|
@@ -703,49 +616,33 @@ def train_model(
|
|
703 |
|
704 |
for key, value in metrics.items():
|
705 |
log.append(f"{key}: {value}")
|
706 |
-
|
707 |
-
# Print peak memory usage
|
708 |
-
if torch.cuda.is_available():
|
709 |
-
peak_memory = torch.cuda.max_memory_allocated() / (1024**3)
|
710 |
-
log.append(f"Peak GPU memory usage: {peak_memory:.2f} GB")
|
711 |
|
712 |
except Exception as e:
|
713 |
-
error_msg = f"An error occurred during training: {
|
714 |
log.append(error_msg)
|
715 |
-
|
716 |
-
# Try to save checkpoint even if training failed
|
717 |
-
try:
|
718 |
-
# Save whatever we have
|
719 |
-
log.append("Attempting to save partial checkpoint...")
|
720 |
-
emergency_save_path = os.path.join(training_args.output_dir, "emergency_checkpoint")
|
721 |
-
trainer.save_model(emergency_save_path)
|
722 |
-
log.append(f"Saved emergency checkpoint to {emergency_save_path}")
|
723 |
-
except Exception as save_error:
|
724 |
-
log.append(f"Could not save emergency checkpoint: {save_error}")
|
725 |
-
|
726 |
return "\n".join(log)
|
727 |
|
728 |
progress(1.0, desc="Training complete!")
|
729 |
-
log.append("Training process complete
|
730 |
return "\n".join(log)
|
731 |
|
732 |
# Define the Gradio interface
|
733 |
def create_interface():
|
734 |
-
with gr.Blocks(title="Llama 3
|
735 |
-
gr.Markdown("# Llama 3
|
736 |
-
gr.Markdown("Fine-tune a Llama 3
|
737 |
|
738 |
with gr.Row():
|
739 |
with gr.Column():
|
740 |
hf_username = gr.Textbox(label="HuggingFace Username", value="Twelve2five")
|
741 |
-
model_repo = gr.Textbox(label="Model Repository Name", value="llama-3-
|
742 |
dataset_repo = gr.Textbox(label="Dataset Repository Name", value="podcast-dialogue-rvq-pairs-3items")
|
743 |
|
744 |
with gr.Column():
|
745 |
-
epochs = gr.Number(label="Number of Epochs", value=
|
746 |
-
batch_size = gr.Number(label="Batch Size per Device", value=
|
747 |
-
grad_accum = gr.Number(label="Gradient Accumulation Steps", value=
|
748 |
-
lr = gr.Number(label="Learning Rate", value=
|
749 |
|
750 |
start_btn = gr.Button("Start Training")
|
751 |
output = gr.Textbox(label="Training Log", lines=20)
|
|
|
27 |
|
28 |
# --- Configuration ---
|
29 |
YOUR_HF_USERNAME = "Twelve2five"
|
30 |
+
MODEL_REPO_NAME = "llama-3.2-1b-rvq"
|
31 |
DATASET_REPO_NAME = "podcast-dialogue-rvq-pairs-3items"
|
32 |
|
33 |
hf_model_repo_id = f"{YOUR_HF_USERNAME}/{MODEL_REPO_NAME}"
|
|
|
329 |
model_repo_name,
|
330 |
dataset_repo_name,
|
331 |
epochs=1,
|
332 |
+
batch_size=2, # Increased batch size since model is smaller
|
333 |
+
grad_accum_steps=4,
|
334 |
+
learning_rate=2e-4, # Slightly higher learning rate for smaller model
|
335 |
progress=gr.Progress()
|
336 |
):
|
337 |
progress(0, desc="Setting up environment...")
|
338 |
log = []
|
339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
340 |
# Clean up any existing model files to save space
|
341 |
if os.path.exists("./model_files"):
|
342 |
try:
|
|
|
364 |
from huggingface_hub import snapshot_download
|
365 |
import torch
|
366 |
import transformers
|
367 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
368 |
+
from transformers import BitsAndBytesConfig, TrainingArguments, Trainer
|
369 |
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
|
370 |
|
371 |
log.append(f"Transformers version: {transformers.__version__}")
|
|
|
386 |
n_gpus = torch.cuda.device_count()
|
387 |
log.append(f"Number of GPUs available: {n_gpus}")
|
388 |
|
389 |
+
# --- DeepSpeed Configuration ---
|
390 |
+
# Create DeepSpeed config file
|
391 |
+
progress(0.1, desc="Setting up DeepSpeed configuration...")
|
392 |
+
|
393 |
+
# Create a conservative DeepSpeed config for the smaller model
|
394 |
+
ds_config = {
|
395 |
+
"fp16": {
|
396 |
+
"enabled": "auto",
|
397 |
+
"loss_scale": 0,
|
398 |
+
"loss_scale_window": 1000,
|
399 |
+
"initial_scale_power": 16,
|
400 |
+
"hysteresis": 2,
|
401 |
+
"min_loss_scale": 1
|
402 |
+
},
|
403 |
+
"bf16": {
|
404 |
+
"enabled": "auto"
|
405 |
+
},
|
406 |
+
"zero_optimization": {
|
407 |
+
"stage": 2, # Lower stage for smaller model
|
408 |
+
"offload_optimizer": {
|
409 |
+
"device": "cpu",
|
410 |
+
"pin_memory": True
|
411 |
+
},
|
412 |
+
"contiguous_gradients": True,
|
413 |
+
"overlap_comm": True
|
414 |
+
},
|
415 |
+
"gradient_accumulation_steps": grad_accum_steps,
|
416 |
+
"gradient_clipping": 1.0,
|
417 |
+
"train_batch_size": batch_size * grad_accum_steps * max(1, n_gpus)
|
418 |
+
}
|
419 |
+
|
420 |
+
ds_config_path = "ds_config.json"
|
421 |
+
with open(ds_config_path, "w") as f:
|
422 |
+
json.dump(ds_config, f, indent=4)
|
423 |
+
|
424 |
+
log.append("DeepSpeed configuration created successfully")
|
425 |
+
|
426 |
+
# --- Download and Load Model ---
|
427 |
+
progress(0.15, desc="Downloading model...")
|
428 |
+
|
429 |
try:
|
430 |
+
# Download model files
|
431 |
+
local_model_path = "./model_files"
|
432 |
snapshot_download(
|
433 |
repo_id=hf_model_repo_id,
|
434 |
local_dir=local_model_path,
|
|
|
435 |
)
|
436 |
log.append(f"Model files downloaded to {local_model_path}")
|
437 |
|
438 |
+
# First, load the config
|
439 |
config_path = os.path.join(local_model_path, "config.json")
|
440 |
with open(config_path, "r") as f:
|
441 |
config_data = json.load(f)
|
442 |
|
443 |
+
# Check model architecture
|
444 |
+
model_type = config_data.get("model_type", "").lower()
|
445 |
log.append(f"Model architecture type: {model_type}")
|
446 |
|
447 |
+
# Set 4-bit quantization config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
448 |
bnb_config = BitsAndBytesConfig(
|
449 |
load_in_4bit=True,
|
450 |
bnb_4bit_use_double_quant=True,
|
|
|
452 |
bnb_4bit_compute_dtype=torch.bfloat16
|
453 |
)
|
454 |
|
455 |
+
# Load the model with 4-bit quantization
|
456 |
+
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
|
|
457 |
local_model_path,
|
|
|
458 |
quantization_config=bnb_config,
|
459 |
device_map="auto",
|
460 |
+
trust_remote_code=False,
|
461 |
+
use_cache=False # No KV caching during training
|
|
|
462 |
)
|
463 |
|
464 |
+
# Load tokenizer
|
465 |
+
tokenizer = AutoTokenizer.from_pretrained(local_model_path)
|
466 |
+
if tokenizer.pad_token is None:
|
467 |
+
tokenizer.pad_token = tokenizer.eos_token
|
468 |
+
|
469 |
+
# Log model info
|
470 |
+
if hasattr(model, "config"):
|
471 |
+
log.append(f"Loaded model vocab size: {model.config.vocab_size}")
|
472 |
+
if hasattr(model, "get_input_embeddings"):
|
473 |
+
embedding = model.get_input_embeddings()
|
474 |
+
if hasattr(embedding, "weight"):
|
475 |
+
log.append(f"Input embedding shape: {embedding.weight.shape}")
|
476 |
+
|
477 |
except Exception as e:
|
478 |
error_msg = f"Error loading model: {str(e)}"
|
479 |
log.append(error_msg)
|
|
|
485 |
model = prepare_model_for_kbit_training(model)
|
486 |
log.append("Model prepared for k-bit training")
|
487 |
|
488 |
+
# For Llama 3.2 1B the target modules might be slightly different
|
489 |
lora_config = LoraConfig(
|
490 |
task_type=TaskType.CAUSAL_LM,
|
491 |
+
r=8, # Reduced from 16 due to smaller model
|
492 |
+
lora_alpha=16, # Reduced from 32
|
493 |
lora_dropout=0.05,
|
494 |
bias="none",
|
495 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
|
|
496 |
)
|
|
|
|
|
497 |
peft_model = get_peft_model(model, lora_config)
|
498 |
+
trainable_params = peft_model.print_trainable_parameters()
|
499 |
+
log.append(f"LoRA applied to model")
|
500 |
model_to_train = peft_model
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
except Exception as e:
|
502 |
error_msg = f"Error preparing model for training: {str(e)}"
|
503 |
log.append(error_msg)
|
504 |
return "\n".join(log)
|
505 |
|
506 |
+
# --- Load and Prepare Dataset ---
|
507 |
+
progress(0.3, desc="Loading and preparing dataset...")
|
508 |
try:
|
509 |
+
# Download dataset
|
510 |
+
dataset_path = "./downloaded_dataset_files"
|
511 |
snapshot_download(
|
512 |
repo_id=hf_dataset_repo_id,
|
513 |
+
local_dir=dataset_path,
|
|
|
514 |
)
|
515 |
+
log.append(f"Dataset repository content downloaded to: {dataset_path}")
|
516 |
|
517 |
+
# Find all RVQ pairs files
|
518 |
+
rvq_pair_files = glob.glob(os.path.join(dataset_path, "*_rvq_pairs.pt"))
|
519 |
log.append(f"Found {len(rvq_pair_files)} RVQ pair files.")
|
520 |
|
521 |
+
# Load the pytorch files
|
522 |
+
all_pairs = []
|
|
|
|
|
|
|
523 |
|
524 |
+
for file_path in rvq_pair_files:
|
525 |
+
pairs = torch.load(file_path)
|
526 |
+
all_pairs.extend(pairs)
|
|
|
|
|
|
|
527 |
|
528 |
+
# You may want to limit the data size for quicker testing
|
529 |
+
# all_pairs = all_pairs[:1000] # Uncomment to limit data size
|
530 |
|
531 |
+
log.append(f"Loaded a total of {len(all_pairs)} training pairs into memory.")
|
|
|
|
|
|
|
|
|
532 |
|
533 |
+
# Convert to HF dataset format
|
534 |
+
dataset_dict = {
|
535 |
+
"input_ids": [pair["source"] for pair in all_pairs],
|
536 |
+
"labels": [pair["target"] for pair in all_pairs]
|
537 |
+
}
|
538 |
|
539 |
+
train_dataset = Dataset.from_dict(dataset_dict)
|
|
|
540 |
|
541 |
+
# Create data collator for padding
|
542 |
+
from transformers import DataCollatorForLanguageModeling
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
543 |
|
544 |
+
data_collator = DataCollatorForLanguageModeling(
|
545 |
+
tokenizer=tokenizer,
|
546 |
+
mlm=False
|
547 |
+
)
|
548 |
|
|
|
|
|
|
|
|
|
549 |
except Exception as e:
|
550 |
error_msg = f"Error loading dataset: {str(e)}"
|
551 |
log.append(error_msg)
|
552 |
return "\n".join(log)
|
553 |
|
554 |
# --- Training Arguments ---
|
555 |
+
progress(0.75, desc="Setting up training arguments...")
|
556 |
output_dir = f"./results_{model_repo_name}"
|
557 |
os.makedirs(output_dir, exist_ok=True)
|
558 |
|
|
|
559 |
training_args = TrainingArguments(
|
560 |
output_dir=output_dir,
|
561 |
num_train_epochs=float(epochs),
|
|
|
564 |
learning_rate=learning_rate,
|
565 |
weight_decay=0.01,
|
566 |
logging_dir=f"{output_dir}/logs",
|
567 |
+
logging_steps=10,
|
568 |
+
save_steps=100,
|
569 |
+
save_total_limit=3,
|
570 |
remove_unused_columns=False,
|
571 |
push_to_hub=False,
|
572 |
disable_tqdm=False,
|
573 |
warmup_ratio=0.03,
|
574 |
lr_scheduler_type="cosine",
|
575 |
report_to="tensorboard",
|
576 |
+
bf16=True if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else False,
|
|
|
|
|
|
|
577 |
gradient_checkpointing=True,
|
578 |
gradient_checkpointing_kwargs={'use_reentrant': False},
|
579 |
+
ddp_find_unused_parameters=False,
|
580 |
+
deepspeed=ds_config_path if n_gpus > 1 else None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
581 |
)
|
582 |
|
583 |
# --- Initialize Trainer ---
|
584 |
+
progress(0.8, desc="Initializing trainer...")
|
585 |
+
trainer = Trainer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
586 |
model=model_to_train,
|
587 |
args=training_args,
|
588 |
train_dataset=train_dataset,
|
589 |
data_collator=data_collator,
|
590 |
)
|
591 |
|
592 |
+
log.append("Trainer initialized for training.")
|
593 |
|
594 |
# --- Start Training ---
|
595 |
+
# Clear cache before starting
|
596 |
+
gc.collect()
|
597 |
+
if torch.cuda.is_available():
|
598 |
+
torch.cuda.empty_cache()
|
599 |
+
|
600 |
try:
|
601 |
+
progress(0.85, desc="Starting training...")
|
602 |
+
log.append("Starting training...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
603 |
train_result = trainer.train()
|
|
|
604 |
progress(0.95, desc="Saving model...")
|
605 |
|
606 |
# Save final model (adapter weights) and training state
|
|
|
616 |
|
617 |
for key, value in metrics.items():
|
618 |
log.append(f"{key}: {value}")
|
|
|
|
|
|
|
|
|
|
|
619 |
|
620 |
except Exception as e:
|
621 |
+
error_msg = f"An error occurred during training: {e}"
|
622 |
log.append(error_msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
623 |
return "\n".join(log)
|
624 |
|
625 |
progress(1.0, desc="Training complete!")
|
626 |
+
log.append("Training process complete.")
|
627 |
return "\n".join(log)
|
628 |
|
629 |
# Define the Gradio interface
|
630 |
def create_interface():
|
631 |
+
with gr.Blocks(title="Llama 3.2 1B RVQ Fine-tuning") as demo:
|
632 |
+
gr.Markdown("# Llama 3.2 1B RVQ LoRA Fine-tuning")
|
633 |
+
gr.Markdown("Fine-tune a Llama 3.2 1B model with RVQ token embeddings using LoRA")
|
634 |
|
635 |
with gr.Row():
|
636 |
with gr.Column():
|
637 |
hf_username = gr.Textbox(label="HuggingFace Username", value="Twelve2five")
|
638 |
+
model_repo = gr.Textbox(label="Model Repository Name", value="llama-3.2-1b-rvq")
|
639 |
dataset_repo = gr.Textbox(label="Dataset Repository Name", value="podcast-dialogue-rvq-pairs-3items")
|
640 |
|
641 |
with gr.Column():
|
642 |
+
epochs = gr.Number(label="Number of Epochs", value=2, minimum=1, maximum=10)
|
643 |
+
batch_size = gr.Number(label="Batch Size per Device", value=2, minimum=1, maximum=8)
|
644 |
+
grad_accum = gr.Number(label="Gradient Accumulation Steps", value=4, minimum=1, maximum=16)
|
645 |
+
lr = gr.Number(label="Learning Rate", value=2e-4)
|
646 |
|
647 |
start_btn = gr.Button("Start Training")
|
648 |
output = gr.Textbox(label="Training Log", lines=20)
|