Twelve2five commited on
Commit
26c97a9
·
verified ·
1 Parent(s): 9295d60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +277 -320
app.py CHANGED
@@ -1,23 +1,23 @@
1
- import os
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
- import glob
4
- import gc
5
- from transformers import (
6
- AutoModelForCausalLM,
7
- AutoTokenizer,
8
- BitsAndBytesConfig,
9
- TrainingArguments,
10
- Trainer,
11
- DataCollatorForLanguageModeling,
12
- AutoTokenizer
13
- )
14
  from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
 
 
15
  from datasets import Dataset
16
  from huggingface_hub import snapshot_download
 
17
  from tqdm import tqdm
18
- import gradio as gr
19
- import math
20
- from accelerate import Accelerator
21
 
22
  # --- Configuration ---
23
  YOUR_HF_USERNAME = "Twelve2five"
@@ -27,75 +27,155 @@ DATASET_REPO_NAME = "podcast-dialogue-rvq-pairs-3items"
27
  hf_model_repo_id = f"{YOUR_HF_USERNAME}/{MODEL_REPO_NAME}"
28
  hf_dataset_repo_id = f"{YOUR_HF_USERNAME}/{DATASET_REPO_NAME}"
29
 
30
- # Output directories
31
- OUTPUT_TRAINING_DIR = "./llama3-8b-rvq-qlora-finetuned-run"
32
- LOGGING_DIR = "./llama3-8b-rvq-qlora-logs-run"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  local_download_path = "./downloaded_dataset_files"
34
 
35
- # Training parameters
36
- NUM_EPOCHS = 1
37
- BATCH_SIZE_PER_DEVICE = 1
38
- GRAD_ACCUMULATION_STEPS = 64
39
- LEARNING_RATE = 1e-4
40
- WEIGHT_DECAY = 0.01
41
- WARMUP_RATIO = 0.03
42
- LR_SCHEDULER = "cosine"
43
- OPTIMIZER = "paged_adamw_8bit"
44
- MAX_SEQ_LENGTH = 256
45
- MICRO_BATCH_SIZE = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- # Multi-GPU configuration
48
- accelerator = Accelerator()
 
 
 
 
 
49
 
50
- # Configure environment for multi-GPU
51
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
52
 
53
- # Print GPU information
54
- print(f"Available GPUs: {torch.cuda.device_count()}")
55
- for i in range(torch.cuda.device_count()):
56
- print(f"GPU {i}: {torch.cuda.get_device_name(i)} with {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")
 
57
 
 
 
 
 
 
 
 
 
58
  def seq2seq_causal_collator(features):
59
- """
60
- Collator that concatenates context (input_ids) and target (labels)
61
- for Causal LM sequence-to-sequence training.
62
- Masks the loss for the context part of the sequence.
63
- Pads sequences to the maximum length in the batch.
64
- """
65
  batch = {}
66
  concatenated_input_ids = []
67
  concatenated_labels = []
68
  max_len = 0
69
 
70
- # --- First pass: Concatenate, create masked labels, find max length ---
71
  for feature in features:
72
- # Dataset transform should provide tensors here
73
  input_ids = feature['input_ids']
74
  labels = feature['labels']
75
 
76
- # Ensure tensors are 1D (handle potential extra dims if any)
77
  if input_ids.dim() > 1: input_ids = input_ids.squeeze()
78
  if labels.dim() > 1: labels = labels.squeeze()
79
 
80
  context_len = input_ids.shape[0]
81
  target_len = labels.shape[0]
82
 
83
- # Concatenate context and target for input
84
  combined_ids = torch.cat([input_ids, labels], dim=0)
85
  concatenated_input_ids.append(combined_ids)
86
 
87
- # Create labels: -100 for context, actual labels for target
88
  masked_labels = torch.cat([
89
  torch.full((context_len,), -100, dtype=torch.long, device=input_ids.device),
90
  labels
91
  ], dim=0)
92
  concatenated_labels.append(masked_labels)
93
 
94
- # Track max length for padding
95
  if combined_ids.shape[0] > max_len:
96
  max_len = combined_ids.shape[0]
97
 
98
- # --- Second pass: Pad to max length ---
99
  padded_input_ids = []
100
  padded_labels = []
101
  input_pad_token_id = 0
@@ -107,7 +187,6 @@ def seq2seq_causal_collator(features):
107
 
108
  padding_len = max_len - ids.shape[0]
109
 
110
- # Pad on the right side
111
  padded_input_ids.append(torch.nn.functional.pad(
112
  ids, (0, padding_len), value=input_pad_token_id
113
  ))
@@ -115,281 +194,159 @@ def seq2seq_causal_collator(features):
115
  lbls, (0, padding_len), value=label_pad_token_id
116
  ))
117
 
118
- # --- Stack and create final batch ---
119
  batch['input_ids'] = torch.stack(padded_input_ids)
120
  batch['labels'] = torch.stack(padded_labels)
121
-
122
- # Create attention mask (1 for real tokens, 0 for padding)
123
  batch['attention_mask'] = batch['input_ids'].ne(input_pad_token_id).long()
124
 
125
  return batch
126
 
127
- def prepare_for_dataset(batch):
128
- output = {'input_ids': [], 'labels': []}
129
- for item in batch:
130
- output['input_ids'].append(item['input_ids'].cpu().tolist())
131
- output['labels'].append(item['labels'].cpu().tolist())
132
- return output
133
 
134
- def load_model():
135
- print(f"Loading base model architecture from: {hf_model_repo_id}")
136
-
137
- # Get information about GPU with most free memory
138
- gpu_id = 0 # Default to first GPU
139
- max_free_memory = 0
140
-
141
- for i in range(torch.cuda.device_count()):
142
- free_memory = torch.cuda.get_device_properties(i).total_memory - torch.cuda.memory_allocated(i)
143
- if free_memory > max_free_memory:
144
- max_free_memory = free_memory
145
- gpu_id = i
146
-
147
- print(f"Loading model on GPU {gpu_id} with {max_free_memory / 1e9:.2f}GB free memory")
148
-
149
- # Configure quantization
150
- bnb_config = BitsAndBytesConfig(
151
- load_in_4bit=True,
152
- bnb_4bit_use_double_quant=True,
153
- bnb_4bit_quant_type="nf4",
154
- bnb_4bit_compute_dtype=torch.bfloat16
155
- )
156
-
157
- # Load the model
158
- model = AutoModelForCausalLM.from_pretrained(
159
- hf_model_repo_id,
160
- quantization_config=bnb_config,
161
- device_map={"": gpu_id},
162
- torch_dtype=torch.bfloat16,
163
- )
164
-
165
- print(f"Model loaded on device: cuda:{gpu_id}")
166
-
167
- # Load the official Meta tokenizer for LLaMA 3
168
- tokenizer = AutoTokenizer.from_pretrained(
169
- "meta-llama/Llama-3-8B", # Use the official Meta tokenizer
170
- use_auth_token=os.environ.get("HF_TOKEN", None) # In case it's needed
171
- )
172
-
173
- if tokenizer is None:
174
- # Fallback to another common foundation model tokenizer
175
- print("Falling back to another tokenizer as Meta tokenizer requires auth token")
176
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
177
-
178
- print(f"Loaded tokenizer vocabulary size: {len(tokenizer)}")
179
-
180
- # Print information about input embeddings
181
- print(f"Input embedding shape: {model.get_input_embeddings().weight.shape}")
182
-
183
- # Prepare model for k-bit training
184
- model = prepare_model_for_kbit_training(model)
185
-
186
- # Define LoRA configuration
187
- lora_config = LoraConfig(
188
- r=16,
189
- lora_alpha=32,
190
- target_modules=[
191
- "q_proj",
192
- "k_proj",
193
- "v_proj",
194
- "o_proj",
195
- "gate_proj",
196
- "up_proj",
197
- "down_proj",
198
- ],
199
- lora_dropout=0.05,
200
- bias="none",
201
- task_type=TaskType.CAUSAL_LM
202
- )
203
-
204
- # Apply LoRA to model
205
- model = get_peft_model(model, lora_config)
206
- model.print_trainable_parameters()
207
-
208
- return model, tokenizer # Return both model and tokenizer
209
 
210
- def load_dataset():
211
- # --- Download the dataset repository files ---
212
- try:
213
- os.makedirs(local_download_path, exist_ok=True)
214
- downloaded_repo_root = snapshot_download(
215
- repo_id=hf_dataset_repo_id,
216
- repo_type="dataset",
217
- local_dir=local_download_path,
218
- local_dir_use_symlinks=False
219
- )
220
- print(f"Dataset repository content downloaded to: {downloaded_repo_root}")
221
- except Exception as e:
222
- print(f"Error downloading dataset: {e}")
223
- return None
224
-
225
- # --- Load .pt files into a Hugging Face Dataset object ---
226
- pairs_dir = os.path.join(downloaded_repo_root, "final_rvq_pairs")
227
- all_pair_files = glob.glob(os.path.join(pairs_dir, "*_rvq_pairs.pt"))
228
-
229
- if not all_pair_files:
230
- all_pair_files = glob.glob(os.path.join(downloaded_repo_root, "*_rvq_pairs.pt"))
231
- if not all_pair_files:
232
- print("No RVQ pair files found!")
233
- return None
234
-
235
- print(f"Found {len(all_pair_files)} RVQ pair files.")
236
-
237
- # Load data from .pt files into memory
238
- all_data_pairs = []
239
- for file_path in tqdm(all_pair_files, desc="Loading pair files"):
240
- try:
241
- episode_pairs = torch.load(file_path, map_location='cpu')
242
- all_data_pairs.extend(episode_pairs)
243
- except Exception as e:
244
- print(f"Warning: Could not load file {file_path}: {e}")
245
-
246
- if not all_data_pairs:
247
- return None
248
-
249
- print(f"Loaded {len(all_data_pairs)} training pairs.")
250
-
251
- # Convert to Hugging Face Dataset
252
- chunk_size = 1000
253
- processed_data = {'input_ids': [], 'labels': []}
254
- for i in tqdm(range(0, len(all_data_pairs), chunk_size), desc="Preparing data"):
255
- batch = all_data_pairs[i:i + chunk_size]
256
- prepared_batch = prepare_for_dataset(batch)
257
- processed_data['input_ids'].extend(prepared_batch['input_ids'])
258
- processed_data['labels'].extend(prepared_batch['labels'])
259
-
260
- hf_dataset = Dataset.from_dict(processed_data)
261
-
262
- # Transform to get tensors back
263
- hf_dataset.set_transform(lambda batch: {
264
- 'input_ids': [torch.tensor(ids, dtype=torch.long) for ids in batch['input_ids']],
265
- 'labels': [torch.tensor(lbls, dtype=torch.long) for lbls in batch['labels']]
266
- })
267
-
268
- # Cleanup
269
- del all_data_pairs
270
- del processed_data
271
- gc.collect()
272
-
273
- return hf_dataset
274
-
275
- # Memory cleaning function
276
- def clean_memory():
277
- gc.collect()
278
- if torch.cuda.is_available():
279
- for i in range(torch.cuda.device_count()):
280
- with torch.cuda.device(f'cuda:{i}'):
281
- torch.cuda.empty_cache()
282
- torch.cuda.reset_peak_memory_stats()
283
-
284
- def train_model(progress=gr.Progress()):
285
- # Clean memory before starting
286
- clean_memory()
287
-
288
- # Load model with optimized memory settings
289
- model, tokenizer = load_model()
290
-
291
- # Load and prepare dataset
292
- progress(0.1, desc="Loading dataset...")
293
- train_dataset = load_dataset()
294
-
295
- # Initialize trainer with debug flags
296
- progress(0.2, desc="Initializing trainer...")
297
-
298
- try:
299
- # Set up training args with simplified settings
300
- training_args = TrainingArguments(
301
- output_dir="./results",
302
- num_train_epochs=1, # Just 1 epoch for testing
303
- per_device_train_batch_size=1, # Minimal batch size
304
- gradient_accumulation_steps=4, # Reduce memory pressure
305
- warmup_steps=2,
306
- logging_steps=1, # Log every step
307
- save_steps=10000, # Don't save checkpoints during test
308
- learning_rate=2e-4,
309
- fp16=False, # Disable mixed precision for stability
310
- optim="adamw_torch",
311
- report_to="none", # Disable wandb/tensorboard reporting
312
- max_steps=3, # Just try 3 steps to see if it works
313
- logging_first_step=True, # Force log on first step
314
- )
315
-
316
- # Create a simple trainer with the tokenizer
317
- trainer = Trainer(
318
- model=model,
319
- args=training_args,
320
- train_dataset=train_dataset,
321
- data_collator=DataCollatorForLanguageModeling(
322
- tokenizer=tokenizer,
323
- mlm=False
324
- )
325
- )
326
-
327
- # Run training for just 3 steps
328
- progress(0.3, desc="Starting training (this may take 5-15 minutes for first step)...")
329
- trainer.train()
330
-
331
- progress(0.9, desc="Initial training successful! You can now run full training.")
332
- return "Initial training completed successfully! The system is working. You can now adjust parameters for a full training run."
333
-
334
- except Exception as e:
335
- error_msg = str(e)
336
- print(f"Training error: {error_msg}")
337
-
338
- # Add memory diagnostics to error message
339
- mem_info = "\nMemory status at error time:\n"
340
- for i in range(torch.cuda.device_count()):
341
- mem_info += f"GPU {i}: {torch.cuda.memory_allocated(i) / 1e9:.2f}GB allocated, {torch.cuda.memory_reserved(i) / 1e9:.2f}GB reserved\n"
342
-
343
- return f"An error occurred during training: {error_msg}\n{mem_info}"
344
-
345
- # Create Gradio interface
346
- def create_ui():
347
- with gr.Blocks() as demo:
348
- gr.Markdown("# Fine-tune LLaMA 3 8B with QLoRA")
349
-
350
- with gr.Tab("Training"):
351
- train_button = gr.Button("Start Fine-tuning")
352
- result_text = gr.Textbox(label="Training Results", interactive=False)
353
-
354
- train_button.click(train_model, outputs=result_text)
355
-
356
- with gr.Tab("About"):
357
- gr.Markdown("""
358
- ## Information
359
- This is a Hugging Face Space version of the original Google Colab notebook.
360
-
361
- It fine-tunes a quantized LLaMA 3 8B model using QLoRA on podcast dialogue data.
362
-
363
- ### Model
364
- - Base Model: {YOUR_HF_USERNAME}/{MODEL_REPO_NAME}
365
- - Using 4-bit quantization with LoRA adapters
366
-
367
- ### Dataset
368
- - Custom dataset: {YOUR_HF_USERNAME}/{DATASET_REPO_NAME}
369
- - Contains podcast dialogue pairs processed for training
370
-
371
- ### Training Setup
372
- - QLoRA fine-tuning
373
- - Epochs: {NUM_EPOCHS}
374
- - Batch size: {BATCH_SIZE_PER_DEVICE} with {GRAD_ACCUMULATION_STEPS} gradient accumulation steps
375
- - Learning rate: {LEARNING_RATE}
376
- """.format(
377
- YOUR_HF_USERNAME=YOUR_HF_USERNAME,
378
- MODEL_REPO_NAME=MODEL_REPO_NAME,
379
- DATASET_REPO_NAME=DATASET_REPO_NAME,
380
- NUM_EPOCHS=NUM_EPOCHS,
381
- BATCH_SIZE_PER_DEVICE=BATCH_SIZE_PER_DEVICE,
382
- GRAD_ACCUMULATION_STEPS=GRAD_ACCUMULATION_STEPS,
383
- LEARNING_RATE=LEARNING_RATE
384
- ))
385
-
386
- return demo
387
 
388
- # Main entry point
389
- if __name__ == "__main__":
390
- # Install dependencies first if needed
391
- # !pip install -q -U transformers accelerate bitsandbytes peft torch datasets huggingface_hub gradio
392
-
393
- # Create and launch the UI
394
- demo = create_ui()
395
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Script for fine-tuning Llama-3-8B with RVQ tokens on multiple GPUs
4
+ """
5
+
6
+ # Basic setup and installations
7
+ !pip install -q -U transformers accelerate bitsandbytes peft torch datasets huggingface_hub deepspeed
8
+
9
+ # No need for notebook_login on Hugging Face platform
10
+ # Authentication is handled automatically
11
+
12
  import torch
13
+ from transformers import AutoModelForCausalLM, BitsAndBytesConfig
 
 
 
 
 
 
 
 
 
 
14
  from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
15
+ import gc
16
+ import os
17
  from datasets import Dataset
18
  from huggingface_hub import snapshot_download
19
+ import glob
20
  from tqdm import tqdm
 
 
 
21
 
22
  # --- Configuration ---
23
  YOUR_HF_USERNAME = "Twelve2five"
 
27
  hf_model_repo_id = f"{YOUR_HF_USERNAME}/{MODEL_REPO_NAME}"
28
  hf_dataset_repo_id = f"{YOUR_HF_USERNAME}/{DATASET_REPO_NAME}"
29
 
30
+ # Check if running on multiple GPUs
31
+ n_gpus = torch.cuda.device_count()
32
+ print(f"Number of GPUs available: {n_gpus}")
33
+
34
+ # --- Quantization Configuration ---
35
+ bnb_config = BitsAndBytesConfig(
36
+ load_in_4bit=True,
37
+ bnb_4bit_quant_type="nf4",
38
+ bnb_4bit_compute_dtype=torch.bfloat16,
39
+ bnb_4bit_use_double_quant=True,
40
+ )
41
+
42
+ # --- Load Base Model (with quantization) ---
43
+ try:
44
+ # For multi-GPU QLoRA, we'll use device_map="auto" and let DeepSpeed handle distribution later
45
+ model = AutoModelForCausalLM.from_pretrained(
46
+ hf_model_repo_id,
47
+ quantization_config=bnb_config,
48
+ device_map="auto", # Will be overridden by DeepSpeed
49
+ trust_remote_code=True
50
+ )
51
+ print(f"Loaded model vocab size: {model.config.vocab_size}")
52
+ print(f"Input embedding shape: {model.get_input_embeddings().weight.shape}")
53
+ except Exception as e:
54
+ print(f"Error loading model from Hub: {e}")
55
+ raise SystemExit("Model loading failed.")
56
+
57
+ # --- Prepare for K-bit Training & Apply LoRA ---
58
+ model = prepare_model_for_kbit_training(model)
59
+
60
+ lora_config = LoraConfig(
61
+ task_type=TaskType.CAUSAL_LM,
62
+ r=16,
63
+ lora_alpha=32,
64
+ lora_dropout=0.05,
65
+ bias="none",
66
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
67
+ )
68
+ peft_model = get_peft_model(model, lora_config)
69
+ peft_model.print_trainable_parameters()
70
+ model_to_train = peft_model
71
+
72
+ # Cleanup
73
+ gc.collect()
74
+ if torch.cuda.is_available():
75
+ torch.cuda.empty_cache()
76
+
77
+ # --- Load Dataset from Hub ---
78
  local_download_path = "./downloaded_dataset_files"
79
 
80
+ try:
81
+ downloaded_repo_root = snapshot_download(
82
+ repo_id=hf_dataset_repo_id,
83
+ repo_type="dataset",
84
+ local_dir=local_download_path,
85
+ local_dir_use_symlinks=False
86
+ )
87
+ print(f"Dataset repository content downloaded to: {downloaded_repo_root}")
88
+ except Exception as e:
89
+ print(f"Error downloading dataset repository from Hub: {e}")
90
+ raise SystemExit("Dataset download failed.")
91
+
92
+ # --- Find and load the .pt files ---
93
+ pairs_dir = os.path.join(downloaded_repo_root, "final_rvq_pairs")
94
+ all_pair_files = glob.glob(os.path.join(pairs_dir, "*_rvq_pairs.pt"))
95
+
96
+ if not all_pair_files:
97
+ all_pair_files = glob.glob(os.path.join(downloaded_repo_root, "*_rvq_pairs.pt"))
98
+ if not all_pair_files:
99
+ raise FileNotFoundError(f"No RVQ pair files found in expected directories")
100
+
101
+ print(f"Found {len(all_pair_files)} RVQ pair files.")
102
+
103
+ # --- Load data from .pt files ---
104
+ all_data_pairs = []
105
+ for file_path in tqdm(all_pair_files, desc="Loading pair files"):
106
+ try:
107
+ episode_pairs = torch.load(file_path, map_location='cpu')
108
+ all_data_pairs.extend(episode_pairs)
109
+ except Exception as e:
110
+ print(f"Warning: Could not load file {file_path}: {e}")
111
+
112
+ if not all_data_pairs:
113
+ raise ValueError("No valid data pairs were loaded")
114
+
115
+ print(f"Loaded a total of {len(all_data_pairs)} training pairs into memory.")
116
+
117
+ # --- Convert to HF Dataset ---
118
+ def prepare_for_dataset(batch):
119
+ output = {'input_ids': [], 'labels': []}
120
+ for item in batch:
121
+ output['input_ids'].append(item['input_ids'].cpu().tolist())
122
+ output['labels'].append(item['labels'].cpu().tolist())
123
+ return output
124
 
125
+ chunk_size = 1000
126
+ processed_data = {'input_ids': [], 'labels': []}
127
+ for i in tqdm(range(0, len(all_data_pairs), chunk_size), desc="Preparing data for Dataset"):
128
+ batch = all_data_pairs[i:i + chunk_size]
129
+ prepared_batch = prepare_for_dataset(batch)
130
+ processed_data['input_ids'].extend(prepared_batch['input_ids'])
131
+ processed_data['labels'].extend(prepared_batch['labels'])
132
 
133
+ hf_dataset = Dataset.from_dict(processed_data)
 
134
 
135
+ # Transform to get tensors back
136
+ hf_dataset.set_transform(lambda batch: {
137
+ 'input_ids': [torch.tensor(ids, dtype=torch.long) for ids in batch['input_ids']],
138
+ 'labels': [torch.tensor(lbls, dtype=torch.long) for lbls in batch['labels']]
139
+ })
140
 
141
+ train_dataset = hf_dataset
142
+
143
+ # Cleanup
144
+ del all_data_pairs
145
+ del processed_data
146
+ gc.collect()
147
+
148
+ # --- Define Data Collator ---
149
  def seq2seq_causal_collator(features):
 
 
 
 
 
 
150
  batch = {}
151
  concatenated_input_ids = []
152
  concatenated_labels = []
153
  max_len = 0
154
 
155
+ # First pass: Concatenate, create masked labels, find max length
156
  for feature in features:
 
157
  input_ids = feature['input_ids']
158
  labels = feature['labels']
159
 
 
160
  if input_ids.dim() > 1: input_ids = input_ids.squeeze()
161
  if labels.dim() > 1: labels = labels.squeeze()
162
 
163
  context_len = input_ids.shape[0]
164
  target_len = labels.shape[0]
165
 
 
166
  combined_ids = torch.cat([input_ids, labels], dim=0)
167
  concatenated_input_ids.append(combined_ids)
168
 
 
169
  masked_labels = torch.cat([
170
  torch.full((context_len,), -100, dtype=torch.long, device=input_ids.device),
171
  labels
172
  ], dim=0)
173
  concatenated_labels.append(masked_labels)
174
 
 
175
  if combined_ids.shape[0] > max_len:
176
  max_len = combined_ids.shape[0]
177
 
178
+ # Second pass: Pad to max length
179
  padded_input_ids = []
180
  padded_labels = []
181
  input_pad_token_id = 0
 
187
 
188
  padding_len = max_len - ids.shape[0]
189
 
 
190
  padded_input_ids.append(torch.nn.functional.pad(
191
  ids, (0, padding_len), value=input_pad_token_id
192
  ))
 
194
  lbls, (0, padding_len), value=label_pad_token_id
195
  ))
196
 
197
+ # Stack and create final batch
198
  batch['input_ids'] = torch.stack(padded_input_ids)
199
  batch['labels'] = torch.stack(padded_labels)
 
 
200
  batch['attention_mask'] = batch['input_ids'].ne(input_pad_token_id).long()
201
 
202
  return batch
203
 
204
+ data_collator = seq2seq_causal_collator
 
 
 
 
 
205
 
206
+ # --- Define Training Arguments and Initialize Trainer ---
207
+ from transformers import TrainingArguments, Trainer
208
+ import math
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
+ # Output directories
211
+ OUTPUT_TRAINING_DIR = "./llama3-8b-rvq-qlora-finetuned-run"
212
+ LOGGING_DIR = "./llama3-8b-rvq-qlora-logs-run"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
+ # Training parameters - adjusted for 4x T4 GPUs
215
+ NUM_EPOCHS = 1
216
+ # Scale down per-device batch size since we have multiple GPUs now
217
+ BATCH_SIZE_PER_DEVICE = 1 # Smaller per-device batch size to avoid OOM
218
+ GRAD_ACCUMULATION_STEPS = 4
219
+ LEARNING_RATE = 1e-4
220
+ WEIGHT_DECAY = 0.01
221
+ WARMUP_RATIO = 0.03
222
+ LR_SCHEDULER = "cosine"
223
+ OPTIMIZER = "paged_adamw_8bit"
224
+
225
+ # Calculate total steps and warmup steps
226
+ # Total batch size is now batch_size × num_gpus × grad_accum_steps
227
+ total_train_batch_size = BATCH_SIZE_PER_DEVICE * n_gpus * GRAD_ACCUMULATION_STEPS
228
+ num_training_steps = math.ceil((len(train_dataset) * NUM_EPOCHS) / total_train_batch_size)
229
+ num_warmup_steps = int(num_training_steps * WARMUP_RATIO)
230
+
231
+ # Logging/Saving frequency
232
+ steps_per_epoch = math.ceil(len(train_dataset) / total_train_batch_size)
233
+ LOGGING_STEPS = max(10, steps_per_epoch // 15)
234
+ SAVE_STEPS = max(50, steps_per_epoch // 10)
235
+
236
+ print(f"Dataset size: {len(train_dataset)}")
237
+ print(f"Number of GPUs: {n_gpus}")
238
+ print(f"Batch size per device: {BATCH_SIZE_PER_DEVICE}")
239
+ print(f"Gradient Accumulation steps: {GRAD_ACCUMULATION_STEPS}")
240
+ print(f"Total train batch size (effective): {total_train_batch_size}")
241
+ print(f"Total optimization steps: {num_training_steps}")
242
+ print(f"Warmup steps: {num_warmup_steps}")
243
+
244
+ # Configure for multi-GPU training using DeepSpeed
245
+ training_args = TrainingArguments(
246
+ output_dir=OUTPUT_TRAINING_DIR,
247
+ num_train_epochs=NUM_EPOCHS,
248
+ per_device_train_batch_size=BATCH_SIZE_PER_DEVICE,
249
+ gradient_accumulation_steps=GRAD_ACCUMULATION_STEPS,
250
+ optim=OPTIMIZER,
251
+ logging_dir=LOGGING_DIR,
252
+ logging_strategy="steps",
253
+ logging_steps=LOGGING_STEPS,
254
+ save_strategy="steps",
255
+ save_steps=SAVE_STEPS,
256
+ save_total_limit=2,
257
+ learning_rate=LEARNING_RATE,
258
+ weight_decay=WEIGHT_DECAY,
259
+ warmup_steps=num_warmup_steps,
260
+ lr_scheduler_type=LR_SCHEDULER,
261
+ report_to="tensorboard",
262
+ bf16=True if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else False,
263
+ gradient_checkpointing=True,
264
+ gradient_checkpointing_kwargs={'use_reentrant': False},
265
+
266
+ # Multi-GPU specific settings
267
+ deepspeed="ds_config.json", # We'll create this file below
268
+ ddp_find_unused_parameters=False,
269
+ )
270
+
271
+ # --- Create DeepSpeed configuration file ---
272
+ import json
273
+
274
+ # DeepSpeed ZeRO-3 config optimized for T4 GPUs
275
+ ds_config = {
276
+ "fp16": {
277
+ "enabled": "auto",
278
+ "loss_scale": 0,
279
+ "loss_scale_window": 1000,
280
+ "initial_scale_power": 16,
281
+ "hysteresis": 2,
282
+ "min_loss_scale": 1
283
+ },
284
+ "bf16": {
285
+ "enabled": "auto"
286
+ },
287
+ "zero_optimization": {
288
+ "stage": 3,
289
+ "offload_optimizer": {
290
+ "device": "cpu",
291
+ "pin_memory": True
292
+ },
293
+ "offload_param": {
294
+ "device": "cpu",
295
+ "pin_memory": True
296
+ },
297
+ "overlap_comm": True,
298
+ "contiguous_gradients": True,
299
+ "reduce_bucket_size": "auto",
300
+ "stage3_prefetch_bucket_size": "auto",
301
+ "stage3_param_persistence_threshold": "auto",
302
+ "gather_16bit_weights_on_model_save": True,
303
+ "stage3_max_live_parameters": 1e9,
304
+ "stage3_max_reuse_distance": 1e9
305
+ },
306
+ "gradient_accumulation_steps": GRAD_ACCUMULATION_STEPS,
307
+ "gradient_clipping": "auto",
308
+ "steps_per_print": 10,
309
+ "train_batch_size": "auto",
310
+ "train_micro_batch_size_per_gpu": "auto",
311
+ "wall_clock_breakdown": False
312
+ }
313
+
314
+ with open("ds_config.json", "w") as f:
315
+ json.dump(ds_config, f, indent=4)
316
+
317
+ # --- Initialize Trainer ---
318
+ trainer = Trainer(
319
+ model=model_to_train,
320
+ args=training_args,
321
+ train_dataset=train_dataset,
322
+ data_collator=data_collator,
323
+ )
324
+
325
+ print("Trainer initialized with DeepSpeed for multi-GPU training.")
326
+
327
+ # --- Start Training ---
328
+ # Clear cache before starting
329
+ gc.collect()
330
+ if torch.cuda.is_available():
331
+ torch.cuda.empty_cache()
332
+
333
+ try:
334
+ print("Starting distributed training on multiple GPUs...")
335
+ train_result = trainer.train()
336
+
337
+ # Save final model (adapter weights) and training state
338
+ final_save_path = os.path.join(training_args.output_dir, "final_checkpoint")
339
+ print(f"Saving final model checkpoint to {final_save_path}...")
340
+ trainer.save_model(final_save_path)
341
+ trainer.save_state()
342
+
343
+ # Log metrics
344
+ metrics = train_result.metrics
345
+ trainer.log_metrics("train", metrics)
346
+ trainer.save_metrics("train", metrics)
347
+
348
+ except Exception as e:
349
+ print(f"An error occurred during training: {e}")
350
+ raise e
351
+
352
+ print("Multi-GPU training process complete.")