File size: 4,708 Bytes
45e1a77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from dataclasses import dataclass, field
from typing import Optional, List
import os
import json
from transformers import TrainingArguments, Trainer
import torch
@dataclass
class AudioTrainingConfig:
# Model configuration
model_name: str = "wav2vec2"
hidden_size: int = 1024
num_attention_heads: int = 16
num_hidden_layers: int = 24
# Training parameters
output_dir: str = field(default="./results")
num_train_epochs: int = 5
per_device_train_batch_size: int = 8
per_device_eval_batch_size: int = 8
gradient_accumulation_steps: int = 4
learning_rate: float = 3e-5
warmup_ratio: float = 0.1
# Optimization
fp16: bool = True
bf16: bool = False
gradient_checkpointing: bool = True
optim: str = "adamw_torch"
weight_decay: float = 0.01
max_grad_norm: float = 1.0
# Logging & Evaluation
logging_dir: str = field(default="./logs")
logging_steps: int = 100
eval_steps: int = 500
save_steps: int = 500
save_strategy: str = "epoch"
evaluation_strategy: str = "epoch"
# Performance
dataloader_num_workers: int = 4
group_by_length: bool = True
remove_unused_columns: bool = True
label_smoothing_factor: float = 0.1
# Advanced features
use_mps_device: bool = field(
default=False,
metadata={"help": "Whether to use Apple M1/M2 GPU acceleration"}
)
mixed_precision: str = field(
default="fp16",
metadata={"help": "Mixed precision mode: 'no', 'fp16', 'bf16'"}
)
def __post_init__(self):
# Create output directories if they don't exist
os.makedirs(self.output_dir, exist_ok=True)
os.makedirs(self.logging_dir, exist_ok=True)
# Adjust settings based on hardware
if torch.cuda.is_available():
self.device = "cuda"
self.n_gpu = torch.cuda.device_count()
elif torch.backends.mps.is_available() and self.use_mps_device:
self.device = "mps"
self.n_gpu = 1
else:
self.device = "cpu"
self.n_gpu = 0
self.fp16 = False
self.bf16 = False
def get_training_args(self) -> TrainingArguments:
return TrainingArguments(
output_dir=self.output_dir,
num_train_epochs=self.num_train_epochs,
per_device_train_batch_size=self.per_device_train_batch_size,
per_device_eval_batch_size=self.per_device_eval_batch_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
learning_rate=self.learning_rate,
warmup_ratio=self.warmup_ratio,
logging_dir=self.logging_dir,
logging_steps=self.logging_steps,
save_strategy=self.save_strategy,
evaluation_strategy=self.evaluation_strategy,
eval_steps=self.eval_steps,
save_steps=self.save_steps,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
greater_is_better=True,
fp16=self.fp16 and self.mixed_precision == "fp16",
bf16=self.bf16 and self.mixed_precision == "bf16",
dataloader_num_workers=self.dataloader_num_workers,
group_by_length=self.group_by_length,
remove_unused_columns=self.remove_unused_columns,
label_smoothing_factor=self.label_smoothing_factor,
gradient_checkpointing=self.gradient_checkpointing,
optim=self.optim,
weight_decay=self.weight_decay,
max_grad_norm=self.max_grad_norm,
)
def save_config(self, filepath: str = "training_config.json"):
"""Save configuration to JSON file"""
config_dict = {k: v for k, v in self.__dict__.items() if not k.startswith('_')}
with open(filepath, 'w') as f:
json.dump(config_dict, f, indent=2)
@classmethod
def load_config(cls, filepath: str = "training_config.json") -> 'AudioTrainingConfig':
"""Load configuration from JSON file"""
with open(filepath, 'r') as f:
config_dict = json.load(f)
return cls(**config_dict)
def main():
# Initialize configuration
config = AudioTrainingConfig()
# Save both formats
config.save_config("training_config.json")
training_args = config.get_training_args()
training_args.save_to_json("training_args.bin")
print(f"Training will use device: {config.device} with {config.n_gpu} GPUs")
print(f"Mixed precision: {config.mixed_precision}")
print(f"Configuration saved to: training_config.json and training_args.bin")
if __name__ == "__main__":
main() |