Twelve2five commited on
Commit
fdebc65
·
verified ·
1 Parent(s): 1c688b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -82
app.py CHANGED
@@ -27,7 +27,7 @@ import shutil
27
 
28
  # --- Configuration ---
29
  YOUR_HF_USERNAME = "Twelve2five"
30
- MODEL_REPO_NAME = "llama-3.2-1b-rvq"
31
  DATASET_REPO_NAME = "podcast-dialogue-rvq-pairs-3items"
32
 
33
  hf_model_repo_id = f"{YOUR_HF_USERNAME}/{MODEL_REPO_NAME}"
@@ -329,9 +329,9 @@ def train_model(
329
  model_repo_name,
330
  dataset_repo_name,
331
  epochs=1,
332
- batch_size=2, # Increased batch size since model is smaller
333
  grad_accum_steps=4,
334
- learning_rate=2e-4, # Slightly higher learning rate for smaller model
335
  progress=gr.Progress()
336
  ):
337
  progress(0, desc="Setting up environment...")
@@ -350,7 +350,7 @@ def train_model(
350
  except Exception as e:
351
  log.append(f"Warning: Could not remove existing dataset files: {e}")
352
 
353
- # Print GPU info
354
  if torch.cuda.is_available():
355
  log.append(f"Available GPUs: {torch.cuda.device_count()}")
356
  for i in range(torch.cuda.device_count()):
@@ -362,7 +362,7 @@ def train_model(
362
  try:
363
  from datasets import Dataset
364
  from huggingface_hub import snapshot_download
365
- import torch
366
  import transformers
367
  from transformers import AutoModelForCausalLM, AutoTokenizer
368
  from transformers import BitsAndBytesConfig, TrainingArguments, Trainer
@@ -390,25 +390,13 @@ def train_model(
390
  # Create DeepSpeed config file
391
  progress(0.1, desc="Setting up DeepSpeed configuration...")
392
 
393
- # Create a conservative DeepSpeed config for the smaller model
394
  ds_config = {
395
- "fp16": {
396
- "enabled": "auto",
397
- "loss_scale": 0,
398
- "loss_scale_window": 1000,
399
- "initial_scale_power": 16,
400
- "hysteresis": 2,
401
- "min_loss_scale": 1
402
- },
403
  "bf16": {
404
  "enabled": "auto"
405
  },
406
  "zero_optimization": {
407
- "stage": 2, # Lower stage for smaller model
408
- "offload_optimizer": {
409
- "device": "cpu",
410
- "pin_memory": True
411
- },
412
  "contiguous_gradients": True,
413
  "overlap_comm": True
414
  },
@@ -432,113 +420,136 @@ def train_model(
432
  snapshot_download(
433
  repo_id=hf_model_repo_id,
434
  local_dir=local_model_path,
 
 
435
  )
436
  log.append(f"Model files downloaded to {local_model_path}")
437
 
438
- # First, load the config
439
- config_path = os.path.join(local_model_path, "config.json")
440
- with open(config_path, "r") as f:
441
- config_data = json.load(f)
442
-
443
- # Check model architecture
444
- model_type = config_data.get("model_type", "").lower()
445
- log.append(f"Model architecture type: {model_type}")
446
-
447
- # Set 4-bit quantization config
448
  bnb_config = BitsAndBytesConfig(
449
  load_in_4bit=True,
450
- bnb_4bit_use_double_quant=True,
451
  bnb_4bit_quant_type="nf4",
452
- bnb_4bit_compute_dtype=torch.bfloat16
 
453
  )
454
 
455
- # Load the model with 4-bit quantization
456
  model = AutoModelForCausalLM.from_pretrained(
457
  local_model_path,
458
  quantization_config=bnb_config,
459
  device_map="auto",
460
- trust_remote_code=False,
461
- use_cache=False # No KV caching during training
462
  )
463
-
464
- # Load tokenizer
465
  tokenizer = AutoTokenizer.from_pretrained(local_model_path)
 
 
466
  if tokenizer.pad_token is None:
467
  tokenizer.pad_token = tokenizer.eos_token
468
-
469
- # Log model info
470
- if hasattr(model, "config"):
471
- log.append(f"Loaded model vocab size: {model.config.vocab_size}")
472
- if hasattr(model, "get_input_embeddings"):
473
- embedding = model.get_input_embeddings()
474
- if hasattr(embedding, "weight"):
475
- log.append(f"Input embedding shape: {embedding.weight.shape}")
476
-
477
- except Exception as e:
478
- error_msg = f"Error loading model: {str(e)}"
479
- log.append(error_msg)
480
- return "\n".join(log)
481
-
482
- # --- Prepare for K-bit Training & Apply LoRA ---
483
- progress(0.15, desc="Preparing model for fine-tuning...")
484
- try:
485
  model = prepare_model_for_kbit_training(model)
486
  log.append("Model prepared for k-bit training")
487
 
488
- # For Llama 3.2 1B the target modules might be slightly different
489
  lora_config = LoraConfig(
490
  task_type=TaskType.CAUSAL_LM,
491
- r=8, # Reduced from 16 due to smaller model
492
- lora_alpha=16, # Reduced from 32
493
  lora_dropout=0.05,
494
  bias="none",
495
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
496
  )
497
  peft_model = get_peft_model(model, lora_config)
498
  trainable_params = peft_model.print_trainable_parameters()
499
  log.append(f"LoRA applied to model")
500
  model_to_train = peft_model
 
501
  except Exception as e:
502
  error_msg = f"Error preparing model for training: {str(e)}"
503
  log.append(error_msg)
504
  return "\n".join(log)
505
 
506
- # --- Load and Prepare Dataset ---
507
- progress(0.3, desc="Loading and preparing dataset...")
 
508
  try:
509
- # Download dataset
510
  dataset_path = "./downloaded_dataset_files"
511
  snapshot_download(
512
  repo_id=hf_dataset_repo_id,
513
  local_dir=dataset_path,
 
 
514
  )
515
  log.append(f"Dataset repository content downloaded to: {dataset_path}")
516
 
517
- # Find all RVQ pairs files
518
- rvq_pair_files = glob.glob(os.path.join(dataset_path, "*_rvq_pairs.pt"))
519
- log.append(f"Found {len(rvq_pair_files)} RVQ pair files.")
520
 
521
- # Load the pytorch files
522
- all_pairs = []
 
523
 
524
- for file_path in rvq_pair_files:
525
- pairs = torch.load(file_path)
 
526
  all_pairs.extend(pairs)
527
 
528
- # You may want to limit the data size for quicker testing
529
- # all_pairs = all_pairs[:1000] # Uncomment to limit data size
530
-
531
  log.append(f"Loaded a total of {len(all_pairs)} training pairs into memory.")
532
 
533
- # Convert to HF dataset format
534
- dataset_dict = {
535
- "input_ids": [pair["source"] for pair in all_pairs],
536
- "labels": [pair["target"] for pair in all_pairs]
537
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
 
539
- train_dataset = Dataset.from_dict(dataset_dict)
540
 
541
- # Create data collator for padding
542
  from transformers import DataCollatorForLanguageModeling
543
 
544
  data_collator = DataCollatorForLanguageModeling(
@@ -556,6 +567,7 @@ def train_model(
556
  output_dir = f"./results_{model_repo_name}"
557
  os.makedirs(output_dir, exist_ok=True)
558
 
 
559
  training_args = TrainingArguments(
560
  output_dir=output_dir,
561
  num_train_epochs=float(epochs),
@@ -574,10 +586,10 @@ def train_model(
574
  lr_scheduler_type="cosine",
575
  report_to="tensorboard",
576
  bf16=True if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else False,
577
- gradient_checkpointing=True,
578
  gradient_checkpointing_kwargs={'use_reentrant': False},
579
  ddp_find_unused_parameters=False,
580
- deepspeed=ds_config_path if n_gpus > 1 else None,
581
  )
582
 
583
  # --- Initialize Trainer ---
@@ -639,9 +651,9 @@ def create_interface():
639
  dataset_repo = gr.Textbox(label="Dataset Repository Name", value="podcast-dialogue-rvq-pairs-3items")
640
 
641
  with gr.Column():
642
- epochs = gr.Number(label="Number of Epochs", value=2, minimum=1, maximum=10)
643
- batch_size = gr.Number(label="Batch Size per Device", value=2, minimum=1, maximum=8)
644
- grad_accum = gr.Number(label="Gradient Accumulation Steps", value=4, minimum=1, maximum=16)
645
  lr = gr.Number(label="Learning Rate", value=2e-4)
646
 
647
  start_btn = gr.Button("Start Training")
 
27
 
28
  # --- Configuration ---
29
  YOUR_HF_USERNAME = "Twelve2five"
30
+ MODEL_REPO_NAME = "llama-3-8b-rvq-resized"
31
  DATASET_REPO_NAME = "podcast-dialogue-rvq-pairs-3items"
32
 
33
  hf_model_repo_id = f"{YOUR_HF_USERNAME}/{MODEL_REPO_NAME}"
 
329
  model_repo_name,
330
  dataset_repo_name,
331
  epochs=1,
332
+ batch_size=4, # Increased for A100
333
  grad_accum_steps=4,
334
+ learning_rate=2e-4,
335
  progress=gr.Progress()
336
  ):
337
  progress(0, desc="Setting up environment...")
 
350
  except Exception as e:
351
  log.append(f"Warning: Could not remove existing dataset files: {e}")
352
 
353
+ # Print GPU info - using imported torch, not a local variable
354
  if torch.cuda.is_available():
355
  log.append(f"Available GPUs: {torch.cuda.device_count()}")
356
  for i in range(torch.cuda.device_count()):
 
362
  try:
363
  from datasets import Dataset
364
  from huggingface_hub import snapshot_download
365
+ # Don't import torch again, since it's already imported
366
  import transformers
367
  from transformers import AutoModelForCausalLM, AutoTokenizer
368
  from transformers import BitsAndBytesConfig, TrainingArguments, Trainer
 
390
  # Create DeepSpeed config file
391
  progress(0.1, desc="Setting up DeepSpeed configuration...")
392
 
393
+ # Create a simpler config since we have plenty of memory on A100
394
  ds_config = {
 
 
 
 
 
 
 
 
395
  "bf16": {
396
  "enabled": "auto"
397
  },
398
  "zero_optimization": {
399
+ "stage": 1, # Lower stage is fine for A100-80GB
 
 
 
 
400
  "contiguous_gradients": True,
401
  "overlap_comm": True
402
  },
 
420
  snapshot_download(
421
  repo_id=hf_model_repo_id,
422
  local_dir=local_model_path,
423
+ use_auth_token=False,
424
+ resume_download=True
425
  )
426
  log.append(f"Model files downloaded to {local_model_path}")
427
 
428
+ # Create a bnb configuration for loading the model in 4-bit
429
+ # Not strictly necessary for A100 but keeps memory usage lower
430
+ progress(0.25, desc="Loading model...")
 
 
 
 
 
 
 
431
  bnb_config = BitsAndBytesConfig(
432
  load_in_4bit=True,
 
433
  bnb_4bit_quant_type="nf4",
434
+ bnb_4bit_compute_dtype=torch.bfloat16,
435
+ bnb_4bit_use_double_quant=False
436
  )
437
 
438
+ # Load model and tokenizer
439
  model = AutoModelForCausalLM.from_pretrained(
440
  local_model_path,
441
  quantization_config=bnb_config,
442
  device_map="auto",
443
+ torch_dtype=torch.bfloat16,
 
444
  )
 
 
445
  tokenizer = AutoTokenizer.from_pretrained(local_model_path)
446
+
447
+ # Handle tokenizer settings
448
  if tokenizer.pad_token is None:
449
  tokenizer.pad_token = tokenizer.eos_token
450
+
451
+ log.append(f"Loaded model vocab size: {tokenizer.vocab_size}")
452
+ log.append(f"Input embedding shape: {model.get_input_embeddings().weight.shape}")
453
+
454
+ # PEFT Configuration (Smaller LoRA for faster iteration)
 
 
 
 
 
 
 
 
 
 
 
 
455
  model = prepare_model_for_kbit_training(model)
456
  log.append("Model prepared for k-bit training")
457
 
 
458
  lora_config = LoraConfig(
459
  task_type=TaskType.CAUSAL_LM,
460
+ r=16, # Keeping higher rank for A100
461
+ lora_alpha=32,
462
  lora_dropout=0.05,
463
  bias="none",
464
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"] # Fewer modules for faster training
465
  )
466
  peft_model = get_peft_model(model, lora_config)
467
  trainable_params = peft_model.print_trainable_parameters()
468
  log.append(f"LoRA applied to model")
469
  model_to_train = peft_model
470
+
471
  except Exception as e:
472
  error_msg = f"Error preparing model for training: {str(e)}"
473
  log.append(error_msg)
474
  return "\n".join(log)
475
 
476
+ # --- Download and Process Dataset ---
477
+ progress(0.4, desc="Downloading dataset...")
478
+
479
  try:
 
480
  dataset_path = "./downloaded_dataset_files"
481
  snapshot_download(
482
  repo_id=hf_dataset_repo_id,
483
  local_dir=dataset_path,
484
+ use_auth_token=False,
485
+ resume_download=True
486
  )
487
  log.append(f"Dataset repository content downloaded to: {dataset_path}")
488
 
489
+ # Load dataset from PT files
490
+ progress(0.5, desc="Processing dataset...")
 
491
 
492
+ # Load RVQ pairs
493
+ pair_files = glob.glob(f"{dataset_path}/*_rvq_pairs.pt")
494
+ log.append(f"Found {len(pair_files)} RVQ pair files.")
495
 
496
+ all_pairs = []
497
+ for file in pair_files:
498
+ pairs = torch.load(file)
499
  all_pairs.extend(pairs)
500
 
 
 
 
501
  log.append(f"Loaded a total of {len(all_pairs)} training pairs into memory.")
502
 
503
+ # Process pairs into a format suitable for training
504
+ all_texts = []
505
+ for pair in all_pairs:
506
+ # Create instruction format
507
+ if isinstance(pair, dict):
508
+ instruction = pair.get("instruction", "")
509
+ input_text = pair.get("input", "")
510
+ output = pair.get("output", "")
511
+
512
+ # ALPACA format
513
+ if instruction and input_text:
514
+ text = f"### Instruction: {instruction}\n### Input: {input_text}\n### Response: {output}"
515
+ elif instruction:
516
+ text = f"### Instruction: {instruction}\n### Response: {output}"
517
+ else:
518
+ text = output
519
+ else:
520
+ # Simple prompt-completion format
521
+ if isinstance(pair, tuple) and len(pair) == 2:
522
+ prompt, completion = pair
523
+ text = f"{prompt}{completion}"
524
+ else:
525
+ text = str(pair)
526
+
527
+ all_texts.append({"text": text})
528
+
529
+ # Create HF dataset
530
+ train_dataset = Dataset.from_list(all_texts)
531
+
532
+ # Function to tokenize the dataset
533
+ def tokenize_function(examples):
534
+ return tokenizer(
535
+ examples["text"],
536
+ padding=False,
537
+ truncation=True,
538
+ max_length=2048,
539
+ return_tensors=None,
540
+ )
541
+
542
+ # Tokenize the dataset
543
+ tokenized_dataset = train_dataset.map(
544
+ tokenize_function,
545
+ batched=True,
546
+ remove_columns=["text"],
547
+ desc="Tokenizing dataset",
548
+ )
549
 
550
+ train_dataset = tokenized_dataset
551
 
552
+ # Data collator
553
  from transformers import DataCollatorForLanguageModeling
554
 
555
  data_collator = DataCollatorForLanguageModeling(
 
567
  output_dir = f"./results_{model_repo_name}"
568
  os.makedirs(output_dir, exist_ok=True)
569
 
570
+ # Optimize settings for A100
571
  training_args = TrainingArguments(
572
  output_dir=output_dir,
573
  num_train_epochs=float(epochs),
 
586
  lr_scheduler_type="cosine",
587
  report_to="tensorboard",
588
  bf16=True if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else False,
589
+ gradient_checkpointing=True, # Still useful for efficiency
590
  gradient_checkpointing_kwargs={'use_reentrant': False},
591
  ddp_find_unused_parameters=False,
592
+ deepspeed=ds_config_path if n_gpus > 1 else None, # Only use DeepSpeed for multi-GPU
593
  )
594
 
595
  # --- Initialize Trainer ---
 
651
  dataset_repo = gr.Textbox(label="Dataset Repository Name", value="podcast-dialogue-rvq-pairs-3items")
652
 
653
  with gr.Column():
654
+ epochs = gr.Number(label="Number of Epochs", value=3, minimum=1, maximum=10)
655
+ batch_size = gr.Number(label="Batch Size per Device", value=4, minimum=1, maximum=16)
656
+ grad_accum = gr.Number(label="Gradient Accumulation Steps", value=2, minimum=1, maximum=16)
657
  lr = gr.Number(label="Learning Rate", value=2e-4)
658
 
659
  start_btn = gr.Button("Start Training")