File size: 4,708 Bytes
45e1a77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from dataclasses import dataclass, field
from typing import Optional, List
import os
import json
from transformers import TrainingArguments, Trainer
import torch

@dataclass
class AudioTrainingConfig:
    # Model configuration
    model_name: str = "wav2vec2"
    hidden_size: int = 1024
    num_attention_heads: int = 16
    num_hidden_layers: int = 24
    
    # Training parameters
    output_dir: str = field(default="./results")
    num_train_epochs: int = 5
    per_device_train_batch_size: int = 8
    per_device_eval_batch_size: int = 8
    gradient_accumulation_steps: int = 4
    learning_rate: float = 3e-5
    warmup_ratio: float = 0.1
    
    # Optimization
    fp16: bool = True
    bf16: bool = False
    gradient_checkpointing: bool = True
    optim: str = "adamw_torch"
    weight_decay: float = 0.01
    max_grad_norm: float = 1.0
    
    # Logging & Evaluation
    logging_dir: str = field(default="./logs")
    logging_steps: int = 100
    eval_steps: int = 500
    save_steps: int = 500
    save_strategy: str = "epoch"
    evaluation_strategy: str = "epoch"
    
    # Performance
    dataloader_num_workers: int = 4
    group_by_length: bool = True
    remove_unused_columns: bool = True
    label_smoothing_factor: float = 0.1
    
    # Advanced features
    use_mps_device: bool = field(
        default=False,
        metadata={"help": "Whether to use Apple M1/M2 GPU acceleration"}
    )
    mixed_precision: str = field(
        default="fp16",
        metadata={"help": "Mixed precision mode: 'no', 'fp16', 'bf16'"}
    )
    
    def __post_init__(self):
        # Create output directories if they don't exist
        os.makedirs(self.output_dir, exist_ok=True)
        os.makedirs(self.logging_dir, exist_ok=True)
        
        # Adjust settings based on hardware
        if torch.cuda.is_available():
            self.device = "cuda"
            self.n_gpu = torch.cuda.device_count()
        elif torch.backends.mps.is_available() and self.use_mps_device:
            self.device = "mps"
            self.n_gpu = 1
        else:
            self.device = "cpu"
            self.n_gpu = 0
            self.fp16 = False
            self.bf16 = False

    def get_training_args(self) -> TrainingArguments:
        return TrainingArguments(
            output_dir=self.output_dir,
            num_train_epochs=self.num_train_epochs,
            per_device_train_batch_size=self.per_device_train_batch_size,
            per_device_eval_batch_size=self.per_device_eval_batch_size,
            gradient_accumulation_steps=self.gradient_accumulation_steps,
            learning_rate=self.learning_rate,
            warmup_ratio=self.warmup_ratio,
            logging_dir=self.logging_dir,
            logging_steps=self.logging_steps,
            save_strategy=self.save_strategy,
            evaluation_strategy=self.evaluation_strategy,
            eval_steps=self.eval_steps,
            save_steps=self.save_steps,
            load_best_model_at_end=True,
            metric_for_best_model="accuracy",
            greater_is_better=True,
            fp16=self.fp16 and self.mixed_precision == "fp16",
            bf16=self.bf16 and self.mixed_precision == "bf16",
            dataloader_num_workers=self.dataloader_num_workers,
            group_by_length=self.group_by_length,
            remove_unused_columns=self.remove_unused_columns,
            label_smoothing_factor=self.label_smoothing_factor,
            gradient_checkpointing=self.gradient_checkpointing,
            optim=self.optim,
            weight_decay=self.weight_decay,
            max_grad_norm=self.max_grad_norm,
        )

    def save_config(self, filepath: str = "training_config.json"):
        """Save configuration to JSON file"""
        config_dict = {k: v for k, v in self.__dict__.items() if not k.startswith('_')}
        with open(filepath, 'w') as f:
            json.dump(config_dict, f, indent=2)

    @classmethod
    def load_config(cls, filepath: str = "training_config.json") -> 'AudioTrainingConfig':
        """Load configuration from JSON file"""
        with open(filepath, 'r') as f:
            config_dict = json.load(f)
        return cls(**config_dict)

def main():
    # Initialize configuration
    config = AudioTrainingConfig()
    
    # Save both formats
    config.save_config("training_config.json")
    training_args = config.get_training_args()
    training_args.save_to_json("training_args.bin")
    
    print(f"Training will use device: {config.device} with {config.n_gpu} GPUs")
    print(f"Mixed precision: {config.mixed_precision}")
    print(f"Configuration saved to: training_config.json and training_args.bin")

if __name__ == "__main__":
    main()