File size: 1,168 Bytes
45e1a77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
import torch
from transformers import Trainer, TrainingArguments
from src.model.architectures.wav2vec2 import Wav2Vec2ForAudioClassification
from src.data.preprocessing.feature_extraction import load_and_process_audio
import json

def load_config(config_path):
    with open(config_path, 'r') as f:
        return json.load(f)

def main():
    # Load configurations
    model_config = load_config('configs/model/base_config.json')
    training_config = load_config('configs/training/base_config.json')
    
    # Initialize model
    model = Wav2Vec2ForAudioClassification.from_pretrained(
        'wav2vec2-base',
        num_labels=2,
        **model_config
    )

    # Training arguments
    training_args = TrainingArguments(
        output_dir="results/checkpoints",
        **training_config['training_parameters'],
        **training_config['optimization']
    )

    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=None,  # Add your dataset here
        eval_dataset=None,   # Add your eval dataset here
    )

    # Train
    trainer.train()

if __name__ == "__main__":
    main()