Spaces:
Runtime error
Runtime error
Update audio_diffusion_attacks_forhf/src/music_gen.py
Browse files
audio_diffusion_attacks_forhf/src/music_gen.py
CHANGED
@@ -12,7 +12,7 @@ class MusicGenEval:
|
|
12 |
model_name="facebook/musicgen-stereo-small"
|
13 |
self.processor = AutoProcessor.from_pretrained(model_name)
|
14 |
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
15 |
-
self.model=self.model.to(device='cuda')
|
16 |
self.input_sample_rate=input_sample_rate
|
17 |
self.audio_steps=audio_steps
|
18 |
self.mel_loss = losses.MelSpectrogramLoss(n_mels=[5, 10, 20, 40, 80, 160, 320],
|
@@ -27,8 +27,10 @@ class MusicGenEval:
|
|
27 |
protected_audio=protected_audio[:, :, :self.audio_steps]
|
28 |
input_len=original_audio.shape[-1]
|
29 |
|
30 |
-
unprotected_gen=self.generate_audio(original_audio)[0].to(device='cuda')
|
31 |
-
|
|
|
|
|
32 |
|
33 |
eval_dict={}
|
34 |
# Difference between original and unprotected gen
|
@@ -48,7 +50,8 @@ class MusicGenEval:
|
|
48 |
def generate_audio(self, audio):
|
49 |
torch.manual_seed(0)
|
50 |
|
51 |
-
transform = torchaudio.transforms.Resample(self.input_sample_rate, 32000).to(device='cuda')
|
|
|
52 |
waveform=transform(audio[0]).detach().cpu()
|
53 |
# waveform.clamp_(0,1)
|
54 |
a=torch.min(waveform)
|
@@ -64,10 +67,12 @@ class MusicGenEval:
|
|
64 |
return_tensors="pt",
|
65 |
)
|
66 |
for d in inputs.data:
|
67 |
-
inputs.data[d]=inputs.data[d].to(device='cuda')
|
|
|
68 |
audio_values = self.model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=1024)
|
69 |
|
70 |
-
transform = torchaudio.transforms.Resample(32000, self.input_sample_rate).to(device='cuda')
|
|
|
71 |
audio_values=transform(audio_values)
|
72 |
return audio_values
|
73 |
|
|
|
12 |
model_name="facebook/musicgen-stereo-small"
|
13 |
self.processor = AutoProcessor.from_pretrained(model_name)
|
14 |
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
15 |
+
#Andy commented: self.model=self.model.to(device='cuda')
|
16 |
self.input_sample_rate=input_sample_rate
|
17 |
self.audio_steps=audio_steps
|
18 |
self.mel_loss = losses.MelSpectrogramLoss(n_mels=[5, 10, 20, 40, 80, 160, 320],
|
|
|
27 |
protected_audio=protected_audio[:, :, :self.audio_steps]
|
28 |
input_len=original_audio.shape[-1]
|
29 |
|
30 |
+
#Andy edited: unprotected_gen=self.generate_audio(original_audio)[0].to(device='cuda')
|
31 |
+
unprotected_gen=self.generate_audio(original_audio)[0]
|
32 |
+
#Andy edited: protected_gen=self.generate_audio(protected_audio)[0].to(device='cuda')
|
33 |
+
protected_gen=self.generate_audio(protected_audio)[0]
|
34 |
|
35 |
eval_dict={}
|
36 |
# Difference between original and unprotected gen
|
|
|
50 |
def generate_audio(self, audio):
|
51 |
torch.manual_seed(0)
|
52 |
|
53 |
+
#Andy edited: transform = torchaudio.transforms.Resample(self.input_sample_rate, 32000).to(device='cuda')
|
54 |
+
transform = torchaudio.transforms.Resample(self.input_sample_rate, 32000)
|
55 |
waveform=transform(audio[0]).detach().cpu()
|
56 |
# waveform.clamp_(0,1)
|
57 |
a=torch.min(waveform)
|
|
|
67 |
return_tensors="pt",
|
68 |
)
|
69 |
for d in inputs.data:
|
70 |
+
#Andy edited: inputs.data[d]=inputs.data[d].to(device='cuda')
|
71 |
+
inputs.data[d]=inputs.data[d]
|
72 |
audio_values = self.model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=1024)
|
73 |
|
74 |
+
#Andy edited: transform = torchaudio.transforms.Resample(32000, self.input_sample_rate).to(device='cuda')
|
75 |
+
transform = torchaudio.transforms.Resample(32000, self.input_sample_rate)
|
76 |
audio_values=transform(audio_values)
|
77 |
return audio_values
|
78 |
|