Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -330,49 +330,49 @@ def train_model(
|
|
330 |
dataset_repo_name,
|
331 |
epochs=1,
|
332 |
batch_size=1,
|
333 |
-
grad_accum_steps=
|
334 |
learning_rate=1e-4,
|
335 |
progress=gr.Progress()
|
336 |
):
|
337 |
progress(0, desc="Setting up environment...")
|
338 |
log = []
|
339 |
|
340 |
-
#
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
# Clean any cached files that might be causing issues
|
347 |
-
cache_dirs = [
|
348 |
-
os.path.expanduser("~/.cache/huggingface"),
|
349 |
-
os.path.expanduser("~/.cache/pip")
|
350 |
-
]
|
351 |
-
|
352 |
-
for cache_dir in cache_dirs:
|
353 |
-
if os.path.exists(cache_dir):
|
354 |
-
log.append(f"Cleaning cache directory: {cache_dir}")
|
355 |
-
try:
|
356 |
-
shutil.rmtree(cache_dir)
|
357 |
-
except Exception as e:
|
358 |
-
log.append(f"Warning: Could not clean {cache_dir}: {e}")
|
359 |
|
360 |
-
#
|
361 |
-
|
|
|
|
|
|
|
|
|
362 |
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
|
368 |
-
#
|
369 |
try:
|
370 |
from datasets import Dataset
|
371 |
from huggingface_hub import snapshot_download
|
372 |
import torch
|
373 |
import transformers
|
374 |
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
|
375 |
-
from transformers import BitsAndBytesConfig, TrainingArguments, Trainer
|
376 |
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
|
377 |
|
378 |
log.append(f"Transformers version: {transformers.__version__}")
|
@@ -393,59 +393,60 @@ def train_model(
|
|
393 |
n_gpus = torch.cuda.device_count()
|
394 |
log.append(f"Number of GPUs available: {n_gpus}")
|
395 |
|
396 |
-
# ---
|
397 |
-
bnb_config = BitsAndBytesConfig(
|
398 |
-
load_in_4bit=True,
|
399 |
-
bnb_4bit_quant_type="nf4",
|
400 |
-
bnb_4bit_compute_dtype=torch.bfloat16,
|
401 |
-
bnb_4bit_use_double_quant=True,
|
402 |
-
)
|
403 |
-
|
404 |
-
# --- Load Base Model (with quantization) ---
|
405 |
progress(0.1, desc="Loading base model...")
|
|
|
406 |
try:
|
407 |
-
#
|
408 |
-
local_model_path = "./model_files"
|
409 |
-
if os.path.exists(local_model_path):
|
410 |
-
shutil.rmtree(local_model_path) # Clean up any previous files
|
411 |
-
|
412 |
snapshot_download(
|
413 |
repo_id=hf_model_repo_id,
|
414 |
local_dir=local_model_path,
|
415 |
local_dir_use_symlinks=False
|
416 |
)
|
417 |
-
|
418 |
log.append(f"Model files downloaded to {local_model_path}")
|
419 |
|
420 |
-
#
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
log.append(f"Model architecture type: {config_data.get('model_type', 'unknown')}")
|
425 |
-
|
426 |
-
# Force model_type to llama
|
427 |
-
config_data["model_type"] = "llama"
|
428 |
-
if "architectures" in config_data:
|
429 |
-
config_data["architectures"] = ["LlamaForCausalLM"]
|
430 |
-
|
431 |
-
with open(os.path.join(local_model_path, "config.json"), "w") as f:
|
432 |
-
json.dump(config_data, f)
|
433 |
-
log.append("Updated config.json to use llama model_type")
|
434 |
|
435 |
-
|
436 |
-
|
437 |
-
local_model_path,
|
438 |
-
trust_remote_code=False
|
439 |
-
)
|
440 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
441 |
log.append(f"Successfully loaded config: {config.model_type}")
|
442 |
|
443 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
model = LlamaForCausalLM.from_pretrained(
|
445 |
local_model_path,
|
446 |
config=config,
|
447 |
quantization_config=bnb_config,
|
448 |
device_map="auto",
|
|
|
449 |
torch_dtype=torch.bfloat16,
|
450 |
low_cpu_mem_usage=True
|
451 |
)
|
@@ -455,20 +456,7 @@ def train_model(
|
|
455 |
except Exception as e:
|
456 |
error_msg = f"Error loading model: {str(e)}"
|
457 |
log.append(error_msg)
|
458 |
-
|
459 |
-
# Try a fallback approach
|
460 |
-
try:
|
461 |
-
log.append("Trying fallback approach with AutoModelForCausalLM...")
|
462 |
-
model = AutoModelForCausalLM.from_pretrained(
|
463 |
-
local_model_path,
|
464 |
-
device_map="auto",
|
465 |
-
torch_dtype=torch.bfloat16,
|
466 |
-
low_cpu_mem_usage=True
|
467 |
-
)
|
468 |
-
log.append(f"Fallback model loaded successfully")
|
469 |
-
except Exception as e2:
|
470 |
-
log.append(f"Fallback approach also failed: {str(e2)}")
|
471 |
-
return "\n".join(log)
|
472 |
|
473 |
# --- Prepare for K-bit Training & Apply LoRA ---
|
474 |
progress(0.15, desc="Preparing model for fine-tuning...")
|
@@ -476,218 +464,138 @@ def train_model(
|
|
476 |
model = prepare_model_for_kbit_training(model)
|
477 |
log.append("Model prepared for k-bit training")
|
478 |
|
|
|
479 |
lora_config = LoraConfig(
|
480 |
task_type=TaskType.CAUSAL_LM,
|
481 |
-
r=16
|
482 |
-
lora_alpha=32
|
483 |
lora_dropout=0.05,
|
484 |
bias="none",
|
485 |
-
|
|
|
486 |
)
|
|
|
|
|
487 |
peft_model = get_peft_model(model, lora_config)
|
488 |
-
trainable_params = peft_model.print_trainable_parameters()
|
489 |
-
log.append(f"LoRA applied to model")
|
490 |
model_to_train = peft_model
|
|
|
|
|
|
|
|
|
|
|
|
|
491 |
except Exception as e:
|
492 |
error_msg = f"Error preparing model for training: {str(e)}"
|
493 |
log.append(error_msg)
|
494 |
return "\n".join(log)
|
495 |
|
496 |
-
#
|
497 |
-
|
498 |
-
if torch.cuda.is_available():
|
499 |
-
torch.cuda.empty_cache()
|
500 |
-
|
501 |
-
# --- Load Dataset from Hub ---
|
502 |
-
progress(0.2, desc="Downloading dataset...")
|
503 |
-
local_download_path = "./downloaded_dataset_files"
|
504 |
-
|
505 |
try:
|
506 |
-
|
|
|
|
|
507 |
repo_id=hf_dataset_repo_id,
|
508 |
-
|
509 |
-
local_dir=local_download_path,
|
510 |
local_dir_use_symlinks=False
|
511 |
)
|
512 |
-
log.append(f"Dataset repository content downloaded to: {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
513 |
except Exception as e:
|
514 |
-
error_msg = f"Error
|
515 |
log.append(error_msg)
|
516 |
return "\n".join(log)
|
517 |
|
518 |
-
# --- Find and load the .pt files ---
|
519 |
-
progress(0.25, desc="Finding dataset files...")
|
520 |
-
pairs_dir = os.path.join(downloaded_repo_root, "final_rvq_pairs")
|
521 |
-
all_pair_files = glob.glob(os.path.join(pairs_dir, "*_rvq_pairs.pt"))
|
522 |
-
|
523 |
-
if not all_pair_files:
|
524 |
-
all_pair_files = glob.glob(os.path.join(downloaded_repo_root, "*_rvq_pairs.pt"))
|
525 |
-
if not all_pair_files:
|
526 |
-
error_msg = "No RVQ pair files found in expected directories"
|
527 |
-
log.append(error_msg)
|
528 |
-
return "\n".join(log)
|
529 |
-
|
530 |
-
log.append(f"Found {len(all_pair_files)} RVQ pair files.")
|
531 |
-
|
532 |
-
# --- Load data from .pt files ---
|
533 |
-
progress(0.3, desc="Loading dataset files...")
|
534 |
-
all_data_pairs = []
|
535 |
-
for i, file_path in enumerate(all_pair_files):
|
536 |
-
progress(0.3 + (0.1 * i / len(all_pair_files)), desc=f"Loading file {i+1}/{len(all_pair_files)}")
|
537 |
-
try:
|
538 |
-
episode_pairs = torch.load(file_path, map_location='cpu')
|
539 |
-
all_data_pairs.extend(episode_pairs)
|
540 |
-
except Exception as e:
|
541 |
-
log.append(f"Warning: Could not load file {file_path}: {e}")
|
542 |
-
|
543 |
-
if not all_data_pairs:
|
544 |
-
error_msg = "No valid data pairs were loaded"
|
545 |
-
log.append(error_msg)
|
546 |
-
return "\n".join(log)
|
547 |
-
|
548 |
-
log.append(f"Loaded a total of {len(all_data_pairs)} training pairs into memory.")
|
549 |
-
|
550 |
-
# --- Convert to HF Dataset ---
|
551 |
-
progress(0.45, desc="Converting to Hugging Face Dataset...")
|
552 |
-
def prepare_for_dataset(batch):
|
553 |
-
output = {'input_ids': [], 'labels': []}
|
554 |
-
for item in batch:
|
555 |
-
output['input_ids'].append(item['input_ids'].cpu().tolist())
|
556 |
-
output['labels'].append(item['labels'].cpu().tolist())
|
557 |
-
return output
|
558 |
-
|
559 |
-
chunk_size = 1000
|
560 |
-
processed_data = {'input_ids': [], 'labels': []}
|
561 |
-
|
562 |
-
total_chunks = len(range(0, len(all_data_pairs), chunk_size))
|
563 |
-
for i in range(0, len(all_data_pairs), chunk_size):
|
564 |
-
chunk_idx = i // chunk_size
|
565 |
-
progress(0.45 + (0.1 * chunk_idx / total_chunks),
|
566 |
-
desc=f"Processing chunk {chunk_idx+1}/{total_chunks}")
|
567 |
-
batch = all_data_pairs[i:i + chunk_size]
|
568 |
-
prepared_batch = prepare_for_dataset(batch)
|
569 |
-
processed_data['input_ids'].extend(prepared_batch['input_ids'])
|
570 |
-
processed_data['labels'].extend(prepared_batch['labels'])
|
571 |
-
|
572 |
-
hf_dataset = Dataset.from_dict(processed_data)
|
573 |
-
|
574 |
-
# Transform to get tensors back
|
575 |
-
hf_dataset.set_transform(lambda batch: {
|
576 |
-
'input_ids': [torch.tensor(ids, dtype=torch.long) for ids in batch['input_ids']],
|
577 |
-
'labels': [torch.tensor(lbls, dtype=torch.long) for lbls in batch['labels']]
|
578 |
-
})
|
579 |
-
|
580 |
-
train_dataset = hf_dataset
|
581 |
-
|
582 |
-
# Cleanup
|
583 |
-
del all_data_pairs
|
584 |
-
del processed_data
|
585 |
-
gc.collect()
|
586 |
-
|
587 |
-
# --- Define Data Collator ---
|
588 |
-
progress(0.55, desc="Defining data collator...")
|
589 |
-
def seq2seq_causal_collator(features):
|
590 |
-
batch = {}
|
591 |
-
concatenated_input_ids = []
|
592 |
-
concatenated_labels = []
|
593 |
-
max_len = 0
|
594 |
-
|
595 |
-
# First pass: Concatenate, create masked labels, find max length
|
596 |
-
for feature in features:
|
597 |
-
input_ids = feature['input_ids']
|
598 |
-
labels = feature['labels']
|
599 |
-
|
600 |
-
if input_ids.dim() > 1: input_ids = input_ids.squeeze()
|
601 |
-
if labels.dim() > 1: labels = labels.squeeze()
|
602 |
-
|
603 |
-
context_len = input_ids.shape[0]
|
604 |
-
target_len = labels.shape[0]
|
605 |
-
|
606 |
-
combined_ids = torch.cat([input_ids, labels], dim=0)
|
607 |
-
concatenated_input_ids.append(combined_ids)
|
608 |
-
|
609 |
-
masked_labels = torch.cat([
|
610 |
-
torch.full((context_len,), -100, dtype=torch.long, device=input_ids.device),
|
611 |
-
labels
|
612 |
-
], dim=0)
|
613 |
-
concatenated_labels.append(masked_labels)
|
614 |
-
|
615 |
-
if combined_ids.shape[0] > max_len:
|
616 |
-
max_len = combined_ids.shape[0]
|
617 |
-
|
618 |
-
# Second pass: Pad to max length
|
619 |
-
padded_input_ids = []
|
620 |
-
padded_labels = []
|
621 |
-
input_pad_token_id = 0
|
622 |
-
label_pad_token_id = -100
|
623 |
-
|
624 |
-
for i in range(len(features)):
|
625 |
-
ids = concatenated_input_ids[i]
|
626 |
-
lbls = concatenated_labels[i]
|
627 |
-
|
628 |
-
padding_len = max_len - ids.shape[0]
|
629 |
-
|
630 |
-
padded_input_ids.append(torch.nn.functional.pad(
|
631 |
-
ids, (0, padding_len), value=input_pad_token_id
|
632 |
-
))
|
633 |
-
padded_labels.append(torch.nn.functional.pad(
|
634 |
-
lbls, (0, padding_len), value=label_pad_token_id
|
635 |
-
))
|
636 |
-
|
637 |
-
# Stack and create final batch
|
638 |
-
batch['input_ids'] = torch.stack(padded_input_ids)
|
639 |
-
batch['labels'] = torch.stack(padded_labels)
|
640 |
-
batch['attention_mask'] = batch['input_ids'].ne(input_pad_token_id).long()
|
641 |
-
|
642 |
-
return batch
|
643 |
-
|
644 |
-
data_collator = seq2seq_causal_collator
|
645 |
-
|
646 |
-
# --- DeepSpeed Configuration ---
|
647 |
-
# Create DeepSpeed config file directly in Python instead of loading from a file
|
648 |
-
progress(0.15, desc="Setting up DeepSpeed configuration...")
|
649 |
-
|
650 |
-
ds_config = {
|
651 |
-
"fp16": {
|
652 |
-
"enabled": False
|
653 |
-
},
|
654 |
-
"bf16": {
|
655 |
-
"enabled": True
|
656 |
-
},
|
657 |
-
"zero_optimization": {
|
658 |
-
"stage": 3,
|
659 |
-
"offload_optimizer": {
|
660 |
-
"device": "cpu",
|
661 |
-
"pin_memory": True
|
662 |
-
},
|
663 |
-
"offload_param": {
|
664 |
-
"device": "cpu",
|
665 |
-
"pin_memory": True
|
666 |
-
},
|
667 |
-
"overlap_comm": True,
|
668 |
-
"contiguous_gradients": True,
|
669 |
-
"reduce_bucket_size": "auto",
|
670 |
-
"stage3_prefetch_bucket_size": "auto",
|
671 |
-
"stage3_param_persistence_threshold": "auto"
|
672 |
-
},
|
673 |
-
"gradient_accumulation_steps": grad_accum_steps,
|
674 |
-
"train_micro_batch_size_per_gpu": batch_size,
|
675 |
-
"gradient_clipping": 1.0,
|
676 |
-
"steps_per_print": 10
|
677 |
-
}
|
678 |
-
|
679 |
-
# Save the config to a file
|
680 |
-
with open("ds_config.json", "w") as f:
|
681 |
-
json.dump(ds_config, f, indent=4)
|
682 |
-
|
683 |
-
log.append("DeepSpeed configuration created successfully")
|
684 |
-
|
685 |
# --- Training Arguments ---
|
686 |
-
progress(0.
|
687 |
output_dir = f"./results_{model_repo_name}"
|
688 |
os.makedirs(output_dir, exist_ok=True)
|
689 |
|
690 |
-
#
|
691 |
training_args = TrainingArguments(
|
692 |
output_dir=output_dir,
|
693 |
num_train_epochs=float(epochs),
|
@@ -696,71 +604,90 @@ def train_model(
|
|
696 |
learning_rate=learning_rate,
|
697 |
weight_decay=0.01,
|
698 |
logging_dir=f"{output_dir}/logs",
|
699 |
-
logging_steps=
|
700 |
-
save_steps=
|
701 |
-
save_total_limit=
|
702 |
remove_unused_columns=False,
|
703 |
push_to_hub=False,
|
704 |
disable_tqdm=False,
|
705 |
warmup_ratio=0.03,
|
706 |
lr_scheduler_type="cosine",
|
707 |
report_to="tensorboard",
|
708 |
-
bf16=True
|
|
|
|
|
|
|
709 |
gradient_checkpointing=True,
|
710 |
gradient_checkpointing_kwargs={'use_reentrant': False},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
711 |
|
712 |
-
#
|
713 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
714 |
)
|
715 |
|
716 |
-
#
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
|
|
731 |
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
-
|
736 |
-
|
737 |
-
|
738 |
-
|
739 |
-
|
740 |
-
|
741 |
-
|
742 |
-
|
743 |
-
|
744 |
-
else:
|
745 |
-
# Single GPU setup
|
746 |
-
trainer = Trainer(
|
747 |
-
model=model_to_train,
|
748 |
-
args=training_args,
|
749 |
-
train_dataset=train_dataset,
|
750 |
-
data_collator=data_collator,
|
751 |
-
)
|
752 |
-
log.append("Trainer initialized for single GPU training")
|
753 |
|
754 |
-
|
755 |
-
# Clear cache before starting
|
756 |
-
gc.collect()
|
757 |
-
if torch.cuda.is_available():
|
758 |
-
torch.cuda.empty_cache()
|
759 |
|
|
|
760 |
try:
|
761 |
-
|
762 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
763 |
train_result = trainer.train()
|
|
|
764 |
progress(0.95, desc="Saving model...")
|
765 |
|
766 |
# Save final model (adapter weights) and training state
|
@@ -776,21 +703,37 @@ def train_model(
|
|
776 |
|
777 |
for key, value in metrics.items():
|
778 |
log.append(f"{key}: {value}")
|
|
|
|
|
|
|
|
|
|
|
779 |
|
780 |
except Exception as e:
|
781 |
-
error_msg = f"An error occurred during training: {e}"
|
782 |
log.append(error_msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
783 |
return "\n".join(log)
|
784 |
|
785 |
progress(1.0, desc="Training complete!")
|
786 |
-
log.append("
|
787 |
return "\n".join(log)
|
788 |
|
789 |
# Define the Gradio interface
|
790 |
def create_interface():
|
791 |
with gr.Blocks(title="Llama 3 8B RVQ Fine-tuning") as demo:
|
792 |
gr.Markdown("# Llama 3 8B RVQ LoRA Fine-tuning")
|
793 |
-
gr.Markdown("Fine-tune a Llama 3 8B model with RVQ token embeddings using LoRA
|
794 |
|
795 |
with gr.Row():
|
796 |
with gr.Column():
|
@@ -801,7 +744,7 @@ def create_interface():
|
|
801 |
with gr.Column():
|
802 |
epochs = gr.Number(label="Number of Epochs", value=1, minimum=1, maximum=10)
|
803 |
batch_size = gr.Number(label="Batch Size per Device", value=1, minimum=1, maximum=8)
|
804 |
-
grad_accum = gr.Number(label="Gradient Accumulation Steps", value=
|
805 |
lr = gr.Number(label="Learning Rate", value=1e-4)
|
806 |
|
807 |
start_btn = gr.Button("Start Training")
|
|
|
330 |
dataset_repo_name,
|
331 |
epochs=1,
|
332 |
batch_size=1,
|
333 |
+
grad_accum_steps=16, # Increased from 8 to 16
|
334 |
learning_rate=1e-4,
|
335 |
progress=gr.Progress()
|
336 |
):
|
337 |
progress(0, desc="Setting up environment...")
|
338 |
log = []
|
339 |
|
340 |
+
# Aggressive memory cleanup
|
341 |
+
gc.collect()
|
342 |
+
if torch.cuda.is_available():
|
343 |
+
torch.cuda.empty_cache()
|
344 |
+
# Reset peak memory stats
|
345 |
+
torch.cuda.reset_peak_memory_stats()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
|
347 |
+
# Clean up any existing model files to save space
|
348 |
+
if os.path.exists("./model_files"):
|
349 |
+
try:
|
350 |
+
shutil.rmtree("./model_files")
|
351 |
+
except Exception as e:
|
352 |
+
log.append(f"Warning: Could not remove existing model files: {e}")
|
353 |
|
354 |
+
if os.path.exists("./downloaded_dataset_files"):
|
355 |
+
try:
|
356 |
+
shutil.rmtree("./downloaded_dataset_files")
|
357 |
+
except Exception as e:
|
358 |
+
log.append(f"Warning: Could not remove existing dataset files: {e}")
|
359 |
+
|
360 |
+
# Print GPU info
|
361 |
+
if torch.cuda.is_available():
|
362 |
+
log.append(f"Available GPUs: {torch.cuda.device_count()}")
|
363 |
+
for i in range(torch.cuda.device_count()):
|
364 |
+
gpu_name = torch.cuda.get_device_name(i)
|
365 |
+
gpu_memory = torch.cuda.get_device_properties(i).total_memory / (1024**3)
|
366 |
+
log.append(f"GPU {i}: {gpu_name} with {gpu_memory:.2f} GB")
|
367 |
|
368 |
+
# Import required libraries
|
369 |
try:
|
370 |
from datasets import Dataset
|
371 |
from huggingface_hub import snapshot_download
|
372 |
import torch
|
373 |
import transformers
|
374 |
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
|
375 |
+
from transformers import BitsAndBytesConfig, TrainingArguments, Trainer, AutoTokenizer
|
376 |
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
|
377 |
|
378 |
log.append(f"Transformers version: {transformers.__version__}")
|
|
|
393 |
n_gpus = torch.cuda.device_count()
|
394 |
log.append(f"Number of GPUs available: {n_gpus}")
|
395 |
|
396 |
+
# --- Load Base Model (with extreme quantization) ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
progress(0.1, desc="Loading base model...")
|
398 |
+
local_model_path = "./model_files"
|
399 |
try:
|
400 |
+
# Download the model files
|
|
|
|
|
|
|
|
|
401 |
snapshot_download(
|
402 |
repo_id=hf_model_repo_id,
|
403 |
local_dir=local_model_path,
|
404 |
local_dir_use_symlinks=False
|
405 |
)
|
|
|
406 |
log.append(f"Model files downloaded to {local_model_path}")
|
407 |
|
408 |
+
# Ensure model_type is set correctly in the config
|
409 |
+
config_path = os.path.join(local_model_path, "config.json")
|
410 |
+
with open(config_path, "r") as f:
|
411 |
+
config_data = json.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
|
413 |
+
model_type = config_data.get("model_type", "")
|
414 |
+
log.append(f"Model architecture type: {model_type}")
|
|
|
|
|
|
|
415 |
|
416 |
+
# Force model_type to be "llama" if it's not already
|
417 |
+
if model_type != "llama":
|
418 |
+
config_data["model_type"] = "llama"
|
419 |
+
# Also ensure architectures is set correctly
|
420 |
+
config_data["architectures"] = ["LlamaForCausalLM"]
|
421 |
+
with open(config_path, "w") as f:
|
422 |
+
json.dump(config_data, f, indent=2)
|
423 |
+
log.append("Updated config.json to use llama model_type")
|
424 |
+
|
425 |
+
# Load the config first
|
426 |
+
config = LlamaConfig.from_pretrained(local_model_path)
|
427 |
log.append(f"Successfully loaded config: {config.model_type}")
|
428 |
|
429 |
+
# Use 4-bit quantization for extreme memory savings
|
430 |
+
bnb_config = BitsAndBytesConfig(
|
431 |
+
load_in_4bit=True,
|
432 |
+
bnb_4bit_use_double_quant=True,
|
433 |
+
bnb_4bit_quant_type="nf4",
|
434 |
+
bnb_4bit_compute_dtype=torch.bfloat16
|
435 |
+
)
|
436 |
+
|
437 |
+
# Load tokenizer first (needed for dataset preparation)
|
438 |
+
tokenizer = AutoTokenizer.from_pretrained(local_model_path)
|
439 |
+
|
440 |
+
# Explicit device map to enable CPU offloading
|
441 |
+
max_memory = {0: "40GB", "cpu": "64GB"}
|
442 |
+
|
443 |
+
# Load the model with extreme memory optimization
|
444 |
model = LlamaForCausalLM.from_pretrained(
|
445 |
local_model_path,
|
446 |
config=config,
|
447 |
quantization_config=bnb_config,
|
448 |
device_map="auto",
|
449 |
+
max_memory=max_memory,
|
450 |
torch_dtype=torch.bfloat16,
|
451 |
low_cpu_mem_usage=True
|
452 |
)
|
|
|
456 |
except Exception as e:
|
457 |
error_msg = f"Error loading model: {str(e)}"
|
458 |
log.append(error_msg)
|
459 |
+
return "\n".join(log)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
460 |
|
461 |
# --- Prepare for K-bit Training & Apply LoRA ---
|
462 |
progress(0.15, desc="Preparing model for fine-tuning...")
|
|
|
464 |
model = prepare_model_for_kbit_training(model)
|
465 |
log.append("Model prepared for k-bit training")
|
466 |
|
467 |
+
# Use minimal LoRA configuration with fewer parameters
|
468 |
lora_config = LoraConfig(
|
469 |
task_type=TaskType.CAUSAL_LM,
|
470 |
+
r=8, # Reduced from 16 to 8
|
471 |
+
lora_alpha=16, # Reduced from 32 to 16
|
472 |
lora_dropout=0.05,
|
473 |
bias="none",
|
474 |
+
# Target only key modules to reduce memory usage
|
475 |
+
target_modules=["q_proj", "v_proj"] # Reduced target modules
|
476 |
)
|
477 |
+
|
478 |
+
# Apply LoRA
|
479 |
peft_model = get_peft_model(model, lora_config)
|
|
|
|
|
480 |
model_to_train = peft_model
|
481 |
+
log.append("LoRA applied to model")
|
482 |
+
|
483 |
+
# Free memory
|
484 |
+
gc.collect()
|
485 |
+
if torch.cuda.is_available():
|
486 |
+
torch.cuda.empty_cache()
|
487 |
except Exception as e:
|
488 |
error_msg = f"Error preparing model for training: {str(e)}"
|
489 |
log.append(error_msg)
|
490 |
return "\n".join(log)
|
491 |
|
492 |
+
# --- Download and Process Dataset ---
|
493 |
+
progress(0.2, desc="Loading dataset...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
494 |
try:
|
495 |
+
# Download the dataset files
|
496 |
+
dataset_dir = os.path.join(os.getcwd(), "downloaded_dataset_files")
|
497 |
+
snapshot_download(
|
498 |
repo_id=hf_dataset_repo_id,
|
499 |
+
local_dir=dataset_dir,
|
|
|
500 |
local_dir_use_symlinks=False
|
501 |
)
|
502 |
+
log.append(f"Dataset repository content downloaded to: {dataset_dir}")
|
503 |
+
|
504 |
+
# Find all RVQ pair files
|
505 |
+
rvq_pair_files = glob.glob(os.path.join(dataset_dir, "*_rvq_pairs.pt"))
|
506 |
+
log.append(f"Found {len(rvq_pair_files)} RVQ pair files.")
|
507 |
+
|
508 |
+
# Load training pairs from the dataset
|
509 |
+
training_pairs = []
|
510 |
+
|
511 |
+
# For memory conservation, use only half the dataset for now
|
512 |
+
max_file_count = min(12, len(rvq_pair_files))
|
513 |
+
|
514 |
+
for i, pair_file in enumerate(rvq_pair_files[:max_file_count]):
|
515 |
+
try:
|
516 |
+
pairs = torch.load(pair_file)
|
517 |
+
training_pairs.extend(pairs)
|
518 |
+
except Exception as e:
|
519 |
+
log.append(f"Warning: Could not load {pair_file}: {e}")
|
520 |
+
|
521 |
+
log.append(f"Loaded a total of {len(training_pairs)} training pairs into memory.")
|
522 |
+
|
523 |
+
# Prepare dataset
|
524 |
+
dataset = Dataset.from_dict({
|
525 |
+
"input_ids": [pair[0].tolist() for pair in training_pairs],
|
526 |
+
"labels": [pair[1].tolist() for pair in training_pairs]
|
527 |
+
})
|
528 |
+
|
529 |
+
# Clear the training_pairs to free memory
|
530 |
+
training_pairs = None
|
531 |
+
gc.collect()
|
532 |
+
torch.cuda.empty_cache()
|
533 |
+
|
534 |
+
# Use a smaller max_length to reduce memory pressure
|
535 |
+
max_length = 512 # Reduced max sequence length
|
536 |
+
|
537 |
+
# Create data collator that handles padding
|
538 |
+
def data_collator(examples):
|
539 |
+
# Convert lists back to tensors
|
540 |
+
for i in range(len(examples)):
|
541 |
+
examples[i]["input_ids"] = torch.tensor(examples[i]["input_ids"], dtype=torch.long)
|
542 |
+
examples[i]["labels"] = torch.tensor(examples[i]["labels"], dtype=torch.long)
|
543 |
+
|
544 |
+
# Get max length in this batch
|
545 |
+
batch_max_length = min(
|
546 |
+
max(len(example["input_ids"]) for example in examples),
|
547 |
+
max_length
|
548 |
+
)
|
549 |
+
|
550 |
+
batch = {
|
551 |
+
"input_ids": [],
|
552 |
+
"attention_mask": [],
|
553 |
+
"labels": []
|
554 |
+
}
|
555 |
+
|
556 |
+
# Prepare sequences
|
557 |
+
for example in examples:
|
558 |
+
input_ids = example["input_ids"][:batch_max_length]
|
559 |
+
labels = example["labels"][:batch_max_length]
|
560 |
+
|
561 |
+
# Pad sequences
|
562 |
+
padding_length = batch_max_length - len(input_ids)
|
563 |
+
attention_mask = torch.ones_like(input_ids)
|
564 |
+
|
565 |
+
if padding_length > 0:
|
566 |
+
padding = torch.ones(padding_length, dtype=input_ids.dtype) * tokenizer.pad_token_id
|
567 |
+
input_ids = torch.cat([input_ids, padding])
|
568 |
+
labels = torch.cat([labels, padding * -100]) # -100 to ignore in loss computation
|
569 |
+
attention_mask = torch.cat([attention_mask, torch.zeros(padding_length, dtype=attention_mask.dtype)])
|
570 |
+
|
571 |
+
batch["input_ids"].append(input_ids)
|
572 |
+
batch["attention_mask"].append(attention_mask)
|
573 |
+
batch["labels"].append(labels)
|
574 |
+
|
575 |
+
# Convert lists to tensors
|
576 |
+
for key in batch:
|
577 |
+
batch[key] = torch.stack(batch[key])
|
578 |
+
|
579 |
+
return batch
|
580 |
+
|
581 |
+
# Convert to training dataset
|
582 |
+
train_dataset = dataset
|
583 |
+
|
584 |
+
# Free memory
|
585 |
+
del dataset
|
586 |
+
gc.collect()
|
587 |
+
torch.cuda.empty_cache()
|
588 |
except Exception as e:
|
589 |
+
error_msg = f"Error loading dataset: {str(e)}"
|
590 |
log.append(error_msg)
|
591 |
return "\n".join(log)
|
592 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
593 |
# --- Training Arguments ---
|
594 |
+
progress(0.3, desc="Setting up training arguments...")
|
595 |
output_dir = f"./results_{model_repo_name}"
|
596 |
os.makedirs(output_dir, exist_ok=True)
|
597 |
|
598 |
+
# Super-aggressive memory conservation
|
599 |
training_args = TrainingArguments(
|
600 |
output_dir=output_dir,
|
601 |
num_train_epochs=float(epochs),
|
|
|
604 |
learning_rate=learning_rate,
|
605 |
weight_decay=0.01,
|
606 |
logging_dir=f"{output_dir}/logs",
|
607 |
+
logging_steps=1, # Log frequently to see progress
|
608 |
+
save_steps=25, # Save checkpoints more frequently
|
609 |
+
save_total_limit=1, # Keep only one checkpoint to save space
|
610 |
remove_unused_columns=False,
|
611 |
push_to_hub=False,
|
612 |
disable_tqdm=False,
|
613 |
warmup_ratio=0.03,
|
614 |
lr_scheduler_type="cosine",
|
615 |
report_to="tensorboard",
|
616 |
+
bf16=True,
|
617 |
+
fp16=False,
|
618 |
+
|
619 |
+
# Memory optimization
|
620 |
gradient_checkpointing=True,
|
621 |
gradient_checkpointing_kwargs={'use_reentrant': False},
|
622 |
+
max_grad_norm=0.3, # Reduced from default 1.0
|
623 |
+
dataloader_pin_memory=False, # Reduce memory pressure
|
624 |
+
|
625 |
+
# Optimizer settings for memory efficiency
|
626 |
+
optim="adamw_torch",
|
627 |
+
adam_beta1=0.9,
|
628 |
+
adam_beta2=0.999,
|
629 |
+
adam_epsilon=1e-8,
|
630 |
|
631 |
+
# Evaluation settings
|
632 |
+
do_eval=False,
|
633 |
+
evaluation_strategy="no",
|
634 |
+
|
635 |
+
# Set this for smaller chunks of data processing
|
636 |
+
dataloader_num_workers=1,
|
637 |
+
|
638 |
+
# For memory efficiency when loading datasets
|
639 |
+
dataloader_drop_last=True,
|
640 |
)
|
641 |
|
642 |
+
# --- Initialize Trainer ---
|
643 |
+
progress(0.4, desc="Initializing trainer...")
|
644 |
+
|
645 |
+
# Use optimizer that requires less memory
|
646 |
+
class MemoryEfficientTrainer(Trainer):
|
647 |
+
def create_optimizer(self):
|
648 |
+
# Create optimizer with reduced memory footprint
|
649 |
+
optimizer = super().create_optimizer()
|
650 |
+
# Force optimizer to use CPU offloading for states
|
651 |
+
for param_group in optimizer.param_groups:
|
652 |
+
for param in param_group['params']:
|
653 |
+
if param.requires_grad:
|
654 |
+
param.data = param.data.to("cpu")
|
655 |
+
if param.grad is not None:
|
656 |
+
param.grad.data = param.grad.data.to("cpu")
|
657 |
+
return optimizer
|
658 |
|
659 |
+
def training_step(self, *args, **kwargs):
|
660 |
+
# Memory cleanup before each training step
|
661 |
+
gc.collect()
|
662 |
+
torch.cuda.empty_cache()
|
663 |
+
return super().training_step(*args, **kwargs)
|
664 |
+
|
665 |
+
trainer = MemoryEfficientTrainer(
|
666 |
+
model=model_to_train,
|
667 |
+
args=training_args,
|
668 |
+
train_dataset=train_dataset,
|
669 |
+
data_collator=data_collator,
|
670 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
671 |
|
672 |
+
log.append("Trainer initialized with memory-efficient settings")
|
|
|
|
|
|
|
|
|
673 |
|
674 |
+
# --- Start Training ---
|
675 |
try:
|
676 |
+
# Final memory cleanup before training
|
677 |
+
gc.collect()
|
678 |
+
if torch.cuda.is_available():
|
679 |
+
torch.cuda.empty_cache()
|
680 |
+
|
681 |
+
progress(0.5, desc="Starting training...")
|
682 |
+
log.append("Starting training with extreme memory optimization...")
|
683 |
+
|
684 |
+
# Train in smaller chunks to manage memory better
|
685 |
+
total_steps = len(train_dataset) // (batch_size * grad_accum_steps)
|
686 |
+
log.append(f"Total training steps: {total_steps}")
|
687 |
+
|
688 |
+
# Train the model
|
689 |
train_result = trainer.train()
|
690 |
+
|
691 |
progress(0.95, desc="Saving model...")
|
692 |
|
693 |
# Save final model (adapter weights) and training state
|
|
|
703 |
|
704 |
for key, value in metrics.items():
|
705 |
log.append(f"{key}: {value}")
|
706 |
+
|
707 |
+
# Print peak memory usage
|
708 |
+
if torch.cuda.is_available():
|
709 |
+
peak_memory = torch.cuda.max_memory_allocated() / (1024**3)
|
710 |
+
log.append(f"Peak GPU memory usage: {peak_memory:.2f} GB")
|
711 |
|
712 |
except Exception as e:
|
713 |
+
error_msg = f"An error occurred during training: {str(e)}"
|
714 |
log.append(error_msg)
|
715 |
+
|
716 |
+
# Try to save checkpoint even if training failed
|
717 |
+
try:
|
718 |
+
# Save whatever we have
|
719 |
+
log.append("Attempting to save partial checkpoint...")
|
720 |
+
emergency_save_path = os.path.join(training_args.output_dir, "emergency_checkpoint")
|
721 |
+
trainer.save_model(emergency_save_path)
|
722 |
+
log.append(f"Saved emergency checkpoint to {emergency_save_path}")
|
723 |
+
except Exception as save_error:
|
724 |
+
log.append(f"Could not save emergency checkpoint: {save_error}")
|
725 |
+
|
726 |
return "\n".join(log)
|
727 |
|
728 |
progress(1.0, desc="Training complete!")
|
729 |
+
log.append("Training process complete successfully.")
|
730 |
return "\n".join(log)
|
731 |
|
732 |
# Define the Gradio interface
|
733 |
def create_interface():
|
734 |
with gr.Blocks(title="Llama 3 8B RVQ Fine-tuning") as demo:
|
735 |
gr.Markdown("# Llama 3 8B RVQ LoRA Fine-tuning")
|
736 |
+
gr.Markdown("Fine-tune a Llama 3 8B model with RVQ token embeddings using LoRA with extreme memory optimization")
|
737 |
|
738 |
with gr.Row():
|
739 |
with gr.Column():
|
|
|
744 |
with gr.Column():
|
745 |
epochs = gr.Number(label="Number of Epochs", value=1, minimum=1, maximum=10)
|
746 |
batch_size = gr.Number(label="Batch Size per Device", value=1, minimum=1, maximum=8)
|
747 |
+
grad_accum = gr.Number(label="Gradient Accumulation Steps", value=16, minimum=8, maximum=32)
|
748 |
lr = gr.Number(label="Learning Rate", value=1e-4)
|
749 |
|
750 |
start_btn = gr.Button("Start Training")
|