Twelve2five commited on
Commit
19e1ed6
·
verified ·
1 Parent(s): 1ac2f31

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +364 -0
app.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import glob
4
+ import gc
5
+ from transformers import (
6
+ AutoModelForCausalLM,
7
+ BitsAndBytesConfig,
8
+ TrainingArguments,
9
+ Trainer
10
+ )
11
+ from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
12
+ from datasets import Dataset
13
+ from huggingface_hub import snapshot_download
14
+ from tqdm import tqdm
15
+ import gradio as gr
16
+ import math
17
+
18
+ # --- Configuration ---
19
+ YOUR_HF_USERNAME = "Twelve2five"
20
+ MODEL_REPO_NAME = "llama-3-8b-rvq-resized"
21
+ DATASET_REPO_NAME = "podcast-dialogue-rvq-pairs-3items"
22
+
23
+ hf_model_repo_id = f"{YOUR_HF_USERNAME}/{MODEL_REPO_NAME}"
24
+ hf_dataset_repo_id = f"{YOUR_HF_USERNAME}/{DATASET_REPO_NAME}"
25
+
26
+ # Output directories
27
+ OUTPUT_TRAINING_DIR = "./llama3-8b-rvq-qlora-finetuned-run"
28
+ LOGGING_DIR = "./llama3-8b-rvq-qlora-logs-run"
29
+ local_download_path = "./downloaded_dataset_files"
30
+
31
+ # Training parameters
32
+ NUM_EPOCHS = 1
33
+ BATCH_SIZE_PER_DEVICE = 2
34
+ GRAD_ACCUMULATION_STEPS = 4
35
+ LEARNING_RATE = 1e-4
36
+ WEIGHT_DECAY = 0.01
37
+ WARMUP_RATIO = 0.03
38
+ LR_SCHEDULER = "cosine"
39
+ OPTIMIZER = "paged_adamw_8bit"
40
+
41
+ def seq2seq_causal_collator(features):
42
+ """
43
+ Collator that concatenates context (input_ids) and target (labels)
44
+ for Causal LM sequence-to-sequence training.
45
+ Masks the loss for the context part of the sequence.
46
+ Pads sequences to the maximum length in the batch.
47
+ """
48
+ batch = {}
49
+ concatenated_input_ids = []
50
+ concatenated_labels = []
51
+ max_len = 0
52
+
53
+ # --- First pass: Concatenate, create masked labels, find max length ---
54
+ for feature in features:
55
+ # Dataset transform should provide tensors here
56
+ input_ids = feature['input_ids']
57
+ labels = feature['labels']
58
+
59
+ # Ensure tensors are 1D (handle potential extra dims if any)
60
+ if input_ids.dim() > 1: input_ids = input_ids.squeeze()
61
+ if labels.dim() > 1: labels = labels.squeeze()
62
+
63
+ context_len = input_ids.shape[0]
64
+ target_len = labels.shape[0]
65
+
66
+ # Concatenate context and target for input
67
+ combined_ids = torch.cat([input_ids, labels], dim=0)
68
+ concatenated_input_ids.append(combined_ids)
69
+
70
+ # Create labels: -100 for context, actual labels for target
71
+ masked_labels = torch.cat([
72
+ torch.full((context_len,), -100, dtype=torch.long, device=input_ids.device),
73
+ labels
74
+ ], dim=0)
75
+ concatenated_labels.append(masked_labels)
76
+
77
+ # Track max length for padding
78
+ if combined_ids.shape[0] > max_len:
79
+ max_len = combined_ids.shape[0]
80
+
81
+ # --- Second pass: Pad to max length ---
82
+ padded_input_ids = []
83
+ padded_labels = []
84
+ input_pad_token_id = 0
85
+ label_pad_token_id = -100
86
+
87
+ for i in range(len(features)):
88
+ ids = concatenated_input_ids[i]
89
+ lbls = concatenated_labels[i]
90
+
91
+ padding_len = max_len - ids.shape[0]
92
+
93
+ # Pad on the right side
94
+ padded_input_ids.append(torch.nn.functional.pad(
95
+ ids, (0, padding_len), value=input_pad_token_id
96
+ ))
97
+ padded_labels.append(torch.nn.functional.pad(
98
+ lbls, (0, padding_len), value=label_pad_token_id
99
+ ))
100
+
101
+ # --- Stack and create final batch ---
102
+ batch['input_ids'] = torch.stack(padded_input_ids)
103
+ batch['labels'] = torch.stack(padded_labels)
104
+
105
+ # Create attention mask (1 for real tokens, 0 for padding)
106
+ batch['attention_mask'] = batch['input_ids'].ne(input_pad_token_id).long()
107
+
108
+ return batch
109
+
110
+ def prepare_for_dataset(batch):
111
+ output = {'input_ids': [], 'labels': []}
112
+ for item in batch:
113
+ output['input_ids'].append(item['input_ids'].cpu().tolist())
114
+ output['labels'].append(item['labels'].cpu().tolist())
115
+ return output
116
+
117
+ def load_model():
118
+ # For HF Spaces, we use the system CUDA if available
119
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
120
+ print(f"Loading base model architecture from: {hf_model_repo_id}")
121
+ print(f"Using device: {DEVICE}")
122
+
123
+ # --- Quantization Configuration ---
124
+ bnb_config = BitsAndBytesConfig(
125
+ load_in_4bit=True,
126
+ bnb_4bit_quant_type="nf4",
127
+ bnb_4bit_compute_dtype=torch.bfloat16,
128
+ bnb_4bit_use_double_quant=True,
129
+ )
130
+
131
+ # --- Load Base Model (with quantization) ---
132
+ try:
133
+ model = AutoModelForCausalLM.from_pretrained(
134
+ hf_model_repo_id,
135
+ quantization_config=bnb_config,
136
+ device_map="auto",
137
+ trust_remote_code=True
138
+ )
139
+ print(f"Loaded model vocab size: {model.config.vocab_size}")
140
+ print(f"Input embedding shape: {model.get_input_embeddings().weight.shape}")
141
+ except Exception as e:
142
+ print(f"Error loading model: {e}")
143
+ return None
144
+
145
+ # --- Prepare for K-bit Training & Apply LoRA ---
146
+ model = prepare_model_for_kbit_training(model)
147
+
148
+ lora_config = LoraConfig(
149
+ task_type=TaskType.CAUSAL_LM,
150
+ r=16,
151
+ lora_alpha=32,
152
+ lora_dropout=0.05,
153
+ bias="none",
154
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
155
+ )
156
+
157
+ peft_model = get_peft_model(model, lora_config)
158
+ peft_model.print_trainable_parameters()
159
+
160
+ # Cleanup
161
+ gc.collect()
162
+ if torch.cuda.is_available():
163
+ torch.cuda.empty_cache()
164
+
165
+ return peft_model
166
+
167
+ def load_dataset():
168
+ # --- Download the dataset repository files ---
169
+ try:
170
+ os.makedirs(local_download_path, exist_ok=True)
171
+ downloaded_repo_root = snapshot_download(
172
+ repo_id=hf_dataset_repo_id,
173
+ repo_type="dataset",
174
+ local_dir=local_download_path,
175
+ local_dir_use_symlinks=False
176
+ )
177
+ print(f"Dataset repository content downloaded to: {downloaded_repo_root}")
178
+ except Exception as e:
179
+ print(f"Error downloading dataset: {e}")
180
+ return None
181
+
182
+ # --- Load .pt files into a Hugging Face Dataset object ---
183
+ pairs_dir = os.path.join(downloaded_repo_root, "final_rvq_pairs")
184
+ all_pair_files = glob.glob(os.path.join(pairs_dir, "*_rvq_pairs.pt"))
185
+
186
+ if not all_pair_files:
187
+ all_pair_files = glob.glob(os.path.join(downloaded_repo_root, "*_rvq_pairs.pt"))
188
+ if not all_pair_files:
189
+ print("No RVQ pair files found!")
190
+ return None
191
+
192
+ print(f"Found {len(all_pair_files)} RVQ pair files.")
193
+
194
+ # Load data from .pt files into memory
195
+ all_data_pairs = []
196
+ for file_path in tqdm(all_pair_files, desc="Loading pair files"):
197
+ try:
198
+ episode_pairs = torch.load(file_path, map_location='cpu')
199
+ all_data_pairs.extend(episode_pairs)
200
+ except Exception as e:
201
+ print(f"Warning: Could not load file {file_path}: {e}")
202
+
203
+ if not all_data_pairs:
204
+ return None
205
+
206
+ print(f"Loaded {len(all_data_pairs)} training pairs.")
207
+
208
+ # Convert to Hugging Face Dataset
209
+ chunk_size = 1000
210
+ processed_data = {'input_ids': [], 'labels': []}
211
+ for i in tqdm(range(0, len(all_data_pairs), chunk_size), desc="Preparing data"):
212
+ batch = all_data_pairs[i:i + chunk_size]
213
+ prepared_batch = prepare_for_dataset(batch)
214
+ processed_data['input_ids'].extend(prepared_batch['input_ids'])
215
+ processed_data['labels'].extend(prepared_batch['labels'])
216
+
217
+ hf_dataset = Dataset.from_dict(processed_data)
218
+
219
+ # Transform to get tensors back
220
+ hf_dataset.set_transform(lambda batch: {
221
+ 'input_ids': [torch.tensor(ids, dtype=torch.long) for ids in batch['input_ids']],
222
+ 'labels': [torch.tensor(lbls, dtype=torch.long) for lbls in batch['labels']]
223
+ })
224
+
225
+ # Cleanup
226
+ del all_data_pairs
227
+ del processed_data
228
+ gc.collect()
229
+
230
+ return hf_dataset
231
+
232
+ def train_model(progress=gr.Progress()):
233
+ # Create directories
234
+ os.makedirs(OUTPUT_TRAINING_DIR, exist_ok=True)
235
+ os.makedirs(LOGGING_DIR, exist_ok=True)
236
+
237
+ progress(0, desc="Loading model...")
238
+ model_to_train = load_model()
239
+ if model_to_train is None:
240
+ return "Failed to load model."
241
+
242
+ progress(0.2, desc="Loading dataset...")
243
+ train_dataset = load_dataset()
244
+ if train_dataset is None:
245
+ return "Failed to load dataset."
246
+
247
+ progress(0.4, desc="Setting up trainer...")
248
+ # Calculate steps and warmup
249
+ total_train_batch_size = BATCH_SIZE_PER_DEVICE * GRAD_ACCUMULATION_STEPS
250
+ num_training_steps = math.ceil((len(train_dataset) * NUM_EPOCHS) / total_train_batch_size)
251
+ num_warmup_steps = int(num_training_steps * WARMUP_RATIO)
252
+
253
+ # Logging frequency
254
+ steps_per_epoch = math.ceil(len(train_dataset) / total_train_batch_size)
255
+ LOGGING_STEPS = max(10, steps_per_epoch // 15)
256
+ SAVE_STEPS = max(50, steps_per_epoch // 10)
257
+
258
+ training_args = TrainingArguments(
259
+ output_dir=OUTPUT_TRAINING_DIR,
260
+ num_train_epochs=NUM_EPOCHS,
261
+ per_device_train_batch_size=BATCH_SIZE_PER_DEVICE,
262
+ gradient_accumulation_steps=GRAD_ACCUMULATION_STEPS,
263
+ optim=OPTIMIZER,
264
+ logging_dir=LOGGING_DIR,
265
+ logging_strategy="steps",
266
+ logging_steps=LOGGING_STEPS,
267
+ save_strategy="steps",
268
+ save_steps=SAVE_STEPS,
269
+ save_total_limit=2,
270
+ learning_rate=LEARNING_RATE,
271
+ weight_decay=WEIGHT_DECAY,
272
+ warmup_steps=num_warmup_steps,
273
+ lr_scheduler_type=LR_SCHEDULER,
274
+ report_to="tensorboard",
275
+ fp16=False,
276
+ bf16=True if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else False,
277
+ gradient_checkpointing=True,
278
+ gradient_checkpointing_kwargs={'use_reentrant': False},
279
+ )
280
+
281
+ trainer = Trainer(
282
+ model=model_to_train,
283
+ args=training_args,
284
+ train_dataset=train_dataset,
285
+ data_collator=seq2seq_causal_collator,
286
+ )
287
+
288
+ progress(0.5, desc="Starting training...")
289
+ # Clear cache before starting
290
+ gc.collect()
291
+ if torch.cuda.is_available():
292
+ torch.cuda.empty_cache()
293
+
294
+ try:
295
+ train_result = trainer.train()
296
+
297
+ progress(0.9, desc="Saving model...")
298
+ # Save final model and training state
299
+ final_save_path = os.path.join(training_args.output_dir, "final_checkpoint")
300
+ trainer.save_model(final_save_path)
301
+ trainer.save_state()
302
+
303
+ # Log metrics
304
+ metrics = train_result.metrics
305
+ trainer.log_metrics("train", metrics)
306
+ trainer.save_metrics("train", metrics)
307
+
308
+ progress(1.0, desc="Training complete!")
309
+ return f"Training completed successfully. Model saved to {final_save_path}"
310
+
311
+ except Exception as e:
312
+ return f"An error occurred during training: {str(e)}"
313
+
314
+ # Create Gradio interface
315
+ def create_ui():
316
+ with gr.Blocks() as demo:
317
+ gr.Markdown("# Fine-tune LLaMA 3 8B with QLoRA")
318
+
319
+ with gr.Tab("Training"):
320
+ train_button = gr.Button("Start Fine-tuning")
321
+ result_text = gr.Textbox(label="Training Results", interactive=False)
322
+
323
+ train_button.click(train_model, outputs=result_text)
324
+
325
+ with gr.Tab("About"):
326
+ gr.Markdown("""
327
+ ## Information
328
+ This is a Hugging Face Space version of the original Google Colab notebook.
329
+
330
+ It fine-tunes a quantized LLaMA 3 8B model using QLoRA on podcast dialogue data.
331
+
332
+ ### Model
333
+ - Base Model: {YOUR_HF_USERNAME}/{MODEL_REPO_NAME}
334
+ - Using 4-bit quantization with LoRA adapters
335
+
336
+ ### Dataset
337
+ - Custom dataset: {YOUR_HF_USERNAME}/{DATASET_REPO_NAME}
338
+ - Contains podcast dialogue pairs processed for training
339
+
340
+ ### Training Setup
341
+ - QLoRA fine-tuning
342
+ - Epochs: {NUM_EPOCHS}
343
+ - Batch size: {BATCH_SIZE_PER_DEVICE} with {GRAD_ACCUMULATION_STEPS} gradient accumulation steps
344
+ - Learning rate: {LEARNING_RATE}
345
+ """.format(
346
+ YOUR_HF_USERNAME=YOUR_HF_USERNAME,
347
+ MODEL_REPO_NAME=MODEL_REPO_NAME,
348
+ DATASET_REPO_NAME=DATASET_REPO_NAME,
349
+ NUM_EPOCHS=NUM_EPOCHS,
350
+ BATCH_SIZE_PER_DEVICE=BATCH_SIZE_PER_DEVICE,
351
+ GRAD_ACCUMULATION_STEPS=GRAD_ACCUMULATION_STEPS,
352
+ LEARNING_RATE=LEARNING_RATE
353
+ ))
354
+
355
+ return demo
356
+
357
+ # Main entry point
358
+ if __name__ == "__main__":
359
+ # Install dependencies first if needed
360
+ # !pip install -q -U transformers accelerate bitsandbytes peft torch datasets huggingface_hub gradio
361
+
362
+ # Create and launch the UI
363
+ demo = create_ui()
364
+ demo.launch()