ALeLacheur commited on
Commit
678fd0b
·
1 Parent(s): b6809ba

Changed to 16hz and added speaker embedding

Browse files
Files changed (1) hide show
  1. app.py +27 -2
app.py CHANGED
@@ -4,6 +4,7 @@ import voicebox.src.attacks.offline.perturbation.voicebox.voicebox as vb #To acc
4
  #import voicebox.src.attacks.online.voicebox_streamer as streamer #To access VoiceBoxStreamer class
5
  import numpy as np
6
  from voicebox.src.constants import PPG_PRETRAINED_PATH
 
7
 
8
  #Set voicebox default parameters
9
  LOOKAHEAD = 5
@@ -28,7 +29,19 @@ voicebox_kwargs={'win_length': 256,
28
  'projection_norm': float('inf'),
29
  'conditioning_dim': 512}
30
 
31
- #Load pretrained model:
 
 
 
 
 
 
 
 
 
 
 
 
32
  model = vb.VoiceBox(**voicebox_kwargs)
33
  model.load_state_dict(torch.load('voicebox/pretrained/voicebox/voicebox_final.pt', map_location=torch.device('cpu')), strict=True)
34
  model.eval()
@@ -41,6 +54,12 @@ def float32_to_int16(waveform):
41
  waveform = waveform.ravel()
42
  return waveform
43
 
 
 
 
 
 
 
44
  #Define predict function:
45
  def predict(inp):
46
  #How to transform audio from string to tensor
@@ -51,9 +70,15 @@ def predict(inp):
51
  waveform = transform_to_16hz(waveform)
52
  sample_rate = 16000
53
 
 
 
 
 
 
 
54
  #Run model without changing weights
55
  with torch.no_grad():
56
- waveform = model(waveform)
57
 
58
  #Transform output audio into gradio-readable format
59
  waveform = waveform.numpy()
 
4
  #import voicebox.src.attacks.online.voicebox_streamer as streamer #To access VoiceBoxStreamer class
5
  import numpy as np
6
  from voicebox.src.constants import PPG_PRETRAINED_PATH
7
+ from voicebox.src.models import ResNetSE34V2
8
 
9
  #Set voicebox default parameters
10
  LOOKAHEAD = 5
 
29
  'projection_norm': float('inf'),
30
  'conditioning_dim': 512}
31
 
32
+ '''
33
+ #Set streamer default parameters:
34
+ config_path = 'voicebox/pretrained/voicebox/voicebox_final.yaml'
35
+ with open(config_path) as f:
36
+ config = yaml.safe_load(f)
37
+
38
+ #Load pretrained model (streamer):
39
+ model = streamer.VoiceBoxStreamer(**config)
40
+ model.load_state_dict(torch.load('voicebox/pretrained/voicebox/voicebox_final.pt', map_location=torch.device('cpu')), strict=True)
41
+ model.eval()
42
+ '''
43
+
44
+ #Load pretrained model (VoiceBox):
45
  model = vb.VoiceBox(**voicebox_kwargs)
46
  model.load_state_dict(torch.load('voicebox/pretrained/voicebox/voicebox_final.pt', map_location=torch.device('cpu')), strict=True)
47
  model.eval()
 
54
  waveform = waveform.ravel()
55
  return waveform
56
 
57
+ def get_embedding(recording):
58
+ resnet = ResNetSE34V2(nOut=512, encoder_type='ASP')
59
+ recording = recording.view(1, -1)
60
+ embedding = resnet(recording)
61
+ return embedding
62
+
63
  #Define predict function:
64
  def predict(inp):
65
  #How to transform audio from string to tensor
 
70
  waveform = transform_to_16hz(waveform)
71
  sample_rate = 16000
72
 
73
+ #Get speaker embedding
74
+ condition_tensor = get_embedding(waveform)
75
+ condition_tensor = condition_tensor.reshape(1, 1, -1)
76
+ n_frames = waveform.shape[1]
77
+ condition_tensor = condition_tensor.repeat(1, n_frames, 1)
78
+
79
  #Run model without changing weights
80
  with torch.no_grad():
81
+ waveform = model(x=waveform, y=condition_tensor)
82
 
83
  #Transform output audio into gradio-readable format
84
  waveform = waveform.numpy()