Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -27,7 +27,7 @@ import shutil
|
|
27 |
|
28 |
# --- Configuration ---
|
29 |
YOUR_HF_USERNAME = "Twelve2five"
|
30 |
-
MODEL_REPO_NAME = "llama-3
|
31 |
DATASET_REPO_NAME = "podcast-dialogue-rvq-pairs-3items"
|
32 |
|
33 |
hf_model_repo_id = f"{YOUR_HF_USERNAME}/{MODEL_REPO_NAME}"
|
@@ -329,9 +329,9 @@ def train_model(
|
|
329 |
model_repo_name,
|
330 |
dataset_repo_name,
|
331 |
epochs=1,
|
332 |
-
batch_size=
|
333 |
grad_accum_steps=4,
|
334 |
-
learning_rate=2e-4,
|
335 |
progress=gr.Progress()
|
336 |
):
|
337 |
progress(0, desc="Setting up environment...")
|
@@ -350,7 +350,7 @@ def train_model(
|
|
350 |
except Exception as e:
|
351 |
log.append(f"Warning: Could not remove existing dataset files: {e}")
|
352 |
|
353 |
-
# Print GPU info
|
354 |
if torch.cuda.is_available():
|
355 |
log.append(f"Available GPUs: {torch.cuda.device_count()}")
|
356 |
for i in range(torch.cuda.device_count()):
|
@@ -362,7 +362,7 @@ def train_model(
|
|
362 |
try:
|
363 |
from datasets import Dataset
|
364 |
from huggingface_hub import snapshot_download
|
365 |
-
import torch
|
366 |
import transformers
|
367 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
368 |
from transformers import BitsAndBytesConfig, TrainingArguments, Trainer
|
@@ -390,25 +390,13 @@ def train_model(
|
|
390 |
# Create DeepSpeed config file
|
391 |
progress(0.1, desc="Setting up DeepSpeed configuration...")
|
392 |
|
393 |
-
# Create a
|
394 |
ds_config = {
|
395 |
-
"fp16": {
|
396 |
-
"enabled": "auto",
|
397 |
-
"loss_scale": 0,
|
398 |
-
"loss_scale_window": 1000,
|
399 |
-
"initial_scale_power": 16,
|
400 |
-
"hysteresis": 2,
|
401 |
-
"min_loss_scale": 1
|
402 |
-
},
|
403 |
"bf16": {
|
404 |
"enabled": "auto"
|
405 |
},
|
406 |
"zero_optimization": {
|
407 |
-
"stage":
|
408 |
-
"offload_optimizer": {
|
409 |
-
"device": "cpu",
|
410 |
-
"pin_memory": True
|
411 |
-
},
|
412 |
"contiguous_gradients": True,
|
413 |
"overlap_comm": True
|
414 |
},
|
@@ -432,113 +420,136 @@ def train_model(
|
|
432 |
snapshot_download(
|
433 |
repo_id=hf_model_repo_id,
|
434 |
local_dir=local_model_path,
|
|
|
|
|
435 |
)
|
436 |
log.append(f"Model files downloaded to {local_model_path}")
|
437 |
|
438 |
-
#
|
439 |
-
|
440 |
-
|
441 |
-
config_data = json.load(f)
|
442 |
-
|
443 |
-
# Check model architecture
|
444 |
-
model_type = config_data.get("model_type", "").lower()
|
445 |
-
log.append(f"Model architecture type: {model_type}")
|
446 |
-
|
447 |
-
# Set 4-bit quantization config
|
448 |
bnb_config = BitsAndBytesConfig(
|
449 |
load_in_4bit=True,
|
450 |
-
bnb_4bit_use_double_quant=True,
|
451 |
bnb_4bit_quant_type="nf4",
|
452 |
-
bnb_4bit_compute_dtype=torch.bfloat16
|
|
|
453 |
)
|
454 |
|
455 |
-
# Load
|
456 |
model = AutoModelForCausalLM.from_pretrained(
|
457 |
local_model_path,
|
458 |
quantization_config=bnb_config,
|
459 |
device_map="auto",
|
460 |
-
|
461 |
-
use_cache=False # No KV caching during training
|
462 |
)
|
463 |
-
|
464 |
-
# Load tokenizer
|
465 |
tokenizer = AutoTokenizer.from_pretrained(local_model_path)
|
|
|
|
|
466 |
if tokenizer.pad_token is None:
|
467 |
tokenizer.pad_token = tokenizer.eos_token
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
embedding = model.get_input_embeddings()
|
474 |
-
if hasattr(embedding, "weight"):
|
475 |
-
log.append(f"Input embedding shape: {embedding.weight.shape}")
|
476 |
-
|
477 |
-
except Exception as e:
|
478 |
-
error_msg = f"Error loading model: {str(e)}"
|
479 |
-
log.append(error_msg)
|
480 |
-
return "\n".join(log)
|
481 |
-
|
482 |
-
# --- Prepare for K-bit Training & Apply LoRA ---
|
483 |
-
progress(0.15, desc="Preparing model for fine-tuning...")
|
484 |
-
try:
|
485 |
model = prepare_model_for_kbit_training(model)
|
486 |
log.append("Model prepared for k-bit training")
|
487 |
|
488 |
-
# For Llama 3.2 1B the target modules might be slightly different
|
489 |
lora_config = LoraConfig(
|
490 |
task_type=TaskType.CAUSAL_LM,
|
491 |
-
r=
|
492 |
-
lora_alpha=
|
493 |
lora_dropout=0.05,
|
494 |
bias="none",
|
495 |
-
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"
|
496 |
)
|
497 |
peft_model = get_peft_model(model, lora_config)
|
498 |
trainable_params = peft_model.print_trainable_parameters()
|
499 |
log.append(f"LoRA applied to model")
|
500 |
model_to_train = peft_model
|
|
|
501 |
except Exception as e:
|
502 |
error_msg = f"Error preparing model for training: {str(e)}"
|
503 |
log.append(error_msg)
|
504 |
return "\n".join(log)
|
505 |
|
506 |
-
# ---
|
507 |
-
progress(0.
|
|
|
508 |
try:
|
509 |
-
# Download dataset
|
510 |
dataset_path = "./downloaded_dataset_files"
|
511 |
snapshot_download(
|
512 |
repo_id=hf_dataset_repo_id,
|
513 |
local_dir=dataset_path,
|
|
|
|
|
514 |
)
|
515 |
log.append(f"Dataset repository content downloaded to: {dataset_path}")
|
516 |
|
517 |
-
#
|
518 |
-
|
519 |
-
log.append(f"Found {len(rvq_pair_files)} RVQ pair files.")
|
520 |
|
521 |
-
# Load
|
522 |
-
|
|
|
523 |
|
524 |
-
|
525 |
-
|
|
|
526 |
all_pairs.extend(pairs)
|
527 |
|
528 |
-
# You may want to limit the data size for quicker testing
|
529 |
-
# all_pairs = all_pairs[:1000] # Uncomment to limit data size
|
530 |
-
|
531 |
log.append(f"Loaded a total of {len(all_pairs)} training pairs into memory.")
|
532 |
|
533 |
-
#
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
538 |
|
539 |
-
train_dataset =
|
540 |
|
541 |
-
#
|
542 |
from transformers import DataCollatorForLanguageModeling
|
543 |
|
544 |
data_collator = DataCollatorForLanguageModeling(
|
@@ -556,6 +567,7 @@ def train_model(
|
|
556 |
output_dir = f"./results_{model_repo_name}"
|
557 |
os.makedirs(output_dir, exist_ok=True)
|
558 |
|
|
|
559 |
training_args = TrainingArguments(
|
560 |
output_dir=output_dir,
|
561 |
num_train_epochs=float(epochs),
|
@@ -574,10 +586,10 @@ def train_model(
|
|
574 |
lr_scheduler_type="cosine",
|
575 |
report_to="tensorboard",
|
576 |
bf16=True if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else False,
|
577 |
-
gradient_checkpointing=True,
|
578 |
gradient_checkpointing_kwargs={'use_reentrant': False},
|
579 |
ddp_find_unused_parameters=False,
|
580 |
-
deepspeed=ds_config_path if n_gpus > 1 else None,
|
581 |
)
|
582 |
|
583 |
# --- Initialize Trainer ---
|
@@ -639,9 +651,9 @@ def create_interface():
|
|
639 |
dataset_repo = gr.Textbox(label="Dataset Repository Name", value="podcast-dialogue-rvq-pairs-3items")
|
640 |
|
641 |
with gr.Column():
|
642 |
-
epochs = gr.Number(label="Number of Epochs", value=
|
643 |
-
batch_size = gr.Number(label="Batch Size per Device", value=
|
644 |
-
grad_accum = gr.Number(label="Gradient Accumulation Steps", value=
|
645 |
lr = gr.Number(label="Learning Rate", value=2e-4)
|
646 |
|
647 |
start_btn = gr.Button("Start Training")
|
|
|
27 |
|
28 |
# --- Configuration ---
|
29 |
YOUR_HF_USERNAME = "Twelve2five"
|
30 |
+
MODEL_REPO_NAME = "llama-3-8b-rvq-resized"
|
31 |
DATASET_REPO_NAME = "podcast-dialogue-rvq-pairs-3items"
|
32 |
|
33 |
hf_model_repo_id = f"{YOUR_HF_USERNAME}/{MODEL_REPO_NAME}"
|
|
|
329 |
model_repo_name,
|
330 |
dataset_repo_name,
|
331 |
epochs=1,
|
332 |
+
batch_size=4, # Increased for A100
|
333 |
grad_accum_steps=4,
|
334 |
+
learning_rate=2e-4,
|
335 |
progress=gr.Progress()
|
336 |
):
|
337 |
progress(0, desc="Setting up environment...")
|
|
|
350 |
except Exception as e:
|
351 |
log.append(f"Warning: Could not remove existing dataset files: {e}")
|
352 |
|
353 |
+
# Print GPU info - using imported torch, not a local variable
|
354 |
if torch.cuda.is_available():
|
355 |
log.append(f"Available GPUs: {torch.cuda.device_count()}")
|
356 |
for i in range(torch.cuda.device_count()):
|
|
|
362 |
try:
|
363 |
from datasets import Dataset
|
364 |
from huggingface_hub import snapshot_download
|
365 |
+
# Don't import torch again, since it's already imported
|
366 |
import transformers
|
367 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
368 |
from transformers import BitsAndBytesConfig, TrainingArguments, Trainer
|
|
|
390 |
# Create DeepSpeed config file
|
391 |
progress(0.1, desc="Setting up DeepSpeed configuration...")
|
392 |
|
393 |
+
# Create a simpler config since we have plenty of memory on A100
|
394 |
ds_config = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
395 |
"bf16": {
|
396 |
"enabled": "auto"
|
397 |
},
|
398 |
"zero_optimization": {
|
399 |
+
"stage": 1, # Lower stage is fine for A100-80GB
|
|
|
|
|
|
|
|
|
400 |
"contiguous_gradients": True,
|
401 |
"overlap_comm": True
|
402 |
},
|
|
|
420 |
snapshot_download(
|
421 |
repo_id=hf_model_repo_id,
|
422 |
local_dir=local_model_path,
|
423 |
+
use_auth_token=False,
|
424 |
+
resume_download=True
|
425 |
)
|
426 |
log.append(f"Model files downloaded to {local_model_path}")
|
427 |
|
428 |
+
# Create a bnb configuration for loading the model in 4-bit
|
429 |
+
# Not strictly necessary for A100 but keeps memory usage lower
|
430 |
+
progress(0.25, desc="Loading model...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
431 |
bnb_config = BitsAndBytesConfig(
|
432 |
load_in_4bit=True,
|
|
|
433 |
bnb_4bit_quant_type="nf4",
|
434 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
435 |
+
bnb_4bit_use_double_quant=False
|
436 |
)
|
437 |
|
438 |
+
# Load model and tokenizer
|
439 |
model = AutoModelForCausalLM.from_pretrained(
|
440 |
local_model_path,
|
441 |
quantization_config=bnb_config,
|
442 |
device_map="auto",
|
443 |
+
torch_dtype=torch.bfloat16,
|
|
|
444 |
)
|
|
|
|
|
445 |
tokenizer = AutoTokenizer.from_pretrained(local_model_path)
|
446 |
+
|
447 |
+
# Handle tokenizer settings
|
448 |
if tokenizer.pad_token is None:
|
449 |
tokenizer.pad_token = tokenizer.eos_token
|
450 |
+
|
451 |
+
log.append(f"Loaded model vocab size: {tokenizer.vocab_size}")
|
452 |
+
log.append(f"Input embedding shape: {model.get_input_embeddings().weight.shape}")
|
453 |
+
|
454 |
+
# PEFT Configuration (Smaller LoRA for faster iteration)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
455 |
model = prepare_model_for_kbit_training(model)
|
456 |
log.append("Model prepared for k-bit training")
|
457 |
|
|
|
458 |
lora_config = LoraConfig(
|
459 |
task_type=TaskType.CAUSAL_LM,
|
460 |
+
r=16, # Keeping higher rank for A100
|
461 |
+
lora_alpha=32,
|
462 |
lora_dropout=0.05,
|
463 |
bias="none",
|
464 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"] # Fewer modules for faster training
|
465 |
)
|
466 |
peft_model = get_peft_model(model, lora_config)
|
467 |
trainable_params = peft_model.print_trainable_parameters()
|
468 |
log.append(f"LoRA applied to model")
|
469 |
model_to_train = peft_model
|
470 |
+
|
471 |
except Exception as e:
|
472 |
error_msg = f"Error preparing model for training: {str(e)}"
|
473 |
log.append(error_msg)
|
474 |
return "\n".join(log)
|
475 |
|
476 |
+
# --- Download and Process Dataset ---
|
477 |
+
progress(0.4, desc="Downloading dataset...")
|
478 |
+
|
479 |
try:
|
|
|
480 |
dataset_path = "./downloaded_dataset_files"
|
481 |
snapshot_download(
|
482 |
repo_id=hf_dataset_repo_id,
|
483 |
local_dir=dataset_path,
|
484 |
+
use_auth_token=False,
|
485 |
+
resume_download=True
|
486 |
)
|
487 |
log.append(f"Dataset repository content downloaded to: {dataset_path}")
|
488 |
|
489 |
+
# Load dataset from PT files
|
490 |
+
progress(0.5, desc="Processing dataset...")
|
|
|
491 |
|
492 |
+
# Load RVQ pairs
|
493 |
+
pair_files = glob.glob(f"{dataset_path}/*_rvq_pairs.pt")
|
494 |
+
log.append(f"Found {len(pair_files)} RVQ pair files.")
|
495 |
|
496 |
+
all_pairs = []
|
497 |
+
for file in pair_files:
|
498 |
+
pairs = torch.load(file)
|
499 |
all_pairs.extend(pairs)
|
500 |
|
|
|
|
|
|
|
501 |
log.append(f"Loaded a total of {len(all_pairs)} training pairs into memory.")
|
502 |
|
503 |
+
# Process pairs into a format suitable for training
|
504 |
+
all_texts = []
|
505 |
+
for pair in all_pairs:
|
506 |
+
# Create instruction format
|
507 |
+
if isinstance(pair, dict):
|
508 |
+
instruction = pair.get("instruction", "")
|
509 |
+
input_text = pair.get("input", "")
|
510 |
+
output = pair.get("output", "")
|
511 |
+
|
512 |
+
# ALPACA format
|
513 |
+
if instruction and input_text:
|
514 |
+
text = f"### Instruction: {instruction}\n### Input: {input_text}\n### Response: {output}"
|
515 |
+
elif instruction:
|
516 |
+
text = f"### Instruction: {instruction}\n### Response: {output}"
|
517 |
+
else:
|
518 |
+
text = output
|
519 |
+
else:
|
520 |
+
# Simple prompt-completion format
|
521 |
+
if isinstance(pair, tuple) and len(pair) == 2:
|
522 |
+
prompt, completion = pair
|
523 |
+
text = f"{prompt}{completion}"
|
524 |
+
else:
|
525 |
+
text = str(pair)
|
526 |
+
|
527 |
+
all_texts.append({"text": text})
|
528 |
+
|
529 |
+
# Create HF dataset
|
530 |
+
train_dataset = Dataset.from_list(all_texts)
|
531 |
+
|
532 |
+
# Function to tokenize the dataset
|
533 |
+
def tokenize_function(examples):
|
534 |
+
return tokenizer(
|
535 |
+
examples["text"],
|
536 |
+
padding=False,
|
537 |
+
truncation=True,
|
538 |
+
max_length=2048,
|
539 |
+
return_tensors=None,
|
540 |
+
)
|
541 |
+
|
542 |
+
# Tokenize the dataset
|
543 |
+
tokenized_dataset = train_dataset.map(
|
544 |
+
tokenize_function,
|
545 |
+
batched=True,
|
546 |
+
remove_columns=["text"],
|
547 |
+
desc="Tokenizing dataset",
|
548 |
+
)
|
549 |
|
550 |
+
train_dataset = tokenized_dataset
|
551 |
|
552 |
+
# Data collator
|
553 |
from transformers import DataCollatorForLanguageModeling
|
554 |
|
555 |
data_collator = DataCollatorForLanguageModeling(
|
|
|
567 |
output_dir = f"./results_{model_repo_name}"
|
568 |
os.makedirs(output_dir, exist_ok=True)
|
569 |
|
570 |
+
# Optimize settings for A100
|
571 |
training_args = TrainingArguments(
|
572 |
output_dir=output_dir,
|
573 |
num_train_epochs=float(epochs),
|
|
|
586 |
lr_scheduler_type="cosine",
|
587 |
report_to="tensorboard",
|
588 |
bf16=True if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else False,
|
589 |
+
gradient_checkpointing=True, # Still useful for efficiency
|
590 |
gradient_checkpointing_kwargs={'use_reentrant': False},
|
591 |
ddp_find_unused_parameters=False,
|
592 |
+
deepspeed=ds_config_path if n_gpus > 1 else None, # Only use DeepSpeed for multi-GPU
|
593 |
)
|
594 |
|
595 |
# --- Initialize Trainer ---
|
|
|
651 |
dataset_repo = gr.Textbox(label="Dataset Repository Name", value="podcast-dialogue-rvq-pairs-3items")
|
652 |
|
653 |
with gr.Column():
|
654 |
+
epochs = gr.Number(label="Number of Epochs", value=3, minimum=1, maximum=10)
|
655 |
+
batch_size = gr.Number(label="Batch Size per Device", value=4, minimum=1, maximum=16)
|
656 |
+
grad_accum = gr.Number(label="Gradient Accumulation Steps", value=2, minimum=1, maximum=16)
|
657 |
lr = gr.Number(label="Learning Rate", value=2e-4)
|
658 |
|
659 |
start_btn = gr.Button("Start Training")
|