Twelve2five commited on
Commit
fd09ea6
·
verified ·
1 Parent(s): f38c379

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +261 -318
app.py CHANGED
@@ -330,49 +330,49 @@ def train_model(
330
  dataset_repo_name,
331
  epochs=1,
332
  batch_size=1,
333
- grad_accum_steps=4,
334
  learning_rate=1e-4,
335
  progress=gr.Progress()
336
  ):
337
  progress(0, desc="Setting up environment...")
338
  log = []
339
 
340
- # Completely clean up transformers installation
341
- log.append("Completely reinstalling transformers and dependencies...")
342
-
343
- # First uninstall any existing transformers
344
- subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", "transformers"])
345
-
346
- # Clean any cached files that might be causing issues
347
- cache_dirs = [
348
- os.path.expanduser("~/.cache/huggingface"),
349
- os.path.expanduser("~/.cache/pip")
350
- ]
351
-
352
- for cache_dir in cache_dirs:
353
- if os.path.exists(cache_dir):
354
- log.append(f"Cleaning cache directory: {cache_dir}")
355
- try:
356
- shutil.rmtree(cache_dir)
357
- except Exception as e:
358
- log.append(f"Warning: Could not clean {cache_dir}: {e}")
359
 
360
- # Install a stable version of transformers known to work with Llama models
361
- subprocess.check_call([sys.executable, "-m", "pip", "install", "transformers==4.35.2", "sentencepiece"])
 
 
 
 
362
 
363
- # Install other dependencies
364
- subprocess.check_call([sys.executable, "-m", "pip", "install", "-q",
365
- "accelerate", "bitsandbytes==0.41.1", "peft==0.6.1",
366
- "datasets", "huggingface_hub", "deepspeed==0.12.3"])
 
 
 
 
 
 
 
 
 
367
 
368
- # Now import everything after installation to ensure we use the correct versions
369
  try:
370
  from datasets import Dataset
371
  from huggingface_hub import snapshot_download
372
  import torch
373
  import transformers
374
  from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
375
- from transformers import BitsAndBytesConfig, TrainingArguments, Trainer
376
  from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
377
 
378
  log.append(f"Transformers version: {transformers.__version__}")
@@ -393,59 +393,60 @@ def train_model(
393
  n_gpus = torch.cuda.device_count()
394
  log.append(f"Number of GPUs available: {n_gpus}")
395
 
396
- # --- Quantization Configuration ---
397
- bnb_config = BitsAndBytesConfig(
398
- load_in_4bit=True,
399
- bnb_4bit_quant_type="nf4",
400
- bnb_4bit_compute_dtype=torch.bfloat16,
401
- bnb_4bit_use_double_quant=True,
402
- )
403
-
404
- # --- Load Base Model (with quantization) ---
405
  progress(0.1, desc="Loading base model...")
 
406
  try:
407
- # First try to download the repo without loading the model
408
- local_model_path = "./model_files"
409
- if os.path.exists(local_model_path):
410
- shutil.rmtree(local_model_path) # Clean up any previous files
411
-
412
  snapshot_download(
413
  repo_id=hf_model_repo_id,
414
  local_dir=local_model_path,
415
  local_dir_use_symlinks=False
416
  )
417
-
418
  log.append(f"Model files downloaded to {local_model_path}")
419
 
420
- # Check if this is a Llama model by looking at config.json
421
- if os.path.exists(os.path.join(local_model_path, "config.json")):
422
- with open(os.path.join(local_model_path, "config.json"), "r") as f:
423
- config_data = json.load(f)
424
- log.append(f"Model architecture type: {config_data.get('model_type', 'unknown')}")
425
-
426
- # Force model_type to llama
427
- config_data["model_type"] = "llama"
428
- if "architectures" in config_data:
429
- config_data["architectures"] = ["LlamaForCausalLM"]
430
-
431
- with open(os.path.join(local_model_path, "config.json"), "w") as f:
432
- json.dump(config_data, f)
433
- log.append("Updated config.json to use llama model_type")
434
 
435
- # Now try to load with explicit Llama classes
436
- config = LlamaConfig.from_pretrained(
437
- local_model_path,
438
- trust_remote_code=False
439
- )
440
 
 
 
 
 
 
 
 
 
 
 
 
441
  log.append(f"Successfully loaded config: {config.model_type}")
442
 
443
- # Load model with specific Llama class
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
  model = LlamaForCausalLM.from_pretrained(
445
  local_model_path,
446
  config=config,
447
  quantization_config=bnb_config,
448
  device_map="auto",
 
449
  torch_dtype=torch.bfloat16,
450
  low_cpu_mem_usage=True
451
  )
@@ -455,20 +456,7 @@ def train_model(
455
  except Exception as e:
456
  error_msg = f"Error loading model: {str(e)}"
457
  log.append(error_msg)
458
-
459
- # Try a fallback approach
460
- try:
461
- log.append("Trying fallback approach with AutoModelForCausalLM...")
462
- model = AutoModelForCausalLM.from_pretrained(
463
- local_model_path,
464
- device_map="auto",
465
- torch_dtype=torch.bfloat16,
466
- low_cpu_mem_usage=True
467
- )
468
- log.append(f"Fallback model loaded successfully")
469
- except Exception as e2:
470
- log.append(f"Fallback approach also failed: {str(e2)}")
471
- return "\n".join(log)
472
 
473
  # --- Prepare for K-bit Training & Apply LoRA ---
474
  progress(0.15, desc="Preparing model for fine-tuning...")
@@ -476,218 +464,138 @@ def train_model(
476
  model = prepare_model_for_kbit_training(model)
477
  log.append("Model prepared for k-bit training")
478
 
 
479
  lora_config = LoraConfig(
480
  task_type=TaskType.CAUSAL_LM,
481
- r=16,
482
- lora_alpha=32,
483
  lora_dropout=0.05,
484
  bias="none",
485
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
 
486
  )
 
 
487
  peft_model = get_peft_model(model, lora_config)
488
- trainable_params = peft_model.print_trainable_parameters()
489
- log.append(f"LoRA applied to model")
490
  model_to_train = peft_model
 
 
 
 
 
 
491
  except Exception as e:
492
  error_msg = f"Error preparing model for training: {str(e)}"
493
  log.append(error_msg)
494
  return "\n".join(log)
495
 
496
- # Cleanup
497
- gc.collect()
498
- if torch.cuda.is_available():
499
- torch.cuda.empty_cache()
500
-
501
- # --- Load Dataset from Hub ---
502
- progress(0.2, desc="Downloading dataset...")
503
- local_download_path = "./downloaded_dataset_files"
504
-
505
  try:
506
- downloaded_repo_root = snapshot_download(
 
 
507
  repo_id=hf_dataset_repo_id,
508
- repo_type="dataset",
509
- local_dir=local_download_path,
510
  local_dir_use_symlinks=False
511
  )
512
- log.append(f"Dataset repository content downloaded to: {downloaded_repo_root}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  except Exception as e:
514
- error_msg = f"Error downloading dataset repository from Hub: {e}"
515
  log.append(error_msg)
516
  return "\n".join(log)
517
 
518
- # --- Find and load the .pt files ---
519
- progress(0.25, desc="Finding dataset files...")
520
- pairs_dir = os.path.join(downloaded_repo_root, "final_rvq_pairs")
521
- all_pair_files = glob.glob(os.path.join(pairs_dir, "*_rvq_pairs.pt"))
522
-
523
- if not all_pair_files:
524
- all_pair_files = glob.glob(os.path.join(downloaded_repo_root, "*_rvq_pairs.pt"))
525
- if not all_pair_files:
526
- error_msg = "No RVQ pair files found in expected directories"
527
- log.append(error_msg)
528
- return "\n".join(log)
529
-
530
- log.append(f"Found {len(all_pair_files)} RVQ pair files.")
531
-
532
- # --- Load data from .pt files ---
533
- progress(0.3, desc="Loading dataset files...")
534
- all_data_pairs = []
535
- for i, file_path in enumerate(all_pair_files):
536
- progress(0.3 + (0.1 * i / len(all_pair_files)), desc=f"Loading file {i+1}/{len(all_pair_files)}")
537
- try:
538
- episode_pairs = torch.load(file_path, map_location='cpu')
539
- all_data_pairs.extend(episode_pairs)
540
- except Exception as e:
541
- log.append(f"Warning: Could not load file {file_path}: {e}")
542
-
543
- if not all_data_pairs:
544
- error_msg = "No valid data pairs were loaded"
545
- log.append(error_msg)
546
- return "\n".join(log)
547
-
548
- log.append(f"Loaded a total of {len(all_data_pairs)} training pairs into memory.")
549
-
550
- # --- Convert to HF Dataset ---
551
- progress(0.45, desc="Converting to Hugging Face Dataset...")
552
- def prepare_for_dataset(batch):
553
- output = {'input_ids': [], 'labels': []}
554
- for item in batch:
555
- output['input_ids'].append(item['input_ids'].cpu().tolist())
556
- output['labels'].append(item['labels'].cpu().tolist())
557
- return output
558
-
559
- chunk_size = 1000
560
- processed_data = {'input_ids': [], 'labels': []}
561
-
562
- total_chunks = len(range(0, len(all_data_pairs), chunk_size))
563
- for i in range(0, len(all_data_pairs), chunk_size):
564
- chunk_idx = i // chunk_size
565
- progress(0.45 + (0.1 * chunk_idx / total_chunks),
566
- desc=f"Processing chunk {chunk_idx+1}/{total_chunks}")
567
- batch = all_data_pairs[i:i + chunk_size]
568
- prepared_batch = prepare_for_dataset(batch)
569
- processed_data['input_ids'].extend(prepared_batch['input_ids'])
570
- processed_data['labels'].extend(prepared_batch['labels'])
571
-
572
- hf_dataset = Dataset.from_dict(processed_data)
573
-
574
- # Transform to get tensors back
575
- hf_dataset.set_transform(lambda batch: {
576
- 'input_ids': [torch.tensor(ids, dtype=torch.long) for ids in batch['input_ids']],
577
- 'labels': [torch.tensor(lbls, dtype=torch.long) for lbls in batch['labels']]
578
- })
579
-
580
- train_dataset = hf_dataset
581
-
582
- # Cleanup
583
- del all_data_pairs
584
- del processed_data
585
- gc.collect()
586
-
587
- # --- Define Data Collator ---
588
- progress(0.55, desc="Defining data collator...")
589
- def seq2seq_causal_collator(features):
590
- batch = {}
591
- concatenated_input_ids = []
592
- concatenated_labels = []
593
- max_len = 0
594
-
595
- # First pass: Concatenate, create masked labels, find max length
596
- for feature in features:
597
- input_ids = feature['input_ids']
598
- labels = feature['labels']
599
-
600
- if input_ids.dim() > 1: input_ids = input_ids.squeeze()
601
- if labels.dim() > 1: labels = labels.squeeze()
602
-
603
- context_len = input_ids.shape[0]
604
- target_len = labels.shape[0]
605
-
606
- combined_ids = torch.cat([input_ids, labels], dim=0)
607
- concatenated_input_ids.append(combined_ids)
608
-
609
- masked_labels = torch.cat([
610
- torch.full((context_len,), -100, dtype=torch.long, device=input_ids.device),
611
- labels
612
- ], dim=0)
613
- concatenated_labels.append(masked_labels)
614
-
615
- if combined_ids.shape[0] > max_len:
616
- max_len = combined_ids.shape[0]
617
-
618
- # Second pass: Pad to max length
619
- padded_input_ids = []
620
- padded_labels = []
621
- input_pad_token_id = 0
622
- label_pad_token_id = -100
623
-
624
- for i in range(len(features)):
625
- ids = concatenated_input_ids[i]
626
- lbls = concatenated_labels[i]
627
-
628
- padding_len = max_len - ids.shape[0]
629
-
630
- padded_input_ids.append(torch.nn.functional.pad(
631
- ids, (0, padding_len), value=input_pad_token_id
632
- ))
633
- padded_labels.append(torch.nn.functional.pad(
634
- lbls, (0, padding_len), value=label_pad_token_id
635
- ))
636
-
637
- # Stack and create final batch
638
- batch['input_ids'] = torch.stack(padded_input_ids)
639
- batch['labels'] = torch.stack(padded_labels)
640
- batch['attention_mask'] = batch['input_ids'].ne(input_pad_token_id).long()
641
-
642
- return batch
643
-
644
- data_collator = seq2seq_causal_collator
645
-
646
- # --- DeepSpeed Configuration ---
647
- # Create DeepSpeed config file directly in Python instead of loading from a file
648
- progress(0.15, desc="Setting up DeepSpeed configuration...")
649
-
650
- ds_config = {
651
- "fp16": {
652
- "enabled": False
653
- },
654
- "bf16": {
655
- "enabled": True
656
- },
657
- "zero_optimization": {
658
- "stage": 3,
659
- "offload_optimizer": {
660
- "device": "cpu",
661
- "pin_memory": True
662
- },
663
- "offload_param": {
664
- "device": "cpu",
665
- "pin_memory": True
666
- },
667
- "overlap_comm": True,
668
- "contiguous_gradients": True,
669
- "reduce_bucket_size": "auto",
670
- "stage3_prefetch_bucket_size": "auto",
671
- "stage3_param_persistence_threshold": "auto"
672
- },
673
- "gradient_accumulation_steps": grad_accum_steps,
674
- "train_micro_batch_size_per_gpu": batch_size,
675
- "gradient_clipping": 1.0,
676
- "steps_per_print": 10
677
- }
678
-
679
- # Save the config to a file
680
- with open("ds_config.json", "w") as f:
681
- json.dump(ds_config, f, indent=4)
682
-
683
- log.append("DeepSpeed configuration created successfully")
684
-
685
  # --- Training Arguments ---
686
- progress(0.75, desc="Setting up training arguments...")
687
  output_dir = f"./results_{model_repo_name}"
688
  os.makedirs(output_dir, exist_ok=True)
689
 
690
- # Create training arguments without DeepSpeed first
691
  training_args = TrainingArguments(
692
  output_dir=output_dir,
693
  num_train_epochs=float(epochs),
@@ -696,71 +604,90 @@ def train_model(
696
  learning_rate=learning_rate,
697
  weight_decay=0.01,
698
  logging_dir=f"{output_dir}/logs",
699
- logging_steps=10,
700
- save_steps=100,
701
- save_total_limit=3,
702
  remove_unused_columns=False,
703
  push_to_hub=False,
704
  disable_tqdm=False,
705
  warmup_ratio=0.03,
706
  lr_scheduler_type="cosine",
707
  report_to="tensorboard",
708
- bf16=True if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else False,
 
 
 
709
  gradient_checkpointing=True,
710
  gradient_checkpointing_kwargs={'use_reentrant': False},
 
 
 
 
 
 
 
 
711
 
712
- # For multi-GPU - use a different approach for DeepSpeed
713
- ddp_find_unused_parameters=False,
 
 
 
 
 
 
 
714
  )
715
 
716
- # Now initialize DeepSpeed separately
717
- if n_gpus > 1:
718
- log.append("Setting up DeepSpeed for multi-GPU training")
719
- try:
720
- import deepspeed
721
- from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
722
-
723
- # Modify the trainer to use DeepSpeed
724
- trainer_kwargs = {
725
- "model": model_to_train,
726
- "args": training_args,
727
- "train_dataset": train_dataset,
728
- "data_collator": data_collator,
729
- "deepspeed": ds_config, # Pass the config as a dict
730
- }
 
731
 
732
- trainer = Trainer(**trainer_kwargs)
733
- log.append("Trainer initialized with DeepSpeed for multi-GPU training")
734
- except Exception as e:
735
- log.append(f"Warning: Could not initialize DeepSpeed: {e}")
736
- # Fallback to standard distributed training
737
- trainer = Trainer(
738
- model=model_to_train,
739
- args=training_args,
740
- train_dataset=train_dataset,
741
- data_collator=data_collator,
742
- )
743
- log.append("Falling back to standard distributed training")
744
- else:
745
- # Single GPU setup
746
- trainer = Trainer(
747
- model=model_to_train,
748
- args=training_args,
749
- train_dataset=train_dataset,
750
- data_collator=data_collator,
751
- )
752
- log.append("Trainer initialized for single GPU training")
753
 
754
- # --- Start Training ---
755
- # Clear cache before starting
756
- gc.collect()
757
- if torch.cuda.is_available():
758
- torch.cuda.empty_cache()
759
 
 
760
  try:
761
- progress(0.85, desc="Starting training...")
762
- log.append("Starting distributed training on multiple GPUs...")
 
 
 
 
 
 
 
 
 
 
 
763
  train_result = trainer.train()
 
764
  progress(0.95, desc="Saving model...")
765
 
766
  # Save final model (adapter weights) and training state
@@ -776,21 +703,37 @@ def train_model(
776
 
777
  for key, value in metrics.items():
778
  log.append(f"{key}: {value}")
 
 
 
 
 
779
 
780
  except Exception as e:
781
- error_msg = f"An error occurred during training: {e}"
782
  log.append(error_msg)
 
 
 
 
 
 
 
 
 
 
 
783
  return "\n".join(log)
784
 
785
  progress(1.0, desc="Training complete!")
786
- log.append("Multi-GPU training process complete.")
787
  return "\n".join(log)
788
 
789
  # Define the Gradio interface
790
  def create_interface():
791
  with gr.Blocks(title="Llama 3 8B RVQ Fine-tuning") as demo:
792
  gr.Markdown("# Llama 3 8B RVQ LoRA Fine-tuning")
793
- gr.Markdown("Fine-tune a Llama 3 8B model with RVQ token embeddings using LoRA on multiple GPUs")
794
 
795
  with gr.Row():
796
  with gr.Column():
@@ -801,7 +744,7 @@ def create_interface():
801
  with gr.Column():
802
  epochs = gr.Number(label="Number of Epochs", value=1, minimum=1, maximum=10)
803
  batch_size = gr.Number(label="Batch Size per Device", value=1, minimum=1, maximum=8)
804
- grad_accum = gr.Number(label="Gradient Accumulation Steps", value=4, minimum=1, maximum=16)
805
  lr = gr.Number(label="Learning Rate", value=1e-4)
806
 
807
  start_btn = gr.Button("Start Training")
 
330
  dataset_repo_name,
331
  epochs=1,
332
  batch_size=1,
333
+ grad_accum_steps=16, # Increased from 8 to 16
334
  learning_rate=1e-4,
335
  progress=gr.Progress()
336
  ):
337
  progress(0, desc="Setting up environment...")
338
  log = []
339
 
340
+ # Aggressive memory cleanup
341
+ gc.collect()
342
+ if torch.cuda.is_available():
343
+ torch.cuda.empty_cache()
344
+ # Reset peak memory stats
345
+ torch.cuda.reset_peak_memory_stats()
 
 
 
 
 
 
 
 
 
 
 
 
 
346
 
347
+ # Clean up any existing model files to save space
348
+ if os.path.exists("./model_files"):
349
+ try:
350
+ shutil.rmtree("./model_files")
351
+ except Exception as e:
352
+ log.append(f"Warning: Could not remove existing model files: {e}")
353
 
354
+ if os.path.exists("./downloaded_dataset_files"):
355
+ try:
356
+ shutil.rmtree("./downloaded_dataset_files")
357
+ except Exception as e:
358
+ log.append(f"Warning: Could not remove existing dataset files: {e}")
359
+
360
+ # Print GPU info
361
+ if torch.cuda.is_available():
362
+ log.append(f"Available GPUs: {torch.cuda.device_count()}")
363
+ for i in range(torch.cuda.device_count()):
364
+ gpu_name = torch.cuda.get_device_name(i)
365
+ gpu_memory = torch.cuda.get_device_properties(i).total_memory / (1024**3)
366
+ log.append(f"GPU {i}: {gpu_name} with {gpu_memory:.2f} GB")
367
 
368
+ # Import required libraries
369
  try:
370
  from datasets import Dataset
371
  from huggingface_hub import snapshot_download
372
  import torch
373
  import transformers
374
  from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
375
+ from transformers import BitsAndBytesConfig, TrainingArguments, Trainer, AutoTokenizer
376
  from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
377
 
378
  log.append(f"Transformers version: {transformers.__version__}")
 
393
  n_gpus = torch.cuda.device_count()
394
  log.append(f"Number of GPUs available: {n_gpus}")
395
 
396
+ # --- Load Base Model (with extreme quantization) ---
 
 
 
 
 
 
 
 
397
  progress(0.1, desc="Loading base model...")
398
+ local_model_path = "./model_files"
399
  try:
400
+ # Download the model files
 
 
 
 
401
  snapshot_download(
402
  repo_id=hf_model_repo_id,
403
  local_dir=local_model_path,
404
  local_dir_use_symlinks=False
405
  )
 
406
  log.append(f"Model files downloaded to {local_model_path}")
407
 
408
+ # Ensure model_type is set correctly in the config
409
+ config_path = os.path.join(local_model_path, "config.json")
410
+ with open(config_path, "r") as f:
411
+ config_data = json.load(f)
 
 
 
 
 
 
 
 
 
 
412
 
413
+ model_type = config_data.get("model_type", "")
414
+ log.append(f"Model architecture type: {model_type}")
 
 
 
415
 
416
+ # Force model_type to be "llama" if it's not already
417
+ if model_type != "llama":
418
+ config_data["model_type"] = "llama"
419
+ # Also ensure architectures is set correctly
420
+ config_data["architectures"] = ["LlamaForCausalLM"]
421
+ with open(config_path, "w") as f:
422
+ json.dump(config_data, f, indent=2)
423
+ log.append("Updated config.json to use llama model_type")
424
+
425
+ # Load the config first
426
+ config = LlamaConfig.from_pretrained(local_model_path)
427
  log.append(f"Successfully loaded config: {config.model_type}")
428
 
429
+ # Use 4-bit quantization for extreme memory savings
430
+ bnb_config = BitsAndBytesConfig(
431
+ load_in_4bit=True,
432
+ bnb_4bit_use_double_quant=True,
433
+ bnb_4bit_quant_type="nf4",
434
+ bnb_4bit_compute_dtype=torch.bfloat16
435
+ )
436
+
437
+ # Load tokenizer first (needed for dataset preparation)
438
+ tokenizer = AutoTokenizer.from_pretrained(local_model_path)
439
+
440
+ # Explicit device map to enable CPU offloading
441
+ max_memory = {0: "40GB", "cpu": "64GB"}
442
+
443
+ # Load the model with extreme memory optimization
444
  model = LlamaForCausalLM.from_pretrained(
445
  local_model_path,
446
  config=config,
447
  quantization_config=bnb_config,
448
  device_map="auto",
449
+ max_memory=max_memory,
450
  torch_dtype=torch.bfloat16,
451
  low_cpu_mem_usage=True
452
  )
 
456
  except Exception as e:
457
  error_msg = f"Error loading model: {str(e)}"
458
  log.append(error_msg)
459
+ return "\n".join(log)
 
 
 
 
 
 
 
 
 
 
 
 
 
460
 
461
  # --- Prepare for K-bit Training & Apply LoRA ---
462
  progress(0.15, desc="Preparing model for fine-tuning...")
 
464
  model = prepare_model_for_kbit_training(model)
465
  log.append("Model prepared for k-bit training")
466
 
467
+ # Use minimal LoRA configuration with fewer parameters
468
  lora_config = LoraConfig(
469
  task_type=TaskType.CAUSAL_LM,
470
+ r=8, # Reduced from 16 to 8
471
+ lora_alpha=16, # Reduced from 32 to 16
472
  lora_dropout=0.05,
473
  bias="none",
474
+ # Target only key modules to reduce memory usage
475
+ target_modules=["q_proj", "v_proj"] # Reduced target modules
476
  )
477
+
478
+ # Apply LoRA
479
  peft_model = get_peft_model(model, lora_config)
 
 
480
  model_to_train = peft_model
481
+ log.append("LoRA applied to model")
482
+
483
+ # Free memory
484
+ gc.collect()
485
+ if torch.cuda.is_available():
486
+ torch.cuda.empty_cache()
487
  except Exception as e:
488
  error_msg = f"Error preparing model for training: {str(e)}"
489
  log.append(error_msg)
490
  return "\n".join(log)
491
 
492
+ # --- Download and Process Dataset ---
493
+ progress(0.2, desc="Loading dataset...")
 
 
 
 
 
 
 
494
  try:
495
+ # Download the dataset files
496
+ dataset_dir = os.path.join(os.getcwd(), "downloaded_dataset_files")
497
+ snapshot_download(
498
  repo_id=hf_dataset_repo_id,
499
+ local_dir=dataset_dir,
 
500
  local_dir_use_symlinks=False
501
  )
502
+ log.append(f"Dataset repository content downloaded to: {dataset_dir}")
503
+
504
+ # Find all RVQ pair files
505
+ rvq_pair_files = glob.glob(os.path.join(dataset_dir, "*_rvq_pairs.pt"))
506
+ log.append(f"Found {len(rvq_pair_files)} RVQ pair files.")
507
+
508
+ # Load training pairs from the dataset
509
+ training_pairs = []
510
+
511
+ # For memory conservation, use only half the dataset for now
512
+ max_file_count = min(12, len(rvq_pair_files))
513
+
514
+ for i, pair_file in enumerate(rvq_pair_files[:max_file_count]):
515
+ try:
516
+ pairs = torch.load(pair_file)
517
+ training_pairs.extend(pairs)
518
+ except Exception as e:
519
+ log.append(f"Warning: Could not load {pair_file}: {e}")
520
+
521
+ log.append(f"Loaded a total of {len(training_pairs)} training pairs into memory.")
522
+
523
+ # Prepare dataset
524
+ dataset = Dataset.from_dict({
525
+ "input_ids": [pair[0].tolist() for pair in training_pairs],
526
+ "labels": [pair[1].tolist() for pair in training_pairs]
527
+ })
528
+
529
+ # Clear the training_pairs to free memory
530
+ training_pairs = None
531
+ gc.collect()
532
+ torch.cuda.empty_cache()
533
+
534
+ # Use a smaller max_length to reduce memory pressure
535
+ max_length = 512 # Reduced max sequence length
536
+
537
+ # Create data collator that handles padding
538
+ def data_collator(examples):
539
+ # Convert lists back to tensors
540
+ for i in range(len(examples)):
541
+ examples[i]["input_ids"] = torch.tensor(examples[i]["input_ids"], dtype=torch.long)
542
+ examples[i]["labels"] = torch.tensor(examples[i]["labels"], dtype=torch.long)
543
+
544
+ # Get max length in this batch
545
+ batch_max_length = min(
546
+ max(len(example["input_ids"]) for example in examples),
547
+ max_length
548
+ )
549
+
550
+ batch = {
551
+ "input_ids": [],
552
+ "attention_mask": [],
553
+ "labels": []
554
+ }
555
+
556
+ # Prepare sequences
557
+ for example in examples:
558
+ input_ids = example["input_ids"][:batch_max_length]
559
+ labels = example["labels"][:batch_max_length]
560
+
561
+ # Pad sequences
562
+ padding_length = batch_max_length - len(input_ids)
563
+ attention_mask = torch.ones_like(input_ids)
564
+
565
+ if padding_length > 0:
566
+ padding = torch.ones(padding_length, dtype=input_ids.dtype) * tokenizer.pad_token_id
567
+ input_ids = torch.cat([input_ids, padding])
568
+ labels = torch.cat([labels, padding * -100]) # -100 to ignore in loss computation
569
+ attention_mask = torch.cat([attention_mask, torch.zeros(padding_length, dtype=attention_mask.dtype)])
570
+
571
+ batch["input_ids"].append(input_ids)
572
+ batch["attention_mask"].append(attention_mask)
573
+ batch["labels"].append(labels)
574
+
575
+ # Convert lists to tensors
576
+ for key in batch:
577
+ batch[key] = torch.stack(batch[key])
578
+
579
+ return batch
580
+
581
+ # Convert to training dataset
582
+ train_dataset = dataset
583
+
584
+ # Free memory
585
+ del dataset
586
+ gc.collect()
587
+ torch.cuda.empty_cache()
588
  except Exception as e:
589
+ error_msg = f"Error loading dataset: {str(e)}"
590
  log.append(error_msg)
591
  return "\n".join(log)
592
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593
  # --- Training Arguments ---
594
+ progress(0.3, desc="Setting up training arguments...")
595
  output_dir = f"./results_{model_repo_name}"
596
  os.makedirs(output_dir, exist_ok=True)
597
 
598
+ # Super-aggressive memory conservation
599
  training_args = TrainingArguments(
600
  output_dir=output_dir,
601
  num_train_epochs=float(epochs),
 
604
  learning_rate=learning_rate,
605
  weight_decay=0.01,
606
  logging_dir=f"{output_dir}/logs",
607
+ logging_steps=1, # Log frequently to see progress
608
+ save_steps=25, # Save checkpoints more frequently
609
+ save_total_limit=1, # Keep only one checkpoint to save space
610
  remove_unused_columns=False,
611
  push_to_hub=False,
612
  disable_tqdm=False,
613
  warmup_ratio=0.03,
614
  lr_scheduler_type="cosine",
615
  report_to="tensorboard",
616
+ bf16=True,
617
+ fp16=False,
618
+
619
+ # Memory optimization
620
  gradient_checkpointing=True,
621
  gradient_checkpointing_kwargs={'use_reentrant': False},
622
+ max_grad_norm=0.3, # Reduced from default 1.0
623
+ dataloader_pin_memory=False, # Reduce memory pressure
624
+
625
+ # Optimizer settings for memory efficiency
626
+ optim="adamw_torch",
627
+ adam_beta1=0.9,
628
+ adam_beta2=0.999,
629
+ adam_epsilon=1e-8,
630
 
631
+ # Evaluation settings
632
+ do_eval=False,
633
+ evaluation_strategy="no",
634
+
635
+ # Set this for smaller chunks of data processing
636
+ dataloader_num_workers=1,
637
+
638
+ # For memory efficiency when loading datasets
639
+ dataloader_drop_last=True,
640
  )
641
 
642
+ # --- Initialize Trainer ---
643
+ progress(0.4, desc="Initializing trainer...")
644
+
645
+ # Use optimizer that requires less memory
646
+ class MemoryEfficientTrainer(Trainer):
647
+ def create_optimizer(self):
648
+ # Create optimizer with reduced memory footprint
649
+ optimizer = super().create_optimizer()
650
+ # Force optimizer to use CPU offloading for states
651
+ for param_group in optimizer.param_groups:
652
+ for param in param_group['params']:
653
+ if param.requires_grad:
654
+ param.data = param.data.to("cpu")
655
+ if param.grad is not None:
656
+ param.grad.data = param.grad.data.to("cpu")
657
+ return optimizer
658
 
659
+ def training_step(self, *args, **kwargs):
660
+ # Memory cleanup before each training step
661
+ gc.collect()
662
+ torch.cuda.empty_cache()
663
+ return super().training_step(*args, **kwargs)
664
+
665
+ trainer = MemoryEfficientTrainer(
666
+ model=model_to_train,
667
+ args=training_args,
668
+ train_dataset=train_dataset,
669
+ data_collator=data_collator,
670
+ )
 
 
 
 
 
 
 
 
 
671
 
672
+ log.append("Trainer initialized with memory-efficient settings")
 
 
 
 
673
 
674
+ # --- Start Training ---
675
  try:
676
+ # Final memory cleanup before training
677
+ gc.collect()
678
+ if torch.cuda.is_available():
679
+ torch.cuda.empty_cache()
680
+
681
+ progress(0.5, desc="Starting training...")
682
+ log.append("Starting training with extreme memory optimization...")
683
+
684
+ # Train in smaller chunks to manage memory better
685
+ total_steps = len(train_dataset) // (batch_size * grad_accum_steps)
686
+ log.append(f"Total training steps: {total_steps}")
687
+
688
+ # Train the model
689
  train_result = trainer.train()
690
+
691
  progress(0.95, desc="Saving model...")
692
 
693
  # Save final model (adapter weights) and training state
 
703
 
704
  for key, value in metrics.items():
705
  log.append(f"{key}: {value}")
706
+
707
+ # Print peak memory usage
708
+ if torch.cuda.is_available():
709
+ peak_memory = torch.cuda.max_memory_allocated() / (1024**3)
710
+ log.append(f"Peak GPU memory usage: {peak_memory:.2f} GB")
711
 
712
  except Exception as e:
713
+ error_msg = f"An error occurred during training: {str(e)}"
714
  log.append(error_msg)
715
+
716
+ # Try to save checkpoint even if training failed
717
+ try:
718
+ # Save whatever we have
719
+ log.append("Attempting to save partial checkpoint...")
720
+ emergency_save_path = os.path.join(training_args.output_dir, "emergency_checkpoint")
721
+ trainer.save_model(emergency_save_path)
722
+ log.append(f"Saved emergency checkpoint to {emergency_save_path}")
723
+ except Exception as save_error:
724
+ log.append(f"Could not save emergency checkpoint: {save_error}")
725
+
726
  return "\n".join(log)
727
 
728
  progress(1.0, desc="Training complete!")
729
+ log.append("Training process complete successfully.")
730
  return "\n".join(log)
731
 
732
  # Define the Gradio interface
733
  def create_interface():
734
  with gr.Blocks(title="Llama 3 8B RVQ Fine-tuning") as demo:
735
  gr.Markdown("# Llama 3 8B RVQ LoRA Fine-tuning")
736
+ gr.Markdown("Fine-tune a Llama 3 8B model with RVQ token embeddings using LoRA with extreme memory optimization")
737
 
738
  with gr.Row():
739
  with gr.Column():
 
744
  with gr.Column():
745
  epochs = gr.Number(label="Number of Epochs", value=1, minimum=1, maximum=10)
746
  batch_size = gr.Number(label="Batch Size per Device", value=1, minimum=1, maximum=8)
747
+ grad_accum = gr.Number(label="Gradient Accumulation Steps", value=16, minimum=8, maximum=32)
748
  lr = gr.Number(label="Learning Rate", value=1e-4)
749
 
750
  start_btn = gr.Button("Start Training")