warisqr7 commited on
Commit
9a56940
·
verified ·
1 Parent(s): 93fa313

Update custom_interface.py

Browse files
Files changed (1) hide show
  1. custom_interface.py +23 -6
custom_interface.py CHANGED
@@ -121,6 +121,28 @@ class CustomEncoderWav2vec2Classifier(Pretrained):
121
  text_lab = self.hparams.label_encoder.decode_torch(index)
122
  return out_prob, score, index, text_lab
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  def classify_file(self, path):
125
  """Classifies the given audiofile into the given set of labels.
126
 
@@ -141,12 +163,7 @@ class CustomEncoderWav2vec2Classifier(Pretrained):
141
  List with the text labels corresponding to the indexes.
142
  (label encoder should be provided).
143
  """
144
- waveform = self.load_audio(path)
145
- # Fake a batch:
146
- batch = waveform.unsqueeze(0)
147
- rel_length = torch.tensor([1.0])
148
- outputs = self.encode_batch(batch, rel_length)
149
- outputs = self.mods.output_mlp(outputs).squeeze(1)
150
  out_prob = self.hparams.softmax(outputs)
151
  score, index = torch.max(out_prob, dim=-1)
152
  text_lab = self.hparams.label_encoder.decode_torch(index)
 
121
  text_lab = self.hparams.label_encoder.decode_torch(index)
122
  return out_prob, score, index, text_lab
123
 
124
+ def embed_file(self, path):
125
+ """Returns embedding (last layer output) for the given audiofile.
126
+
127
+ Arguments
128
+ ---------
129
+ path : str
130
+ Path to audio file to classify.
131
+
132
+ Returns
133
+ -------
134
+ embed
135
+ The log posterior probabilities of each class ([batch, embed_dim])
136
+ """
137
+ waveform = self.load_audio(path)
138
+ # Fake a batch:
139
+ batch = waveform.unsqueeze(0)
140
+ rel_length = torch.tensor([1.0])
141
+ outputs = self.encode_batch(batch, rel_length)
142
+ outputs = self.mods.output_mlp(outputs).squeeze(1)
143
+ return outputs
144
+
145
+
146
  def classify_file(self, path):
147
  """Classifies the given audiofile into the given set of labels.
148
 
 
163
  List with the text labels corresponding to the indexes.
164
  (label encoder should be provided).
165
  """
166
+ outputs = self.embed_file(path)
 
 
 
 
 
167
  out_prob = self.hparams.softmax(outputs)
168
  score, index = torch.max(out_prob, dim=-1)
169
  text_lab = self.hparams.label_encoder.decode_torch(index)