import torch from TTS.api import TTS #Andy edited: import losses import audio_diffusion_attacks_forhf.src.losses from audiotools import AudioSignal import numpy as np import torchaudio import random import string import os class XTTS_Eval: def __init__(self, input_sample_rate, text="The quick brown fox jumps over the lazy dog."): self.model = TTS("tts_models/multilingual/multi-dataset/xtts_v2") self.model=self.model.to(device='cuda') self.text=text self.input_sample_rate=input_sample_rate self.mel_loss = losses.MelSpectrogramLoss(n_mels=[5, 10, 20, 40, 80, 160, 320], window_lengths=[32, 64, 128, 256, 512, 1024, 2048], mel_fmin=[0, 0, 0, 0, 0, 0, 0], pow=1.0, clamp_eps=1.0e-5, mag_weight=0.0) def eval(self, original_audio, protected_audio): original_audio=original_audio[0] protected_audio=protected_audio[0] unprotected_gen=self.generate_audio(original_audio).to(device='cuda') protected_gen=self.generate_audio(protected_audio).to(device='cuda') match_len=min(original_audio.shape[1], unprotected_gen.shape[1]) if original_audio.shape[1]