Update README.md
Browse files
README.md
CHANGED
@@ -31,20 +31,21 @@ pip install -r requirements.txt
|
|
31 |
>>> from ced_model.feature_extraction_ced import CedFeatureExtractor
|
32 |
>>> from ced_model.modeling_ced import CedForAudioClassification
|
33 |
|
34 |
-
>>>
|
35 |
-
>>> feature_extractor = CedFeatureExtractor.from_pretrained(
|
36 |
-
>>> model = CedForAudioClassification.from_pretrained(
|
37 |
|
38 |
>>> import torchaudio
|
39 |
>>> audio, sampling_rate = torchaudio.load("resources/JeD5V5aaaoI_931_932.wav")
|
40 |
-
|
41 |
>>> inputs = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt")
|
|
|
|
|
42 |
>>> with torch.no_grad():
|
43 |
... logits = model(**inputs).logits
|
44 |
|
45 |
-
>>>
|
46 |
-
>>>
|
47 |
-
>>> model.config.id2label[predicted_class_ids]
|
48 |
'Finger snapping'
|
49 |
```
|
50 |
|
|
|
31 |
>>> from ced_model.feature_extraction_ced import CedFeatureExtractor
|
32 |
>>> from ced_model.modeling_ced import CedForAudioClassification
|
33 |
|
34 |
+
>>> model_name = "mispeech/ced-base"
|
35 |
+
>>> feature_extractor = CedFeatureExtractor.from_pretrained(model_name)
|
36 |
+
>>> model = CedForAudioClassification.from_pretrained(model_name)
|
37 |
|
38 |
>>> import torchaudio
|
39 |
>>> audio, sampling_rate = torchaudio.load("resources/JeD5V5aaaoI_931_932.wav")
|
40 |
+
>>> assert sampling_rate == 16000
|
41 |
>>> inputs = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt")
|
42 |
+
|
43 |
+
>>> import torch
|
44 |
>>> with torch.no_grad():
|
45 |
... logits = model(**inputs).logits
|
46 |
|
47 |
+
>>> predicted_class_id = torch.argmax(logits, dim=-1).item()
|
48 |
+
>>> model.config.id2label[predicted_class_id]
|
|
|
49 |
'Finger snapping'
|
50 |
```
|
51 |
|