Twelve2five commited on
Commit
1c688b1
·
verified ·
1 Parent(s): fd09ea6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -233
app.py CHANGED
@@ -27,7 +27,7 @@ import shutil
27
 
28
  # --- Configuration ---
29
  YOUR_HF_USERNAME = "Twelve2five"
30
- MODEL_REPO_NAME = "llama-3-8b-rvq-resized"
31
  DATASET_REPO_NAME = "podcast-dialogue-rvq-pairs-3items"
32
 
33
  hf_model_repo_id = f"{YOUR_HF_USERNAME}/{MODEL_REPO_NAME}"
@@ -329,21 +329,14 @@ def train_model(
329
  model_repo_name,
330
  dataset_repo_name,
331
  epochs=1,
332
- batch_size=1,
333
- grad_accum_steps=16, # Increased from 8 to 16
334
- learning_rate=1e-4,
335
  progress=gr.Progress()
336
  ):
337
  progress(0, desc="Setting up environment...")
338
  log = []
339
 
340
- # Aggressive memory cleanup
341
- gc.collect()
342
- if torch.cuda.is_available():
343
- torch.cuda.empty_cache()
344
- # Reset peak memory stats
345
- torch.cuda.reset_peak_memory_stats()
346
-
347
  # Clean up any existing model files to save space
348
  if os.path.exists("./model_files"):
349
  try:
@@ -371,8 +364,8 @@ def train_model(
371
  from huggingface_hub import snapshot_download
372
  import torch
373
  import transformers
374
- from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
375
- from transformers import BitsAndBytesConfig, TrainingArguments, Trainer, AutoTokenizer
376
  from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
377
 
378
  log.append(f"Transformers version: {transformers.__version__}")
@@ -393,40 +386,65 @@ def train_model(
393
  n_gpus = torch.cuda.device_count()
394
  log.append(f"Number of GPUs available: {n_gpus}")
395
 
396
- # --- Load Base Model (with extreme quantization) ---
397
- progress(0.1, desc="Loading base model...")
398
- local_model_path = "./model_files"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  try:
400
- # Download the model files
 
401
  snapshot_download(
402
  repo_id=hf_model_repo_id,
403
  local_dir=local_model_path,
404
- local_dir_use_symlinks=False
405
  )
406
  log.append(f"Model files downloaded to {local_model_path}")
407
 
408
- # Ensure model_type is set correctly in the config
409
  config_path = os.path.join(local_model_path, "config.json")
410
  with open(config_path, "r") as f:
411
  config_data = json.load(f)
412
 
413
- model_type = config_data.get("model_type", "")
 
414
  log.append(f"Model architecture type: {model_type}")
415
 
416
- # Force model_type to be "llama" if it's not already
417
- if model_type != "llama":
418
- config_data["model_type"] = "llama"
419
- # Also ensure architectures is set correctly
420
- config_data["architectures"] = ["LlamaForCausalLM"]
421
- with open(config_path, "w") as f:
422
- json.dump(config_data, f, indent=2)
423
- log.append("Updated config.json to use llama model_type")
424
-
425
- # Load the config first
426
- config = LlamaConfig.from_pretrained(local_model_path)
427
- log.append(f"Successfully loaded config: {config.model_type}")
428
-
429
- # Use 4-bit quantization for extreme memory savings
430
  bnb_config = BitsAndBytesConfig(
431
  load_in_4bit=True,
432
  bnb_4bit_use_double_quant=True,
@@ -434,25 +452,28 @@ def train_model(
434
  bnb_4bit_compute_dtype=torch.bfloat16
435
  )
436
 
437
- # Load tokenizer first (needed for dataset preparation)
438
- tokenizer = AutoTokenizer.from_pretrained(local_model_path)
439
-
440
- # Explicit device map to enable CPU offloading
441
- max_memory = {0: "40GB", "cpu": "64GB"}
442
-
443
- # Load the model with extreme memory optimization
444
- model = LlamaForCausalLM.from_pretrained(
445
  local_model_path,
446
- config=config,
447
  quantization_config=bnb_config,
448
  device_map="auto",
449
- max_memory=max_memory,
450
- torch_dtype=torch.bfloat16,
451
- low_cpu_mem_usage=True
452
  )
453
 
454
- log.append(f"Loaded model vocab size: {model.config.vocab_size}")
455
- log.append(f"Input embedding shape: {model.get_input_embeddings().weight.shape}")
 
 
 
 
 
 
 
 
 
 
 
456
  except Exception as e:
457
  error_msg = f"Error loading model: {str(e)}"
458
  log.append(error_msg)
@@ -464,138 +485,77 @@ def train_model(
464
  model = prepare_model_for_kbit_training(model)
465
  log.append("Model prepared for k-bit training")
466
 
467
- # Use minimal LoRA configuration with fewer parameters
468
  lora_config = LoraConfig(
469
  task_type=TaskType.CAUSAL_LM,
470
- r=8, # Reduced from 16 to 8
471
- lora_alpha=16, # Reduced from 32 to 16
472
  lora_dropout=0.05,
473
  bias="none",
474
- # Target only key modules to reduce memory usage
475
- target_modules=["q_proj", "v_proj"] # Reduced target modules
476
  )
477
-
478
- # Apply LoRA
479
  peft_model = get_peft_model(model, lora_config)
 
 
480
  model_to_train = peft_model
481
- log.append("LoRA applied to model")
482
-
483
- # Free memory
484
- gc.collect()
485
- if torch.cuda.is_available():
486
- torch.cuda.empty_cache()
487
  except Exception as e:
488
  error_msg = f"Error preparing model for training: {str(e)}"
489
  log.append(error_msg)
490
  return "\n".join(log)
491
 
492
- # --- Download and Process Dataset ---
493
- progress(0.2, desc="Loading dataset...")
494
  try:
495
- # Download the dataset files
496
- dataset_dir = os.path.join(os.getcwd(), "downloaded_dataset_files")
497
  snapshot_download(
498
  repo_id=hf_dataset_repo_id,
499
- local_dir=dataset_dir,
500
- local_dir_use_symlinks=False
501
  )
502
- log.append(f"Dataset repository content downloaded to: {dataset_dir}")
503
 
504
- # Find all RVQ pair files
505
- rvq_pair_files = glob.glob(os.path.join(dataset_dir, "*_rvq_pairs.pt"))
506
  log.append(f"Found {len(rvq_pair_files)} RVQ pair files.")
507
 
508
- # Load training pairs from the dataset
509
- training_pairs = []
510
-
511
- # For memory conservation, use only half the dataset for now
512
- max_file_count = min(12, len(rvq_pair_files))
513
 
514
- for i, pair_file in enumerate(rvq_pair_files[:max_file_count]):
515
- try:
516
- pairs = torch.load(pair_file)
517
- training_pairs.extend(pairs)
518
- except Exception as e:
519
- log.append(f"Warning: Could not load {pair_file}: {e}")
520
 
521
- log.append(f"Loaded a total of {len(training_pairs)} training pairs into memory.")
 
522
 
523
- # Prepare dataset
524
- dataset = Dataset.from_dict({
525
- "input_ids": [pair[0].tolist() for pair in training_pairs],
526
- "labels": [pair[1].tolist() for pair in training_pairs]
527
- })
528
 
529
- # Clear the training_pairs to free memory
530
- training_pairs = None
531
- gc.collect()
532
- torch.cuda.empty_cache()
 
533
 
534
- # Use a smaller max_length to reduce memory pressure
535
- max_length = 512 # Reduced max sequence length
536
 
537
- # Create data collator that handles padding
538
- def data_collator(examples):
539
- # Convert lists back to tensors
540
- for i in range(len(examples)):
541
- examples[i]["input_ids"] = torch.tensor(examples[i]["input_ids"], dtype=torch.long)
542
- examples[i]["labels"] = torch.tensor(examples[i]["labels"], dtype=torch.long)
543
-
544
- # Get max length in this batch
545
- batch_max_length = min(
546
- max(len(example["input_ids"]) for example in examples),
547
- max_length
548
- )
549
-
550
- batch = {
551
- "input_ids": [],
552
- "attention_mask": [],
553
- "labels": []
554
- }
555
-
556
- # Prepare sequences
557
- for example in examples:
558
- input_ids = example["input_ids"][:batch_max_length]
559
- labels = example["labels"][:batch_max_length]
560
-
561
- # Pad sequences
562
- padding_length = batch_max_length - len(input_ids)
563
- attention_mask = torch.ones_like(input_ids)
564
-
565
- if padding_length > 0:
566
- padding = torch.ones(padding_length, dtype=input_ids.dtype) * tokenizer.pad_token_id
567
- input_ids = torch.cat([input_ids, padding])
568
- labels = torch.cat([labels, padding * -100]) # -100 to ignore in loss computation
569
- attention_mask = torch.cat([attention_mask, torch.zeros(padding_length, dtype=attention_mask.dtype)])
570
-
571
- batch["input_ids"].append(input_ids)
572
- batch["attention_mask"].append(attention_mask)
573
- batch["labels"].append(labels)
574
-
575
- # Convert lists to tensors
576
- for key in batch:
577
- batch[key] = torch.stack(batch[key])
578
-
579
- return batch
580
 
581
- # Convert to training dataset
582
- train_dataset = dataset
 
 
583
 
584
- # Free memory
585
- del dataset
586
- gc.collect()
587
- torch.cuda.empty_cache()
588
  except Exception as e:
589
  error_msg = f"Error loading dataset: {str(e)}"
590
  log.append(error_msg)
591
  return "\n".join(log)
592
 
593
  # --- Training Arguments ---
594
- progress(0.3, desc="Setting up training arguments...")
595
  output_dir = f"./results_{model_repo_name}"
596
  os.makedirs(output_dir, exist_ok=True)
597
 
598
- # Super-aggressive memory conservation
599
  training_args = TrainingArguments(
600
  output_dir=output_dir,
601
  num_train_epochs=float(epochs),
@@ -604,90 +564,43 @@ def train_model(
604
  learning_rate=learning_rate,
605
  weight_decay=0.01,
606
  logging_dir=f"{output_dir}/logs",
607
- logging_steps=1, # Log frequently to see progress
608
- save_steps=25, # Save checkpoints more frequently
609
- save_total_limit=1, # Keep only one checkpoint to save space
610
  remove_unused_columns=False,
611
  push_to_hub=False,
612
  disable_tqdm=False,
613
  warmup_ratio=0.03,
614
  lr_scheduler_type="cosine",
615
  report_to="tensorboard",
616
- bf16=True,
617
- fp16=False,
618
-
619
- # Memory optimization
620
  gradient_checkpointing=True,
621
  gradient_checkpointing_kwargs={'use_reentrant': False},
622
- max_grad_norm=0.3, # Reduced from default 1.0
623
- dataloader_pin_memory=False, # Reduce memory pressure
624
-
625
- # Optimizer settings for memory efficiency
626
- optim="adamw_torch",
627
- adam_beta1=0.9,
628
- adam_beta2=0.999,
629
- adam_epsilon=1e-8,
630
-
631
- # Evaluation settings
632
- do_eval=False,
633
- evaluation_strategy="no",
634
-
635
- # Set this for smaller chunks of data processing
636
- dataloader_num_workers=1,
637
-
638
- # For memory efficiency when loading datasets
639
- dataloader_drop_last=True,
640
  )
641
 
642
  # --- Initialize Trainer ---
643
- progress(0.4, desc="Initializing trainer...")
644
-
645
- # Use optimizer that requires less memory
646
- class MemoryEfficientTrainer(Trainer):
647
- def create_optimizer(self):
648
- # Create optimizer with reduced memory footprint
649
- optimizer = super().create_optimizer()
650
- # Force optimizer to use CPU offloading for states
651
- for param_group in optimizer.param_groups:
652
- for param in param_group['params']:
653
- if param.requires_grad:
654
- param.data = param.data.to("cpu")
655
- if param.grad is not None:
656
- param.grad.data = param.grad.data.to("cpu")
657
- return optimizer
658
-
659
- def training_step(self, *args, **kwargs):
660
- # Memory cleanup before each training step
661
- gc.collect()
662
- torch.cuda.empty_cache()
663
- return super().training_step(*args, **kwargs)
664
-
665
- trainer = MemoryEfficientTrainer(
666
  model=model_to_train,
667
  args=training_args,
668
  train_dataset=train_dataset,
669
  data_collator=data_collator,
670
  )
671
 
672
- log.append("Trainer initialized with memory-efficient settings")
673
 
674
  # --- Start Training ---
 
 
 
 
 
675
  try:
676
- # Final memory cleanup before training
677
- gc.collect()
678
- if torch.cuda.is_available():
679
- torch.cuda.empty_cache()
680
-
681
- progress(0.5, desc="Starting training...")
682
- log.append("Starting training with extreme memory optimization...")
683
-
684
- # Train in smaller chunks to manage memory better
685
- total_steps = len(train_dataset) // (batch_size * grad_accum_steps)
686
- log.append(f"Total training steps: {total_steps}")
687
-
688
- # Train the model
689
  train_result = trainer.train()
690
-
691
  progress(0.95, desc="Saving model...")
692
 
693
  # Save final model (adapter weights) and training state
@@ -703,49 +616,33 @@ def train_model(
703
 
704
  for key, value in metrics.items():
705
  log.append(f"{key}: {value}")
706
-
707
- # Print peak memory usage
708
- if torch.cuda.is_available():
709
- peak_memory = torch.cuda.max_memory_allocated() / (1024**3)
710
- log.append(f"Peak GPU memory usage: {peak_memory:.2f} GB")
711
 
712
  except Exception as e:
713
- error_msg = f"An error occurred during training: {str(e)}"
714
  log.append(error_msg)
715
-
716
- # Try to save checkpoint even if training failed
717
- try:
718
- # Save whatever we have
719
- log.append("Attempting to save partial checkpoint...")
720
- emergency_save_path = os.path.join(training_args.output_dir, "emergency_checkpoint")
721
- trainer.save_model(emergency_save_path)
722
- log.append(f"Saved emergency checkpoint to {emergency_save_path}")
723
- except Exception as save_error:
724
- log.append(f"Could not save emergency checkpoint: {save_error}")
725
-
726
  return "\n".join(log)
727
 
728
  progress(1.0, desc="Training complete!")
729
- log.append("Training process complete successfully.")
730
  return "\n".join(log)
731
 
732
  # Define the Gradio interface
733
  def create_interface():
734
- with gr.Blocks(title="Llama 3 8B RVQ Fine-tuning") as demo:
735
- gr.Markdown("# Llama 3 8B RVQ LoRA Fine-tuning")
736
- gr.Markdown("Fine-tune a Llama 3 8B model with RVQ token embeddings using LoRA with extreme memory optimization")
737
 
738
  with gr.Row():
739
  with gr.Column():
740
  hf_username = gr.Textbox(label="HuggingFace Username", value="Twelve2five")
741
- model_repo = gr.Textbox(label="Model Repository Name", value="llama-3-8b-rvq-resized")
742
  dataset_repo = gr.Textbox(label="Dataset Repository Name", value="podcast-dialogue-rvq-pairs-3items")
743
 
744
  with gr.Column():
745
- epochs = gr.Number(label="Number of Epochs", value=1, minimum=1, maximum=10)
746
- batch_size = gr.Number(label="Batch Size per Device", value=1, minimum=1, maximum=8)
747
- grad_accum = gr.Number(label="Gradient Accumulation Steps", value=16, minimum=8, maximum=32)
748
- lr = gr.Number(label="Learning Rate", value=1e-4)
749
 
750
  start_btn = gr.Button("Start Training")
751
  output = gr.Textbox(label="Training Log", lines=20)
 
27
 
28
  # --- Configuration ---
29
  YOUR_HF_USERNAME = "Twelve2five"
30
+ MODEL_REPO_NAME = "llama-3.2-1b-rvq"
31
  DATASET_REPO_NAME = "podcast-dialogue-rvq-pairs-3items"
32
 
33
  hf_model_repo_id = f"{YOUR_HF_USERNAME}/{MODEL_REPO_NAME}"
 
329
  model_repo_name,
330
  dataset_repo_name,
331
  epochs=1,
332
+ batch_size=2, # Increased batch size since model is smaller
333
+ grad_accum_steps=4,
334
+ learning_rate=2e-4, # Slightly higher learning rate for smaller model
335
  progress=gr.Progress()
336
  ):
337
  progress(0, desc="Setting up environment...")
338
  log = []
339
 
 
 
 
 
 
 
 
340
  # Clean up any existing model files to save space
341
  if os.path.exists("./model_files"):
342
  try:
 
364
  from huggingface_hub import snapshot_download
365
  import torch
366
  import transformers
367
+ from transformers import AutoModelForCausalLM, AutoTokenizer
368
+ from transformers import BitsAndBytesConfig, TrainingArguments, Trainer
369
  from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
370
 
371
  log.append(f"Transformers version: {transformers.__version__}")
 
386
  n_gpus = torch.cuda.device_count()
387
  log.append(f"Number of GPUs available: {n_gpus}")
388
 
389
+ # --- DeepSpeed Configuration ---
390
+ # Create DeepSpeed config file
391
+ progress(0.1, desc="Setting up DeepSpeed configuration...")
392
+
393
+ # Create a conservative DeepSpeed config for the smaller model
394
+ ds_config = {
395
+ "fp16": {
396
+ "enabled": "auto",
397
+ "loss_scale": 0,
398
+ "loss_scale_window": 1000,
399
+ "initial_scale_power": 16,
400
+ "hysteresis": 2,
401
+ "min_loss_scale": 1
402
+ },
403
+ "bf16": {
404
+ "enabled": "auto"
405
+ },
406
+ "zero_optimization": {
407
+ "stage": 2, # Lower stage for smaller model
408
+ "offload_optimizer": {
409
+ "device": "cpu",
410
+ "pin_memory": True
411
+ },
412
+ "contiguous_gradients": True,
413
+ "overlap_comm": True
414
+ },
415
+ "gradient_accumulation_steps": grad_accum_steps,
416
+ "gradient_clipping": 1.0,
417
+ "train_batch_size": batch_size * grad_accum_steps * max(1, n_gpus)
418
+ }
419
+
420
+ ds_config_path = "ds_config.json"
421
+ with open(ds_config_path, "w") as f:
422
+ json.dump(ds_config, f, indent=4)
423
+
424
+ log.append("DeepSpeed configuration created successfully")
425
+
426
+ # --- Download and Load Model ---
427
+ progress(0.15, desc="Downloading model...")
428
+
429
  try:
430
+ # Download model files
431
+ local_model_path = "./model_files"
432
  snapshot_download(
433
  repo_id=hf_model_repo_id,
434
  local_dir=local_model_path,
 
435
  )
436
  log.append(f"Model files downloaded to {local_model_path}")
437
 
438
+ # First, load the config
439
  config_path = os.path.join(local_model_path, "config.json")
440
  with open(config_path, "r") as f:
441
  config_data = json.load(f)
442
 
443
+ # Check model architecture
444
+ model_type = config_data.get("model_type", "").lower()
445
  log.append(f"Model architecture type: {model_type}")
446
 
447
+ # Set 4-bit quantization config
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  bnb_config = BitsAndBytesConfig(
449
  load_in_4bit=True,
450
  bnb_4bit_use_double_quant=True,
 
452
  bnb_4bit_compute_dtype=torch.bfloat16
453
  )
454
 
455
+ # Load the model with 4-bit quantization
456
+ model = AutoModelForCausalLM.from_pretrained(
 
 
 
 
 
 
457
  local_model_path,
 
458
  quantization_config=bnb_config,
459
  device_map="auto",
460
+ trust_remote_code=False,
461
+ use_cache=False # No KV caching during training
 
462
  )
463
 
464
+ # Load tokenizer
465
+ tokenizer = AutoTokenizer.from_pretrained(local_model_path)
466
+ if tokenizer.pad_token is None:
467
+ tokenizer.pad_token = tokenizer.eos_token
468
+
469
+ # Log model info
470
+ if hasattr(model, "config"):
471
+ log.append(f"Loaded model vocab size: {model.config.vocab_size}")
472
+ if hasattr(model, "get_input_embeddings"):
473
+ embedding = model.get_input_embeddings()
474
+ if hasattr(embedding, "weight"):
475
+ log.append(f"Input embedding shape: {embedding.weight.shape}")
476
+
477
  except Exception as e:
478
  error_msg = f"Error loading model: {str(e)}"
479
  log.append(error_msg)
 
485
  model = prepare_model_for_kbit_training(model)
486
  log.append("Model prepared for k-bit training")
487
 
488
+ # For Llama 3.2 1B the target modules might be slightly different
489
  lora_config = LoraConfig(
490
  task_type=TaskType.CAUSAL_LM,
491
+ r=8, # Reduced from 16 due to smaller model
492
+ lora_alpha=16, # Reduced from 32
493
  lora_dropout=0.05,
494
  bias="none",
495
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
 
496
  )
 
 
497
  peft_model = get_peft_model(model, lora_config)
498
+ trainable_params = peft_model.print_trainable_parameters()
499
+ log.append(f"LoRA applied to model")
500
  model_to_train = peft_model
 
 
 
 
 
 
501
  except Exception as e:
502
  error_msg = f"Error preparing model for training: {str(e)}"
503
  log.append(error_msg)
504
  return "\n".join(log)
505
 
506
+ # --- Load and Prepare Dataset ---
507
+ progress(0.3, desc="Loading and preparing dataset...")
508
  try:
509
+ # Download dataset
510
+ dataset_path = "./downloaded_dataset_files"
511
  snapshot_download(
512
  repo_id=hf_dataset_repo_id,
513
+ local_dir=dataset_path,
 
514
  )
515
+ log.append(f"Dataset repository content downloaded to: {dataset_path}")
516
 
517
+ # Find all RVQ pairs files
518
+ rvq_pair_files = glob.glob(os.path.join(dataset_path, "*_rvq_pairs.pt"))
519
  log.append(f"Found {len(rvq_pair_files)} RVQ pair files.")
520
 
521
+ # Load the pytorch files
522
+ all_pairs = []
 
 
 
523
 
524
+ for file_path in rvq_pair_files:
525
+ pairs = torch.load(file_path)
526
+ all_pairs.extend(pairs)
 
 
 
527
 
528
+ # You may want to limit the data size for quicker testing
529
+ # all_pairs = all_pairs[:1000] # Uncomment to limit data size
530
 
531
+ log.append(f"Loaded a total of {len(all_pairs)} training pairs into memory.")
 
 
 
 
532
 
533
+ # Convert to HF dataset format
534
+ dataset_dict = {
535
+ "input_ids": [pair["source"] for pair in all_pairs],
536
+ "labels": [pair["target"] for pair in all_pairs]
537
+ }
538
 
539
+ train_dataset = Dataset.from_dict(dataset_dict)
 
540
 
541
+ # Create data collator for padding
542
+ from transformers import DataCollatorForLanguageModeling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
 
544
+ data_collator = DataCollatorForLanguageModeling(
545
+ tokenizer=tokenizer,
546
+ mlm=False
547
+ )
548
 
 
 
 
 
549
  except Exception as e:
550
  error_msg = f"Error loading dataset: {str(e)}"
551
  log.append(error_msg)
552
  return "\n".join(log)
553
 
554
  # --- Training Arguments ---
555
+ progress(0.75, desc="Setting up training arguments...")
556
  output_dir = f"./results_{model_repo_name}"
557
  os.makedirs(output_dir, exist_ok=True)
558
 
 
559
  training_args = TrainingArguments(
560
  output_dir=output_dir,
561
  num_train_epochs=float(epochs),
 
564
  learning_rate=learning_rate,
565
  weight_decay=0.01,
566
  logging_dir=f"{output_dir}/logs",
567
+ logging_steps=10,
568
+ save_steps=100,
569
+ save_total_limit=3,
570
  remove_unused_columns=False,
571
  push_to_hub=False,
572
  disable_tqdm=False,
573
  warmup_ratio=0.03,
574
  lr_scheduler_type="cosine",
575
  report_to="tensorboard",
576
+ bf16=True if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else False,
 
 
 
577
  gradient_checkpointing=True,
578
  gradient_checkpointing_kwargs={'use_reentrant': False},
579
+ ddp_find_unused_parameters=False,
580
+ deepspeed=ds_config_path if n_gpus > 1 else None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581
  )
582
 
583
  # --- Initialize Trainer ---
584
+ progress(0.8, desc="Initializing trainer...")
585
+ trainer = Trainer(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586
  model=model_to_train,
587
  args=training_args,
588
  train_dataset=train_dataset,
589
  data_collator=data_collator,
590
  )
591
 
592
+ log.append("Trainer initialized for training.")
593
 
594
  # --- Start Training ---
595
+ # Clear cache before starting
596
+ gc.collect()
597
+ if torch.cuda.is_available():
598
+ torch.cuda.empty_cache()
599
+
600
  try:
601
+ progress(0.85, desc="Starting training...")
602
+ log.append("Starting training...")
 
 
 
 
 
 
 
 
 
 
 
603
  train_result = trainer.train()
 
604
  progress(0.95, desc="Saving model...")
605
 
606
  # Save final model (adapter weights) and training state
 
616
 
617
  for key, value in metrics.items():
618
  log.append(f"{key}: {value}")
 
 
 
 
 
619
 
620
  except Exception as e:
621
+ error_msg = f"An error occurred during training: {e}"
622
  log.append(error_msg)
 
 
 
 
 
 
 
 
 
 
 
623
  return "\n".join(log)
624
 
625
  progress(1.0, desc="Training complete!")
626
+ log.append("Training process complete.")
627
  return "\n".join(log)
628
 
629
  # Define the Gradio interface
630
  def create_interface():
631
+ with gr.Blocks(title="Llama 3.2 1B RVQ Fine-tuning") as demo:
632
+ gr.Markdown("# Llama 3.2 1B RVQ LoRA Fine-tuning")
633
+ gr.Markdown("Fine-tune a Llama 3.2 1B model with RVQ token embeddings using LoRA")
634
 
635
  with gr.Row():
636
  with gr.Column():
637
  hf_username = gr.Textbox(label="HuggingFace Username", value="Twelve2five")
638
+ model_repo = gr.Textbox(label="Model Repository Name", value="llama-3.2-1b-rvq")
639
  dataset_repo = gr.Textbox(label="Dataset Repository Name", value="podcast-dialogue-rvq-pairs-3items")
640
 
641
  with gr.Column():
642
+ epochs = gr.Number(label="Number of Epochs", value=2, minimum=1, maximum=10)
643
+ batch_size = gr.Number(label="Batch Size per Device", value=2, minimum=1, maximum=8)
644
+ grad_accum = gr.Number(label="Gradient Accumulation Steps", value=4, minimum=1, maximum=16)
645
+ lr = gr.Number(label="Learning Rate", value=2e-4)
646
 
647
  start_btn = gr.Button("Start Training")
648
  output = gr.Textbox(label="Training Log", lines=20)