Update custom_interface.py
Browse files- custom_interface.py +23 -6
custom_interface.py
CHANGED
@@ -121,6 +121,28 @@ class CustomEncoderWav2vec2Classifier(Pretrained):
|
|
121 |
text_lab = self.hparams.label_encoder.decode_torch(index)
|
122 |
return out_prob, score, index, text_lab
|
123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
def classify_file(self, path):
|
125 |
"""Classifies the given audiofile into the given set of labels.
|
126 |
|
@@ -141,12 +163,7 @@ class CustomEncoderWav2vec2Classifier(Pretrained):
|
|
141 |
List with the text labels corresponding to the indexes.
|
142 |
(label encoder should be provided).
|
143 |
"""
|
144 |
-
|
145 |
-
# Fake a batch:
|
146 |
-
batch = waveform.unsqueeze(0)
|
147 |
-
rel_length = torch.tensor([1.0])
|
148 |
-
outputs = self.encode_batch(batch, rel_length)
|
149 |
-
outputs = self.mods.output_mlp(outputs).squeeze(1)
|
150 |
out_prob = self.hparams.softmax(outputs)
|
151 |
score, index = torch.max(out_prob, dim=-1)
|
152 |
text_lab = self.hparams.label_encoder.decode_torch(index)
|
|
|
121 |
text_lab = self.hparams.label_encoder.decode_torch(index)
|
122 |
return out_prob, score, index, text_lab
|
123 |
|
124 |
+
def embed_file(self, path):
|
125 |
+
"""Returns embedding (last layer output) for the given audiofile.
|
126 |
+
|
127 |
+
Arguments
|
128 |
+
---------
|
129 |
+
path : str
|
130 |
+
Path to audio file to classify.
|
131 |
+
|
132 |
+
Returns
|
133 |
+
-------
|
134 |
+
embed
|
135 |
+
The log posterior probabilities of each class ([batch, embed_dim])
|
136 |
+
"""
|
137 |
+
waveform = self.load_audio(path)
|
138 |
+
# Fake a batch:
|
139 |
+
batch = waveform.unsqueeze(0)
|
140 |
+
rel_length = torch.tensor([1.0])
|
141 |
+
outputs = self.encode_batch(batch, rel_length)
|
142 |
+
outputs = self.mods.output_mlp(outputs).squeeze(1)
|
143 |
+
return outputs
|
144 |
+
|
145 |
+
|
146 |
def classify_file(self, path):
|
147 |
"""Classifies the given audiofile into the given set of labels.
|
148 |
|
|
|
163 |
List with the text labels corresponding to the indexes.
|
164 |
(label encoder should be provided).
|
165 |
"""
|
166 |
+
outputs = self.embed_file(path)
|
|
|
|
|
|
|
|
|
|
|
167 |
out_prob = self.hparams.softmax(outputs)
|
168 |
score, index = torch.max(out_prob, dim=-1)
|
169 |
text_lab = self.hparams.label_encoder.decode_torch(index)
|