import os import torch from transformers import Trainer, TrainingArguments from src.model.architectures.wav2vec2 import Wav2Vec2ForAudioClassification from src.data.preprocessing.feature_extraction import load_and_process_audio import json def load_config(config_path): with open(config_path, 'r') as f: return json.load(f) def main(): # Load configurations model_config = load_config('configs/model/base_config.json') training_config = load_config('configs/training/base_config.json') # Initialize model model = Wav2Vec2ForAudioClassification.from_pretrained( 'wav2vec2-base', num_labels=2, **model_config ) # Training arguments training_args = TrainingArguments( output_dir="results/checkpoints", **training_config['training_parameters'], **training_config['optimization'] ) # Initialize trainer trainer = Trainer( model=model, args=training_args, train_dataset=None, # Add your dataset here eval_dataset=None, # Add your eval dataset here ) # Train trainer.train() if __name__ == "__main__": main()