|
from dataclasses import dataclass, field |
|
from typing import Optional, List |
|
import os |
|
import json |
|
from transformers import TrainingArguments, Trainer |
|
import torch |
|
|
|
@dataclass |
|
class AudioTrainingConfig: |
|
|
|
model_name: str = "wav2vec2" |
|
hidden_size: int = 1024 |
|
num_attention_heads: int = 16 |
|
num_hidden_layers: int = 24 |
|
|
|
|
|
output_dir: str = field(default="./results") |
|
num_train_epochs: int = 5 |
|
per_device_train_batch_size: int = 8 |
|
per_device_eval_batch_size: int = 8 |
|
gradient_accumulation_steps: int = 4 |
|
learning_rate: float = 3e-5 |
|
warmup_ratio: float = 0.1 |
|
|
|
|
|
fp16: bool = True |
|
bf16: bool = False |
|
gradient_checkpointing: bool = True |
|
optim: str = "adamw_torch" |
|
weight_decay: float = 0.01 |
|
max_grad_norm: float = 1.0 |
|
|
|
|
|
logging_dir: str = field(default="./logs") |
|
logging_steps: int = 100 |
|
eval_steps: int = 500 |
|
save_steps: int = 500 |
|
save_strategy: str = "epoch" |
|
evaluation_strategy: str = "epoch" |
|
|
|
|
|
dataloader_num_workers: int = 4 |
|
group_by_length: bool = True |
|
remove_unused_columns: bool = True |
|
label_smoothing_factor: float = 0.1 |
|
|
|
|
|
use_mps_device: bool = field( |
|
default=False, |
|
metadata={"help": "Whether to use Apple M1/M2 GPU acceleration"} |
|
) |
|
mixed_precision: str = field( |
|
default="fp16", |
|
metadata={"help": "Mixed precision mode: 'no', 'fp16', 'bf16'"} |
|
) |
|
|
|
def __post_init__(self): |
|
|
|
os.makedirs(self.output_dir, exist_ok=True) |
|
os.makedirs(self.logging_dir, exist_ok=True) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
self.device = "cuda" |
|
self.n_gpu = torch.cuda.device_count() |
|
elif torch.backends.mps.is_available() and self.use_mps_device: |
|
self.device = "mps" |
|
self.n_gpu = 1 |
|
else: |
|
self.device = "cpu" |
|
self.n_gpu = 0 |
|
self.fp16 = False |
|
self.bf16 = False |
|
|
|
def get_training_args(self) -> TrainingArguments: |
|
return TrainingArguments( |
|
output_dir=self.output_dir, |
|
num_train_epochs=self.num_train_epochs, |
|
per_device_train_batch_size=self.per_device_train_batch_size, |
|
per_device_eval_batch_size=self.per_device_eval_batch_size, |
|
gradient_accumulation_steps=self.gradient_accumulation_steps, |
|
learning_rate=self.learning_rate, |
|
warmup_ratio=self.warmup_ratio, |
|
logging_dir=self.logging_dir, |
|
logging_steps=self.logging_steps, |
|
save_strategy=self.save_strategy, |
|
evaluation_strategy=self.evaluation_strategy, |
|
eval_steps=self.eval_steps, |
|
save_steps=self.save_steps, |
|
load_best_model_at_end=True, |
|
metric_for_best_model="accuracy", |
|
greater_is_better=True, |
|
fp16=self.fp16 and self.mixed_precision == "fp16", |
|
bf16=self.bf16 and self.mixed_precision == "bf16", |
|
dataloader_num_workers=self.dataloader_num_workers, |
|
group_by_length=self.group_by_length, |
|
remove_unused_columns=self.remove_unused_columns, |
|
label_smoothing_factor=self.label_smoothing_factor, |
|
gradient_checkpointing=self.gradient_checkpointing, |
|
optim=self.optim, |
|
weight_decay=self.weight_decay, |
|
max_grad_norm=self.max_grad_norm, |
|
) |
|
|
|
def save_config(self, filepath: str = "training_config.json"): |
|
"""Save configuration to JSON file""" |
|
config_dict = {k: v for k, v in self.__dict__.items() if not k.startswith('_')} |
|
with open(filepath, 'w') as f: |
|
json.dump(config_dict, f, indent=2) |
|
|
|
@classmethod |
|
def load_config(cls, filepath: str = "training_config.json") -> 'AudioTrainingConfig': |
|
"""Load configuration from JSON file""" |
|
with open(filepath, 'r') as f: |
|
config_dict = json.load(f) |
|
return cls(**config_dict) |
|
|
|
def main(): |
|
|
|
config = AudioTrainingConfig() |
|
|
|
|
|
config.save_config("training_config.json") |
|
training_args = config.get_training_args() |
|
training_args.save_to_json("training_args.bin") |
|
|
|
print(f"Training will use device: {config.device} with {config.n_gpu} GPUs") |
|
print(f"Mixed precision: {config.mixed_precision}") |
|
print(f"Configuration saved to: training_config.json and training_args.bin") |
|
|
|
if __name__ == "__main__": |
|
main() |