Twelve2five commited on
Commit
0cfd18e
·
verified ·
1 Parent(s): 30d7ae2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +370 -14
app.py CHANGED
@@ -1,19 +1,327 @@
1
- import gradio as gr
2
- import subprocess
3
- import sys
4
  import os
5
- import glob
6
- import json
7
- import math
8
  import torch
 
9
  import gc
10
- from tqdm import tqdm
 
 
 
 
 
 
 
 
 
 
11
  from datasets import Dataset
12
  from huggingface_hub import snapshot_download
13
- from transformers import AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, Trainer
14
- from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Function to run the training process
17
  def train_model(
18
  hf_username,
19
  model_repo_name,
@@ -61,8 +369,22 @@ def train_model(
61
  # --- Load Base Model (with quantization) ---
62
  progress(0.1, desc="Loading base model...")
63
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  model = AutoModelForCausalLM.from_pretrained(
65
  hf_model_repo_id,
 
66
  quantization_config=bnb_config,
67
  device_map="auto",
68
  trust_remote_code=True
@@ -72,12 +394,46 @@ def train_model(
72
  except Exception as e:
73
  error_msg = f"Error loading model from Hub: {e}"
74
  log.append(error_msg)
75
- return "\n".join(log)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- # --- Prepare for K-bit Training & Apply LoRA ---
78
- progress(0.15, desc="Preparing model for fine-tuning...")
79
  model = prepare_model_for_kbit_training(model)
80
 
 
81
  lora_config = LoraConfig(
82
  task_type=TaskType.CAUSAL_LM,
83
  r=16,
@@ -428,4 +784,4 @@ def create_interface():
428
  # Create and launch the interface
429
  demo = create_interface()
430
  if __name__ == "__main__":
431
- demo.launch()
 
 
 
 
1
  import os
 
 
 
2
  import torch
3
+ import glob
4
  import gc
5
+ from transformers import (
6
+ AutoModelForCausalLM,
7
+ AutoTokenizer,
8
+ BitsAndBytesConfig,
9
+ TrainingArguments,
10
+ Trainer,
11
+ DataCollatorForLanguageModeling,
12
+ AutoTokenizer,
13
+ LlamaConfig
14
+ )
15
+ from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
16
  from datasets import Dataset
17
  from huggingface_hub import snapshot_download
18
+ from tqdm import tqdm
19
+ import gradio as gr
20
+ import math
21
+ from accelerate import Accelerator
22
+ import subprocess
23
+ import sys
24
+ import json
25
+
26
+ # --- Configuration ---
27
+ YOUR_HF_USERNAME = "Twelve2five"
28
+ MODEL_REPO_NAME = "llama-3-8b-rvq-resized"
29
+ DATASET_REPO_NAME = "podcast-dialogue-rvq-pairs-3items"
30
+
31
+ hf_model_repo_id = f"{YOUR_HF_USERNAME}/{MODEL_REPO_NAME}"
32
+ hf_dataset_repo_id = f"{YOUR_HF_USERNAME}/{DATASET_REPO_NAME}"
33
+
34
+ # Output directories
35
+ OUTPUT_TRAINING_DIR = "./llama3-8b-rvq-qlora-finetuned-run"
36
+ LOGGING_DIR = "./llama3-8b-rvq-qlora-logs-run"
37
+ local_download_path = "./downloaded_dataset_files"
38
+
39
+ # Training parameters
40
+ NUM_EPOCHS = 1
41
+ BATCH_SIZE_PER_DEVICE = 1
42
+ GRAD_ACCUMULATION_STEPS = 64
43
+ LEARNING_RATE = 1e-4
44
+ WEIGHT_DECAY = 0.01
45
+ WARMUP_RATIO = 0.03
46
+ LR_SCHEDULER = "cosine"
47
+ OPTIMIZER = "paged_adamw_8bit"
48
+ MAX_SEQ_LENGTH = 256
49
+ MICRO_BATCH_SIZE = 1
50
+
51
+ # Multi-GPU configuration
52
+ accelerator = Accelerator()
53
+
54
+ # Configure environment for multi-GPU
55
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
56
+
57
+ # Print GPU information
58
+ print(f"Available GPUs: {torch.cuda.device_count()}")
59
+ for i in range(torch.cuda.device_count()):
60
+ print(f"GPU {i}: {torch.cuda.get_device_name(i)} with {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")
61
+
62
+ def seq2seq_causal_collator(features):
63
+ """
64
+ Collator that concatenates context (input_ids) and target (labels)
65
+ for Causal LM sequence-to-sequence training.
66
+ Masks the loss for the context part of the sequence.
67
+ Pads sequences to the maximum length in the batch.
68
+ """
69
+ batch = {}
70
+ concatenated_input_ids = []
71
+ concatenated_labels = []
72
+ max_len = 0
73
+
74
+ # --- First pass: Concatenate, create masked labels, find max length ---
75
+ for feature in features:
76
+ # Dataset transform should provide tensors here
77
+ input_ids = feature['input_ids']
78
+ labels = feature['labels']
79
+
80
+ # Ensure tensors are 1D (handle potential extra dims if any)
81
+ if input_ids.dim() > 1: input_ids = input_ids.squeeze()
82
+ if labels.dim() > 1: labels = labels.squeeze()
83
+
84
+ context_len = input_ids.shape[0]
85
+ target_len = labels.shape[0]
86
+
87
+ # Concatenate context and target for input
88
+ combined_ids = torch.cat([input_ids, labels], dim=0)
89
+ concatenated_input_ids.append(combined_ids)
90
+
91
+ # Create labels: -100 for context, actual labels for target
92
+ masked_labels = torch.cat([
93
+ torch.full((context_len,), -100, dtype=torch.long, device=input_ids.device),
94
+ labels
95
+ ], dim=0)
96
+ concatenated_labels.append(masked_labels)
97
+
98
+ # Track max length for padding
99
+ if combined_ids.shape[0] > max_len:
100
+ max_len = combined_ids.shape[0]
101
+
102
+ # --- Second pass: Pad to max length ---
103
+ padded_input_ids = []
104
+ padded_labels = []
105
+ input_pad_token_id = 0
106
+ label_pad_token_id = -100
107
+
108
+ for i in range(len(features)):
109
+ ids = concatenated_input_ids[i]
110
+ lbls = concatenated_labels[i]
111
+
112
+ padding_len = max_len - ids.shape[0]
113
+
114
+ # Pad on the right side
115
+ padded_input_ids.append(torch.nn.functional.pad(
116
+ ids, (0, padding_len), value=input_pad_token_id
117
+ ))
118
+ padded_labels.append(torch.nn.functional.pad(
119
+ lbls, (0, padding_len), value=label_pad_token_id
120
+ ))
121
+
122
+ # --- Stack and create final batch ---
123
+ batch['input_ids'] = torch.stack(padded_input_ids)
124
+ batch['labels'] = torch.stack(padded_labels)
125
+
126
+ # Create attention mask (1 for real tokens, 0 for padding)
127
+ batch['attention_mask'] = batch['input_ids'].ne(input_pad_token_id).long()
128
+
129
+ return batch
130
+
131
+ def prepare_for_dataset(batch):
132
+ output = {'input_ids': [], 'labels': []}
133
+ for item in batch:
134
+ output['input_ids'].append(item['input_ids'].cpu().tolist())
135
+ output['labels'].append(item['labels'].cpu().tolist())
136
+ return output
137
+
138
+ def load_model():
139
+ print(f"Loading base model architecture from: {hf_model_repo_id}")
140
+
141
+ # Get information about GPU with most free memory
142
+ gpu_id = 0 # Default to first GPU
143
+ max_free_memory = 0
144
+
145
+ for i in range(torch.cuda.device_count()):
146
+ free_memory = torch.cuda.get_device_properties(i).total_memory - torch.cuda.memory_allocated(i)
147
+ if free_memory > max_free_memory:
148
+ max_free_memory = free_memory
149
+ gpu_id = i
150
+
151
+ print(f"Loading model on GPU {gpu_id} with {max_free_memory / 1e9:.2f}GB free memory")
152
+
153
+ # Configure quantization
154
+ bnb_config = BitsAndBytesConfig(
155
+ load_in_4bit=True,
156
+ bnb_4bit_use_double_quant=True,
157
+ bnb_4bit_quant_type="nf4",
158
+ bnb_4bit_compute_dtype=torch.bfloat16
159
+ )
160
+
161
+ # Load the model
162
+ try:
163
+ # First update transformers to make sure we have latest version
164
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "transformers"])
165
+
166
+ # Now try loading with explicit config class to avoid auto-detection issues
167
+ from transformers import LlamaConfig
168
+
169
+ # Load config first
170
+ config = LlamaConfig.from_pretrained(
171
+ hf_model_repo_id,
172
+ trust_remote_code=True
173
+ )
174
+
175
+ # Then load model with explicit config
176
+ model = AutoModelForCausalLM.from_pretrained(
177
+ hf_model_repo_id,
178
+ config=config,
179
+ quantization_config=bnb_config,
180
+ device_map="auto",
181
+ trust_remote_code=True
182
+ )
183
+ log.append(f"Loaded model vocab size: {model.config.vocab_size}")
184
+ log.append(f"Input embedding shape: {model.get_input_embeddings().weight.shape}")
185
+ except Exception as e:
186
+ error_msg = f"Error loading model from Hub: {e}"
187
+ log.append(error_msg)
188
+ # Try with a fallback method
189
+ try:
190
+ log.append("Attempting alternative loading method...")
191
+ # Try loading without auto detection
192
+ model = AutoModelForCausalLM.from_pretrained(
193
+ hf_model_repo_id,
194
+ quantization_config=bnb_config,
195
+ device_map="auto",
196
+ trust_remote_code=True,
197
+ torch_dtype=torch.bfloat16,
198
+ # Add these to help with the loading
199
+ revision="main",
200
+ low_cpu_mem_usage=True,
201
+ )
202
+ log.append("Alternative loading successful!")
203
+ log.append(f"Loaded model vocab size: {model.config.vocab_size}")
204
+ except Exception as e2:
205
+ log.append(f"Alternative loading also failed: {e2}")
206
+ return "\n".join(log)
207
+
208
+ # Load the official Meta tokenizer for LLaMA 3
209
+ tokenizer = AutoTokenizer.from_pretrained(
210
+ "meta-llama/Llama-3-8B", # Use the official Meta tokenizer
211
+ use_auth_token=os.environ.get("HF_TOKEN", None) # In case it's needed
212
+ )
213
+
214
+ if tokenizer is None:
215
+ # Fallback to another common foundation model tokenizer
216
+ print("Falling back to another tokenizer as Meta tokenizer requires auth token")
217
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
218
+
219
+ print(f"Loaded tokenizer vocabulary size: {len(tokenizer)}")
220
+
221
+ # Print information about input embeddings
222
+ print(f"Input embedding shape: {model.get_input_embeddings().weight.shape}")
223
+
224
+ # Prepare model for k-bit training
225
+ model = prepare_model_for_kbit_training(model)
226
+
227
+ # Define LoRA configuration
228
+ lora_config = LoraConfig(
229
+ r=16,
230
+ lora_alpha=32,
231
+ target_modules=[
232
+ "q_proj",
233
+ "k_proj",
234
+ "v_proj",
235
+ "o_proj",
236
+ "gate_proj",
237
+ "up_proj",
238
+ "down_proj",
239
+ ],
240
+ lora_dropout=0.05,
241
+ bias="none",
242
+ task_type=TaskType.CAUSAL_LM
243
+ )
244
+
245
+ # Apply LoRA to model
246
+ model = get_peft_model(model, lora_config)
247
+ model.print_trainable_parameters()
248
+
249
+ return model, tokenizer # Return both model and tokenizer
250
+
251
+ def load_dataset():
252
+ # --- Download the dataset repository files ---
253
+ try:
254
+ os.makedirs(local_download_path, exist_ok=True)
255
+ downloaded_repo_root = snapshot_download(
256
+ repo_id=hf_dataset_repo_id,
257
+ repo_type="dataset",
258
+ local_dir=local_download_path,
259
+ local_dir_use_symlinks=False
260
+ )
261
+ print(f"Dataset repository content downloaded to: {downloaded_repo_root}")
262
+ except Exception as e:
263
+ print(f"Error downloading dataset: {e}")
264
+ return None
265
+
266
+ # --- Load .pt files into a Hugging Face Dataset object ---
267
+ pairs_dir = os.path.join(downloaded_repo_root, "final_rvq_pairs")
268
+ all_pair_files = glob.glob(os.path.join(pairs_dir, "*_rvq_pairs.pt"))
269
+
270
+ if not all_pair_files:
271
+ all_pair_files = glob.glob(os.path.join(downloaded_repo_root, "*_rvq_pairs.pt"))
272
+ if not all_pair_files:
273
+ print("No RVQ pair files found!")
274
+ return None
275
+
276
+ print(f"Found {len(all_pair_files)} RVQ pair files.")
277
+
278
+ # Load data from .pt files into memory
279
+ all_data_pairs = []
280
+ for file_path in tqdm(all_pair_files, desc="Loading pair files"):
281
+ try:
282
+ episode_pairs = torch.load(file_path, map_location='cpu')
283
+ all_data_pairs.extend(episode_pairs)
284
+ except Exception as e:
285
+ print(f"Warning: Could not load file {file_path}: {e}")
286
+
287
+ if not all_data_pairs:
288
+ return None
289
+
290
+ print(f"Loaded {len(all_data_pairs)} training pairs.")
291
+
292
+ # Convert to Hugging Face Dataset
293
+ chunk_size = 1000
294
+ processed_data = {'input_ids': [], 'labels': []}
295
+ for i in tqdm(range(0, len(all_data_pairs), chunk_size), desc="Preparing data"):
296
+ batch = all_data_pairs[i:i + chunk_size]
297
+ prepared_batch = prepare_for_dataset(batch)
298
+ processed_data['input_ids'].extend(prepared_batch['input_ids'])
299
+ processed_data['labels'].extend(prepared_batch['labels'])
300
+
301
+ hf_dataset = Dataset.from_dict(processed_data)
302
+
303
+ # Transform to get tensors back
304
+ hf_dataset.set_transform(lambda batch: {
305
+ 'input_ids': [torch.tensor(ids, dtype=torch.long) for ids in batch['input_ids']],
306
+ 'labels': [torch.tensor(lbls, dtype=torch.long) for lbls in batch['labels']]
307
+ })
308
+
309
+ # Cleanup
310
+ del all_data_pairs
311
+ del processed_data
312
+ gc.collect()
313
+
314
+ return hf_dataset
315
+
316
+ # Memory cleaning function
317
+ def clean_memory():
318
+ gc.collect()
319
+ if torch.cuda.is_available():
320
+ for i in range(torch.cuda.device_count()):
321
+ with torch.cuda.device(f'cuda:{i}'):
322
+ torch.cuda.empty_cache()
323
+ torch.cuda.reset_peak_memory_stats()
324
 
 
325
  def train_model(
326
  hf_username,
327
  model_repo_name,
 
369
  # --- Load Base Model (with quantization) ---
370
  progress(0.1, desc="Loading base model...")
371
  try:
372
+ # First update transformers to make sure we have latest version
373
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "transformers"])
374
+
375
+ # Now try loading with explicit config class to avoid auto-detection issues
376
+ from transformers import LlamaConfig
377
+
378
+ # Load config first
379
+ config = LlamaConfig.from_pretrained(
380
+ hf_model_repo_id,
381
+ trust_remote_code=True
382
+ )
383
+
384
+ # Then load model with explicit config
385
  model = AutoModelForCausalLM.from_pretrained(
386
  hf_model_repo_id,
387
+ config=config,
388
  quantization_config=bnb_config,
389
  device_map="auto",
390
  trust_remote_code=True
 
394
  except Exception as e:
395
  error_msg = f"Error loading model from Hub: {e}"
396
  log.append(error_msg)
397
+ # Try with a fallback method
398
+ try:
399
+ log.append("Attempting alternative loading method...")
400
+ # Try loading without auto detection
401
+ model = AutoModelForCausalLM.from_pretrained(
402
+ hf_model_repo_id,
403
+ quantization_config=bnb_config,
404
+ device_map="auto",
405
+ trust_remote_code=True,
406
+ torch_dtype=torch.bfloat16,
407
+ # Add these to help with the loading
408
+ revision="main",
409
+ low_cpu_mem_usage=True,
410
+ )
411
+ log.append("Alternative loading successful!")
412
+ log.append(f"Loaded model vocab size: {model.config.vocab_size}")
413
+ except Exception as e2:
414
+ log.append(f"Alternative loading also failed: {e2}")
415
+ return "\n".join(log)
416
+
417
+ # Load the official Meta tokenizer for LLaMA 3
418
+ tokenizer = AutoTokenizer.from_pretrained(
419
+ "meta-llama/Llama-3-8B", # Use the official Meta tokenizer
420
+ use_auth_token=os.environ.get("HF_TOKEN", None) # In case it's needed
421
+ )
422
+
423
+ if tokenizer is None:
424
+ # Fallback to another common foundation model tokenizer
425
+ print("Falling back to another tokenizer as Meta tokenizer requires auth token")
426
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
427
+
428
+ print(f"Loaded tokenizer vocabulary size: {len(tokenizer)}")
429
+
430
+ # Print information about input embeddings
431
+ print(f"Input embedding shape: {model.get_input_embeddings().weight.shape}")
432
 
433
+ # Prepare model for k-bit training
 
434
  model = prepare_model_for_kbit_training(model)
435
 
436
+ # Define LoRA configuration
437
  lora_config = LoraConfig(
438
  task_type=TaskType.CAUSAL_LM,
439
  r=16,
 
784
  # Create and launch the interface
785
  demo = create_interface()
786
  if __name__ == "__main__":
787
+ demo.launch()