Twelve2five commited on
Commit
30d7ae2
·
verified ·
1 Parent(s): 26c97a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +420 -341
app.py CHANGED
@@ -1,352 +1,431 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- Script for fine-tuning Llama-3-8B with RVQ tokens on multiple GPUs
4
- """
5
-
6
- # Basic setup and installations
7
- !pip install -q -U transformers accelerate bitsandbytes peft torch datasets huggingface_hub deepspeed
8
-
9
- # No need for notebook_login on Hugging Face platform
10
- # Authentication is handled automatically
11
-
12
  import torch
13
- from transformers import AutoModelForCausalLM, BitsAndBytesConfig
14
- from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
15
  import gc
16
- import os
17
  from datasets import Dataset
18
  from huggingface_hub import snapshot_download
19
- import glob
20
- from tqdm import tqdm
21
-
22
- # --- Configuration ---
23
- YOUR_HF_USERNAME = "Twelve2five"
24
- MODEL_REPO_NAME = "llama-3-8b-rvq-resized"
25
- DATASET_REPO_NAME = "podcast-dialogue-rvq-pairs-3items"
26
-
27
- hf_model_repo_id = f"{YOUR_HF_USERNAME}/{MODEL_REPO_NAME}"
28
- hf_dataset_repo_id = f"{YOUR_HF_USERNAME}/{DATASET_REPO_NAME}"
29
-
30
- # Check if running on multiple GPUs
31
- n_gpus = torch.cuda.device_count()
32
- print(f"Number of GPUs available: {n_gpus}")
33
-
34
- # --- Quantization Configuration ---
35
- bnb_config = BitsAndBytesConfig(
36
- load_in_4bit=True,
37
- bnb_4bit_quant_type="nf4",
38
- bnb_4bit_compute_dtype=torch.bfloat16,
39
- bnb_4bit_use_double_quant=True,
40
- )
41
 
42
- # --- Load Base Model (with quantization) ---
43
- try:
44
- # For multi-GPU QLoRA, we'll use device_map="auto" and let DeepSpeed handle distribution later
45
- model = AutoModelForCausalLM.from_pretrained(
46
- hf_model_repo_id,
47
- quantization_config=bnb_config,
48
- device_map="auto", # Will be overridden by DeepSpeed
49
- trust_remote_code=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  )
51
- print(f"Loaded model vocab size: {model.config.vocab_size}")
52
- print(f"Input embedding shape: {model.get_input_embeddings().weight.shape}")
53
- except Exception as e:
54
- print(f"Error loading model from Hub: {e}")
55
- raise SystemExit("Model loading failed.")
56
-
57
- # --- Prepare for K-bit Training & Apply LoRA ---
58
- model = prepare_model_for_kbit_training(model)
59
-
60
- lora_config = LoraConfig(
61
- task_type=TaskType.CAUSAL_LM,
62
- r=16,
63
- lora_alpha=32,
64
- lora_dropout=0.05,
65
- bias="none",
66
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
67
- )
68
- peft_model = get_peft_model(model, lora_config)
69
- peft_model.print_trainable_parameters()
70
- model_to_train = peft_model
71
-
72
- # Cleanup
73
- gc.collect()
74
- if torch.cuda.is_available():
75
- torch.cuda.empty_cache()
76
-
77
- # --- Load Dataset from Hub ---
78
- local_download_path = "./downloaded_dataset_files"
79
-
80
- try:
81
- downloaded_repo_root = snapshot_download(
82
- repo_id=hf_dataset_repo_id,
83
- repo_type="dataset",
84
- local_dir=local_download_path,
85
- local_dir_use_symlinks=False
86
  )
87
- print(f"Dataset repository content downloaded to: {downloaded_repo_root}")
88
- except Exception as e:
89
- print(f"Error downloading dataset repository from Hub: {e}")
90
- raise SystemExit("Dataset download failed.")
91
-
92
- # --- Find and load the .pt files ---
93
- pairs_dir = os.path.join(downloaded_repo_root, "final_rvq_pairs")
94
- all_pair_files = glob.glob(os.path.join(pairs_dir, "*_rvq_pairs.pt"))
95
-
96
- if not all_pair_files:
97
- all_pair_files = glob.glob(os.path.join(downloaded_repo_root, "*_rvq_pairs.pt"))
98
- if not all_pair_files:
99
- raise FileNotFoundError(f"No RVQ pair files found in expected directories")
100
-
101
- print(f"Found {len(all_pair_files)} RVQ pair files.")
102
-
103
- # --- Load data from .pt files ---
104
- all_data_pairs = []
105
- for file_path in tqdm(all_pair_files, desc="Loading pair files"):
106
  try:
107
- episode_pairs = torch.load(file_path, map_location='cpu')
108
- all_data_pairs.extend(episode_pairs)
 
 
 
 
 
109
  except Exception as e:
110
- print(f"Warning: Could not load file {file_path}: {e}")
111
-
112
- if not all_data_pairs:
113
- raise ValueError("No valid data pairs were loaded")
114
-
115
- print(f"Loaded a total of {len(all_data_pairs)} training pairs into memory.")
116
-
117
- # --- Convert to HF Dataset ---
118
- def prepare_for_dataset(batch):
119
- output = {'input_ids': [], 'labels': []}
120
- for item in batch:
121
- output['input_ids'].append(item['input_ids'].cpu().tolist())
122
- output['labels'].append(item['labels'].cpu().tolist())
123
- return output
124
-
125
- chunk_size = 1000
126
- processed_data = {'input_ids': [], 'labels': []}
127
- for i in tqdm(range(0, len(all_data_pairs), chunk_size), desc="Preparing data for Dataset"):
128
- batch = all_data_pairs[i:i + chunk_size]
129
- prepared_batch = prepare_for_dataset(batch)
130
- processed_data['input_ids'].extend(prepared_batch['input_ids'])
131
- processed_data['labels'].extend(prepared_batch['labels'])
132
-
133
- hf_dataset = Dataset.from_dict(processed_data)
134
-
135
- # Transform to get tensors back
136
- hf_dataset.set_transform(lambda batch: {
137
- 'input_ids': [torch.tensor(ids, dtype=torch.long) for ids in batch['input_ids']],
138
- 'labels': [torch.tensor(lbls, dtype=torch.long) for lbls in batch['labels']]
139
- })
140
-
141
- train_dataset = hf_dataset
142
-
143
- # Cleanup
144
- del all_data_pairs
145
- del processed_data
146
- gc.collect()
147
-
148
- # --- Define Data Collator ---
149
- def seq2seq_causal_collator(features):
150
- batch = {}
151
- concatenated_input_ids = []
152
- concatenated_labels = []
153
- max_len = 0
154
-
155
- # First pass: Concatenate, create masked labels, find max length
156
- for feature in features:
157
- input_ids = feature['input_ids']
158
- labels = feature['labels']
159
-
160
- if input_ids.dim() > 1: input_ids = input_ids.squeeze()
161
- if labels.dim() > 1: labels = labels.squeeze()
162
-
163
- context_len = input_ids.shape[0]
164
- target_len = labels.shape[0]
165
-
166
- combined_ids = torch.cat([input_ids, labels], dim=0)
167
- concatenated_input_ids.append(combined_ids)
168
-
169
- masked_labels = torch.cat([
170
- torch.full((context_len,), -100, dtype=torch.long, device=input_ids.device),
171
- labels
172
- ], dim=0)
173
- concatenated_labels.append(masked_labels)
174
-
175
- if combined_ids.shape[0] > max_len:
176
- max_len = combined_ids.shape[0]
177
-
178
- # Second pass: Pad to max length
179
- padded_input_ids = []
180
- padded_labels = []
181
- input_pad_token_id = 0
182
- label_pad_token_id = -100
183
-
184
- for i in range(len(features)):
185
- ids = concatenated_input_ids[i]
186
- lbls = concatenated_labels[i]
187
-
188
- padding_len = max_len - ids.shape[0]
189
-
190
- padded_input_ids.append(torch.nn.functional.pad(
191
- ids, (0, padding_len), value=input_pad_token_id
192
- ))
193
- padded_labels.append(torch.nn.functional.pad(
194
- lbls, (0, padding_len), value=label_pad_token_id
195
- ))
196
-
197
- # Stack and create final batch
198
- batch['input_ids'] = torch.stack(padded_input_ids)
199
- batch['labels'] = torch.stack(padded_labels)
200
- batch['attention_mask'] = batch['input_ids'].ne(input_pad_token_id).long()
201
-
202
- return batch
203
-
204
- data_collator = seq2seq_causal_collator
205
-
206
- # --- Define Training Arguments and Initialize Trainer ---
207
- from transformers import TrainingArguments, Trainer
208
- import math
209
-
210
- # Output directories
211
- OUTPUT_TRAINING_DIR = "./llama3-8b-rvq-qlora-finetuned-run"
212
- LOGGING_DIR = "./llama3-8b-rvq-qlora-logs-run"
213
-
214
- # Training parameters - adjusted for 4x T4 GPUs
215
- NUM_EPOCHS = 1
216
- # Scale down per-device batch size since we have multiple GPUs now
217
- BATCH_SIZE_PER_DEVICE = 1 # Smaller per-device batch size to avoid OOM
218
- GRAD_ACCUMULATION_STEPS = 4
219
- LEARNING_RATE = 1e-4
220
- WEIGHT_DECAY = 0.01
221
- WARMUP_RATIO = 0.03
222
- LR_SCHEDULER = "cosine"
223
- OPTIMIZER = "paged_adamw_8bit"
224
-
225
- # Calculate total steps and warmup steps
226
- # Total batch size is now batch_size × num_gpus × grad_accum_steps
227
- total_train_batch_size = BATCH_SIZE_PER_DEVICE * n_gpus * GRAD_ACCUMULATION_STEPS
228
- num_training_steps = math.ceil((len(train_dataset) * NUM_EPOCHS) / total_train_batch_size)
229
- num_warmup_steps = int(num_training_steps * WARMUP_RATIO)
230
-
231
- # Logging/Saving frequency
232
- steps_per_epoch = math.ceil(len(train_dataset) / total_train_batch_size)
233
- LOGGING_STEPS = max(10, steps_per_epoch // 15)
234
- SAVE_STEPS = max(50, steps_per_epoch // 10)
235
-
236
- print(f"Dataset size: {len(train_dataset)}")
237
- print(f"Number of GPUs: {n_gpus}")
238
- print(f"Batch size per device: {BATCH_SIZE_PER_DEVICE}")
239
- print(f"Gradient Accumulation steps: {GRAD_ACCUMULATION_STEPS}")
240
- print(f"Total train batch size (effective): {total_train_batch_size}")
241
- print(f"Total optimization steps: {num_training_steps}")
242
- print(f"Warmup steps: {num_warmup_steps}")
243
-
244
- # Configure for multi-GPU training using DeepSpeed
245
- training_args = TrainingArguments(
246
- output_dir=OUTPUT_TRAINING_DIR,
247
- num_train_epochs=NUM_EPOCHS,
248
- per_device_train_batch_size=BATCH_SIZE_PER_DEVICE,
249
- gradient_accumulation_steps=GRAD_ACCUMULATION_STEPS,
250
- optim=OPTIMIZER,
251
- logging_dir=LOGGING_DIR,
252
- logging_strategy="steps",
253
- logging_steps=LOGGING_STEPS,
254
- save_strategy="steps",
255
- save_steps=SAVE_STEPS,
256
- save_total_limit=2,
257
- learning_rate=LEARNING_RATE,
258
- weight_decay=WEIGHT_DECAY,
259
- warmup_steps=num_warmup_steps,
260
- lr_scheduler_type=LR_SCHEDULER,
261
- report_to="tensorboard",
262
- bf16=True if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else False,
263
- gradient_checkpointing=True,
264
- gradient_checkpointing_kwargs={'use_reentrant': False},
265
-
266
- # Multi-GPU specific settings
267
- deepspeed="ds_config.json", # We'll create this file below
268
- ddp_find_unused_parameters=False,
269
- )
270
-
271
- # --- Create DeepSpeed configuration file ---
272
- import json
273
-
274
- # DeepSpeed ZeRO-3 config optimized for T4 GPUs
275
- ds_config = {
276
- "fp16": {
277
- "enabled": "auto",
278
- "loss_scale": 0,
279
- "loss_scale_window": 1000,
280
- "initial_scale_power": 16,
281
- "hysteresis": 2,
282
- "min_loss_scale": 1
283
- },
284
- "bf16": {
285
- "enabled": "auto"
286
- },
287
- "zero_optimization": {
288
- "stage": 3,
289
- "offload_optimizer": {
290
- "device": "cpu",
291
- "pin_memory": True
292
  },
293
- "offload_param": {
294
- "device": "cpu",
295
- "pin_memory": True
296
  },
297
- "overlap_comm": True,
298
- "contiguous_gradients": True,
299
- "reduce_bucket_size": "auto",
300
- "stage3_prefetch_bucket_size": "auto",
301
- "stage3_param_persistence_threshold": "auto",
302
- "gather_16bit_weights_on_model_save": True,
303
- "stage3_max_live_parameters": 1e9,
304
- "stage3_max_reuse_distance": 1e9
305
- },
306
- "gradient_accumulation_steps": GRAD_ACCUMULATION_STEPS,
307
- "gradient_clipping": "auto",
308
- "steps_per_print": 10,
309
- "train_batch_size": "auto",
310
- "train_micro_batch_size_per_gpu": "auto",
311
- "wall_clock_breakdown": False
312
- }
313
-
314
- with open("ds_config.json", "w") as f:
315
- json.dump(ds_config, f, indent=4)
316
-
317
- # --- Initialize Trainer ---
318
- trainer = Trainer(
319
- model=model_to_train,
320
- args=training_args,
321
- train_dataset=train_dataset,
322
- data_collator=data_collator,
323
- )
324
-
325
- print("Trainer initialized with DeepSpeed for multi-GPU training.")
326
-
327
- # --- Start Training ---
328
- # Clear cache before starting
329
- gc.collect()
330
- if torch.cuda.is_available():
331
- torch.cuda.empty_cache()
332
-
333
- try:
334
- print("Starting distributed training on multiple GPUs...")
335
- train_result = trainer.train()
336
-
337
- # Save final model (adapter weights) and training state
338
- final_save_path = os.path.join(training_args.output_dir, "final_checkpoint")
339
- print(f"Saving final model checkpoint to {final_save_path}...")
340
- trainer.save_model(final_save_path)
341
- trainer.save_state()
342
-
343
- # Log metrics
344
- metrics = train_result.metrics
345
- trainer.log_metrics("train", metrics)
346
- trainer.save_metrics("train", metrics)
347
-
348
- except Exception as e:
349
- print(f"An error occurred during training: {e}")
350
- raise e
351
-
352
- print("Multi-GPU training process complete.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess
3
+ import sys
4
+ import os
5
+ import glob
6
+ import json
7
+ import math
 
 
 
 
8
  import torch
 
 
9
  import gc
10
+ from tqdm import tqdm
11
  from datasets import Dataset
12
  from huggingface_hub import snapshot_download
13
+ from transformers import AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, Trainer
14
+ from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # Function to run the training process
17
+ def train_model(
18
+ hf_username,
19
+ model_repo_name,
20
+ dataset_repo_name,
21
+ epochs=1,
22
+ batch_size=1,
23
+ grad_accum_steps=4,
24
+ learning_rate=1e-4,
25
+ progress=gr.Progress()
26
+ ):
27
+ progress(0, desc="Installing dependencies...")
28
+ # Install required packages if needed
29
+ try:
30
+ import transformers
31
+ import accelerate
32
+ import bitsandbytes
33
+ import peft
34
+ import deepspeed
35
+ except ImportError:
36
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "-U",
37
+ "transformers", "accelerate", "bitsandbytes", "peft",
38
+ "torch", "datasets", "huggingface_hub", "deepspeed"])
39
+
40
+ # --- Configuration ---
41
+ progress(0.05, desc="Setting up configuration...")
42
+ hf_model_repo_id = f"{hf_username}/{model_repo_name}"
43
+ hf_dataset_repo_id = f"{hf_username}/{dataset_repo_name}"
44
+
45
+ log = []
46
+ log.append(f"Model repo: {hf_model_repo_id}")
47
+ log.append(f"Dataset repo: {hf_dataset_repo_id}")
48
+
49
+ # Check if running on multiple GPUs
50
+ n_gpus = torch.cuda.device_count()
51
+ log.append(f"Number of GPUs available: {n_gpus}")
52
+
53
+ # --- Quantization Configuration ---
54
+ bnb_config = BitsAndBytesConfig(
55
+ load_in_4bit=True,
56
+ bnb_4bit_quant_type="nf4",
57
+ bnb_4bit_compute_dtype=torch.bfloat16,
58
+ bnb_4bit_use_double_quant=True,
59
  )
60
+
61
+ # --- Load Base Model (with quantization) ---
62
+ progress(0.1, desc="Loading base model...")
63
+ try:
64
+ model = AutoModelForCausalLM.from_pretrained(
65
+ hf_model_repo_id,
66
+ quantization_config=bnb_config,
67
+ device_map="auto",
68
+ trust_remote_code=True
69
+ )
70
+ log.append(f"Loaded model vocab size: {model.config.vocab_size}")
71
+ log.append(f"Input embedding shape: {model.get_input_embeddings().weight.shape}")
72
+ except Exception as e:
73
+ error_msg = f"Error loading model from Hub: {e}"
74
+ log.append(error_msg)
75
+ return "\n".join(log)
76
+
77
+ # --- Prepare for K-bit Training & Apply LoRA ---
78
+ progress(0.15, desc="Preparing model for fine-tuning...")
79
+ model = prepare_model_for_kbit_training(model)
80
+
81
+ lora_config = LoraConfig(
82
+ task_type=TaskType.CAUSAL_LM,
83
+ r=16,
84
+ lora_alpha=32,
85
+ lora_dropout=0.05,
86
+ bias="none",
87
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
 
 
 
 
 
 
 
88
  )
89
+ peft_model = get_peft_model(model, lora_config)
90
+ trainable_params = peft_model.print_trainable_parameters()
91
+ log.append(f"Trainable parameters: {trainable_params}")
92
+ model_to_train = peft_model
93
+
94
+ # Cleanup
95
+ gc.collect()
96
+ if torch.cuda.is_available():
97
+ torch.cuda.empty_cache()
98
+
99
+ # --- Load Dataset from Hub ---
100
+ progress(0.2, desc="Downloading dataset...")
101
+ local_download_path = "./downloaded_dataset_files"
102
+
 
 
 
 
 
103
  try:
104
+ downloaded_repo_root = snapshot_download(
105
+ repo_id=hf_dataset_repo_id,
106
+ repo_type="dataset",
107
+ local_dir=local_download_path,
108
+ local_dir_use_symlinks=False
109
+ )
110
+ log.append(f"Dataset repository content downloaded to: {downloaded_repo_root}")
111
  except Exception as e:
112
+ error_msg = f"Error downloading dataset repository from Hub: {e}"
113
+ log.append(error_msg)
114
+ return "\n".join(log)
115
+
116
+ # --- Find and load the .pt files ---
117
+ progress(0.25, desc="Finding dataset files...")
118
+ pairs_dir = os.path.join(downloaded_repo_root, "final_rvq_pairs")
119
+ all_pair_files = glob.glob(os.path.join(pairs_dir, "*_rvq_pairs.pt"))
120
+
121
+ if not all_pair_files:
122
+ all_pair_files = glob.glob(os.path.join(downloaded_repo_root, "*_rvq_pairs.pt"))
123
+ if not all_pair_files:
124
+ error_msg = "No RVQ pair files found in expected directories"
125
+ log.append(error_msg)
126
+ return "\n".join(log)
127
+
128
+ log.append(f"Found {len(all_pair_files)} RVQ pair files.")
129
+
130
+ # --- Load data from .pt files ---
131
+ progress(0.3, desc="Loading dataset files...")
132
+ all_data_pairs = []
133
+ for i, file_path in enumerate(all_pair_files):
134
+ progress(0.3 + (0.1 * i / len(all_pair_files)), desc=f"Loading file {i+1}/{len(all_pair_files)}")
135
+ try:
136
+ episode_pairs = torch.load(file_path, map_location='cpu')
137
+ all_data_pairs.extend(episode_pairs)
138
+ except Exception as e:
139
+ log.append(f"Warning: Could not load file {file_path}: {e}")
140
+
141
+ if not all_data_pairs:
142
+ error_msg = "No valid data pairs were loaded"
143
+ log.append(error_msg)
144
+ return "\n".join(log)
145
+
146
+ log.append(f"Loaded a total of {len(all_data_pairs)} training pairs into memory.")
147
+
148
+ # --- Convert to HF Dataset ---
149
+ progress(0.45, desc="Converting to Hugging Face Dataset...")
150
+ def prepare_for_dataset(batch):
151
+ output = {'input_ids': [], 'labels': []}
152
+ for item in batch:
153
+ output['input_ids'].append(item['input_ids'].cpu().tolist())
154
+ output['labels'].append(item['labels'].cpu().tolist())
155
+ return output
156
+
157
+ chunk_size = 1000
158
+ processed_data = {'input_ids': [], 'labels': []}
159
+
160
+ total_chunks = len(range(0, len(all_data_pairs), chunk_size))
161
+ for i in range(0, len(all_data_pairs), chunk_size):
162
+ chunk_idx = i // chunk_size
163
+ progress(0.45 + (0.1 * chunk_idx / total_chunks),
164
+ desc=f"Processing chunk {chunk_idx+1}/{total_chunks}")
165
+ batch = all_data_pairs[i:i + chunk_size]
166
+ prepared_batch = prepare_for_dataset(batch)
167
+ processed_data['input_ids'].extend(prepared_batch['input_ids'])
168
+ processed_data['labels'].extend(prepared_batch['labels'])
169
+
170
+ hf_dataset = Dataset.from_dict(processed_data)
171
+
172
+ # Transform to get tensors back
173
+ hf_dataset.set_transform(lambda batch: {
174
+ 'input_ids': [torch.tensor(ids, dtype=torch.long) for ids in batch['input_ids']],
175
+ 'labels': [torch.tensor(lbls, dtype=torch.long) for lbls in batch['labels']]
176
+ })
177
+
178
+ train_dataset = hf_dataset
179
+
180
+ # Cleanup
181
+ del all_data_pairs
182
+ del processed_data
183
+ gc.collect()
184
+
185
+ # --- Define Data Collator ---
186
+ progress(0.55, desc="Defining data collator...")
187
+ def seq2seq_causal_collator(features):
188
+ batch = {}
189
+ concatenated_input_ids = []
190
+ concatenated_labels = []
191
+ max_len = 0
192
+
193
+ # First pass: Concatenate, create masked labels, find max length
194
+ for feature in features:
195
+ input_ids = feature['input_ids']
196
+ labels = feature['labels']
197
+
198
+ if input_ids.dim() > 1: input_ids = input_ids.squeeze()
199
+ if labels.dim() > 1: labels = labels.squeeze()
200
+
201
+ context_len = input_ids.shape[0]
202
+ target_len = labels.shape[0]
203
+
204
+ combined_ids = torch.cat([input_ids, labels], dim=0)
205
+ concatenated_input_ids.append(combined_ids)
206
+
207
+ masked_labels = torch.cat([
208
+ torch.full((context_len,), -100, dtype=torch.long, device=input_ids.device),
209
+ labels
210
+ ], dim=0)
211
+ concatenated_labels.append(masked_labels)
212
+
213
+ if combined_ids.shape[0] > max_len:
214
+ max_len = combined_ids.shape[0]
215
+
216
+ # Second pass: Pad to max length
217
+ padded_input_ids = []
218
+ padded_labels = []
219
+ input_pad_token_id = 0
220
+ label_pad_token_id = -100
221
+
222
+ for i in range(len(features)):
223
+ ids = concatenated_input_ids[i]
224
+ lbls = concatenated_labels[i]
225
+
226
+ padding_len = max_len - ids.shape[0]
227
+
228
+ padded_input_ids.append(torch.nn.functional.pad(
229
+ ids, (0, padding_len), value=input_pad_token_id
230
+ ))
231
+ padded_labels.append(torch.nn.functional.pad(
232
+ lbls, (0, padding_len), value=label_pad_token_id
233
+ ))
234
+
235
+ # Stack and create final batch
236
+ batch['input_ids'] = torch.stack(padded_input_ids)
237
+ batch['labels'] = torch.stack(padded_labels)
238
+ batch['attention_mask'] = batch['input_ids'].ne(input_pad_token_id).long()
239
+
240
+ return batch
241
+
242
+ data_collator = seq2seq_causal_collator
243
+
244
+ # --- Define Training Arguments and Initialize Trainer ---
245
+ progress(0.65, desc="Setting up training configuration...")
246
+
247
+ # Output directories
248
+ OUTPUT_TRAINING_DIR = "./llama3-8b-rvq-qlora-finetuned-run"
249
+ LOGGING_DIR = "./llama3-8b-rvq-qlora-logs-run"
250
+
251
+ # Training parameters - adjusted for 4x T4 GPUs
252
+ NUM_EPOCHS = int(epochs)
253
+ BATCH_SIZE_PER_DEVICE = int(batch_size) # Smaller per-device batch size to avoid OOM
254
+ GRAD_ACCUMULATION_STEPS = int(grad_accum_steps)
255
+ LEARNING_RATE = float(learning_rate)
256
+ WEIGHT_DECAY = 0.01
257
+ WARMUP_RATIO = 0.03
258
+ LR_SCHEDULER = "cosine"
259
+ OPTIMIZER = "paged_adamw_8bit"
260
+
261
+ # Calculate total steps and warmup steps
262
+ # Total batch size is now batch_size × num_gpus × grad_accum_steps
263
+ total_train_batch_size = BATCH_SIZE_PER_DEVICE * n_gpus * GRAD_ACCUMULATION_STEPS
264
+ num_training_steps = math.ceil((len(train_dataset) * NUM_EPOCHS) / total_train_batch_size)
265
+ num_warmup_steps = int(num_training_steps * WARMUP_RATIO)
266
+
267
+ # Logging/Saving frequency
268
+ steps_per_epoch = math.ceil(len(train_dataset) / total_train_batch_size)
269
+ LOGGING_STEPS = max(10, steps_per_epoch // 15)
270
+ SAVE_STEPS = max(50, steps_per_epoch // 10)
271
+
272
+ log.append(f"Dataset size: {len(train_dataset)}")
273
+ log.append(f"Number of GPUs: {n_gpus}")
274
+ log.append(f"Batch size per device: {BATCH_SIZE_PER_DEVICE}")
275
+ log.append(f"Gradient Accumulation steps: {GRAD_ACCUMULATION_STEPS}")
276
+ log.append(f"Total train batch size (effective): {total_train_batch_size}")
277
+ log.append(f"Total optimization steps: {num_training_steps}")
278
+ log.append(f"Warmup steps: {num_warmup_steps}")
279
+
280
+ # --- Create DeepSpeed configuration file ---
281
+ progress(0.7, desc="Creating DeepSpeed configuration...")
282
+ # DeepSpeed ZeRO-3 config optimized for T4 GPUs
283
+ ds_config = {
284
+ "fp16": {
285
+ "enabled": "auto",
286
+ "loss_scale": 0,
287
+ "loss_scale_window": 1000,
288
+ "initial_scale_power": 16,
289
+ "hysteresis": 2,
290
+ "min_loss_scale": 1
 
 
 
291
  },
292
+ "bf16": {
293
+ "enabled": "auto"
 
294
  },
295
+ "zero_optimization": {
296
+ "stage": 3,
297
+ "offload_optimizer": {
298
+ "device": "cpu",
299
+ "pin_memory": True
300
+ },
301
+ "offload_param": {
302
+ "device": "cpu",
303
+ "pin_memory": True
304
+ },
305
+ "overlap_comm": True,
306
+ "contiguous_gradients": True,
307
+ "reduce_bucket_size": "auto",
308
+ "stage3_prefetch_bucket_size": "auto",
309
+ "stage3_param_persistence_threshold": "auto",
310
+ "gather_16bit_weights_on_model_save": True,
311
+ "stage3_max_live_parameters": 1e9,
312
+ "stage3_max_reuse_distance": 1e9
313
+ },
314
+ "gradient_accumulation_steps": GRAD_ACCUMULATION_STEPS,
315
+ "gradient_clipping": "auto",
316
+ "steps_per_print": 10,
317
+ "train_batch_size": "auto",
318
+ "train_micro_batch_size_per_gpu": "auto",
319
+ "wall_clock_breakdown": False
320
+ }
321
+
322
+ with open("ds_config.json", "w") as f:
323
+ json.dump(ds_config, f, indent=4)
324
+
325
+ # Configure for multi-GPU training using DeepSpeed
326
+ progress(0.75, desc="Setting up training arguments...")
327
+ training_args = TrainingArguments(
328
+ output_dir=OUTPUT_TRAINING_DIR,
329
+ num_train_epochs=NUM_EPOCHS,
330
+ per_device_train_batch_size=BATCH_SIZE_PER_DEVICE,
331
+ gradient_accumulation_steps=GRAD_ACCUMULATION_STEPS,
332
+ optim=OPTIMIZER,
333
+ logging_dir=LOGGING_DIR,
334
+ logging_strategy="steps",
335
+ logging_steps=LOGGING_STEPS,
336
+ save_strategy="steps",
337
+ save_steps=SAVE_STEPS,
338
+ save_total_limit=2,
339
+ learning_rate=LEARNING_RATE,
340
+ weight_decay=WEIGHT_DECAY,
341
+ warmup_steps=num_warmup_steps,
342
+ lr_scheduler_type=LR_SCHEDULER,
343
+ report_to="tensorboard",
344
+ bf16=True if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else False,
345
+ gradient_checkpointing=True,
346
+ gradient_checkpointing_kwargs={'use_reentrant': False},
347
+
348
+ # Multi-GPU specific settings
349
+ deepspeed="ds_config.json",
350
+ ddp_find_unused_parameters=False,
351
+ )
352
+
353
+ # --- Initialize Trainer ---
354
+ progress(0.8, desc="Initializing trainer...")
355
+ trainer = Trainer(
356
+ model=model_to_train,
357
+ args=training_args,
358
+ train_dataset=train_dataset,
359
+ data_collator=data_collator,
360
+ )
361
+
362
+ log.append("Trainer initialized with DeepSpeed for multi-GPU training.")
363
+
364
+ # --- Start Training ---
365
+ # Clear cache before starting
366
+ gc.collect()
367
+ if torch.cuda.is_available():
368
+ torch.cuda.empty_cache()
369
+
370
+ try:
371
+ progress(0.85, desc="Starting training...")
372
+ log.append("Starting distributed training on multiple GPUs...")
373
+ train_result = trainer.train()
374
+ progress(0.95, desc="Saving model...")
375
+
376
+ # Save final model (adapter weights) and training state
377
+ final_save_path = os.path.join(training_args.output_dir, "final_checkpoint")
378
+ log.append(f"Saving final model checkpoint to {final_save_path}...")
379
+ trainer.save_model(final_save_path)
380
+ trainer.save_state()
381
+
382
+ # Log metrics
383
+ metrics = train_result.metrics
384
+ trainer.log_metrics("train", metrics)
385
+ trainer.save_metrics("train", metrics)
386
+
387
+ for key, value in metrics.items():
388
+ log.append(f"{key}: {value}")
389
+
390
+ except Exception as e:
391
+ error_msg = f"An error occurred during training: {e}"
392
+ log.append(error_msg)
393
+ return "\n".join(log)
394
+
395
+ progress(1.0, desc="Training complete!")
396
+ log.append("Multi-GPU training process complete.")
397
+ return "\n".join(log)
398
+
399
+ # Define the Gradio interface
400
+ def create_interface():
401
+ with gr.Blocks(title="Llama 3 8B RVQ Fine-tuning") as demo:
402
+ gr.Markdown("# Llama 3 8B RVQ LoRA Fine-tuning")
403
+ gr.Markdown("Fine-tune a Llama 3 8B model with RVQ token embeddings using LoRA on multiple GPUs")
404
+
405
+ with gr.Row():
406
+ with gr.Column():
407
+ hf_username = gr.Textbox(label="HuggingFace Username", value="Twelve2five")
408
+ model_repo = gr.Textbox(label="Model Repository Name", value="llama-3-8b-rvq-resized")
409
+ dataset_repo = gr.Textbox(label="Dataset Repository Name", value="podcast-dialogue-rvq-pairs-3items")
410
+
411
+ with gr.Column():
412
+ epochs = gr.Number(label="Number of Epochs", value=1, minimum=1, maximum=10)
413
+ batch_size = gr.Number(label="Batch Size per Device", value=1, minimum=1, maximum=8)
414
+ grad_accum = gr.Number(label="Gradient Accumulation Steps", value=4, minimum=1, maximum=16)
415
+ lr = gr.Number(label="Learning Rate", value=1e-4)
416
+
417
+ start_btn = gr.Button("Start Training")
418
+ output = gr.Textbox(label="Training Log", lines=20)
419
+
420
+ start_btn.click(
421
+ fn=train_model,
422
+ inputs=[hf_username, model_repo, dataset_repo, epochs, batch_size, grad_accum, lr],
423
+ outputs=output
424
+ )
425
+
426
+ return demo
427
+
428
+ # Create and launch the interface
429
+ demo = create_interface()
430
+ if __name__ == "__main__":
431
+ demo.launch()