|
import os |
|
import torch |
|
from transformers import Trainer, TrainingArguments |
|
from src.model.architectures.wav2vec2 import Wav2Vec2ForAudioClassification |
|
from src.data.preprocessing.feature_extraction import load_and_process_audio |
|
import json |
|
|
|
def load_config(config_path): |
|
with open(config_path, 'r') as f: |
|
return json.load(f) |
|
|
|
def main(): |
|
|
|
model_config = load_config('configs/model/base_config.json') |
|
training_config = load_config('configs/training/base_config.json') |
|
|
|
|
|
model = Wav2Vec2ForAudioClassification.from_pretrained( |
|
'wav2vec2-base', |
|
num_labels=2, |
|
**model_config |
|
) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir="results/checkpoints", |
|
**training_config['training_parameters'], |
|
**training_config['optimization'] |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=None, |
|
eval_dataset=None, |
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
if __name__ == "__main__": |
|
main() |