ALeLacheur commited on
Commit
8fae7b2
·
verified ·
1 Parent(s): 7d38b5c

Update audio_diffusion_attacks_forhf/src/music_gen.py

Browse files
audio_diffusion_attacks_forhf/src/music_gen.py CHANGED
@@ -12,7 +12,7 @@ class MusicGenEval:
12
  model_name="facebook/musicgen-stereo-small"
13
  self.processor = AutoProcessor.from_pretrained(model_name)
14
  self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
15
- self.model=self.model.to(device='cuda')
16
  self.input_sample_rate=input_sample_rate
17
  self.audio_steps=audio_steps
18
  self.mel_loss = losses.MelSpectrogramLoss(n_mels=[5, 10, 20, 40, 80, 160, 320],
@@ -27,8 +27,10 @@ class MusicGenEval:
27
  protected_audio=protected_audio[:, :, :self.audio_steps]
28
  input_len=original_audio.shape[-1]
29
 
30
- unprotected_gen=self.generate_audio(original_audio)[0].to(device='cuda')
31
- protected_gen=self.generate_audio(protected_audio)[0].to(device='cuda')
 
 
32
 
33
  eval_dict={}
34
  # Difference between original and unprotected gen
@@ -48,7 +50,8 @@ class MusicGenEval:
48
  def generate_audio(self, audio):
49
  torch.manual_seed(0)
50
 
51
- transform = torchaudio.transforms.Resample(self.input_sample_rate, 32000).to(device='cuda')
 
52
  waveform=transform(audio[0]).detach().cpu()
53
  # waveform.clamp_(0,1)
54
  a=torch.min(waveform)
@@ -64,10 +67,12 @@ class MusicGenEval:
64
  return_tensors="pt",
65
  )
66
  for d in inputs.data:
67
- inputs.data[d]=inputs.data[d].to(device='cuda')
 
68
  audio_values = self.model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=1024)
69
 
70
- transform = torchaudio.transforms.Resample(32000, self.input_sample_rate).to(device='cuda')
 
71
  audio_values=transform(audio_values)
72
  return audio_values
73
 
 
12
  model_name="facebook/musicgen-stereo-small"
13
  self.processor = AutoProcessor.from_pretrained(model_name)
14
  self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
15
+ #Andy commented: self.model=self.model.to(device='cuda')
16
  self.input_sample_rate=input_sample_rate
17
  self.audio_steps=audio_steps
18
  self.mel_loss = losses.MelSpectrogramLoss(n_mels=[5, 10, 20, 40, 80, 160, 320],
 
27
  protected_audio=protected_audio[:, :, :self.audio_steps]
28
  input_len=original_audio.shape[-1]
29
 
30
+ #Andy edited: unprotected_gen=self.generate_audio(original_audio)[0].to(device='cuda')
31
+ unprotected_gen=self.generate_audio(original_audio)[0]
32
+ #Andy edited: protected_gen=self.generate_audio(protected_audio)[0].to(device='cuda')
33
+ protected_gen=self.generate_audio(protected_audio)[0]
34
 
35
  eval_dict={}
36
  # Difference between original and unprotected gen
 
50
  def generate_audio(self, audio):
51
  torch.manual_seed(0)
52
 
53
+ #Andy edited: transform = torchaudio.transforms.Resample(self.input_sample_rate, 32000).to(device='cuda')
54
+ transform = torchaudio.transforms.Resample(self.input_sample_rate, 32000)
55
  waveform=transform(audio[0]).detach().cpu()
56
  # waveform.clamp_(0,1)
57
  a=torch.min(waveform)
 
67
  return_tensors="pt",
68
  )
69
  for d in inputs.data:
70
+ #Andy edited: inputs.data[d]=inputs.data[d].to(device='cuda')
71
+ inputs.data[d]=inputs.data[d]
72
  audio_values = self.model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=1024)
73
 
74
+ #Andy edited: transform = torchaudio.transforms.Resample(32000, self.input_sample_rate).to(device='cuda')
75
+ transform = torchaudio.transforms.Resample(32000, self.input_sample_rate)
76
  audio_values=transform(audio_values)
77
  return audio_values
78