Update eval.py
Browse files
eval.py
CHANGED
|
@@ -6,8 +6,8 @@ from typing import Dict
|
|
| 6 |
import torch
|
| 7 |
from datasets import Audio, Dataset, load_dataset, load_metric
|
| 8 |
|
| 9 |
-
from transformers import AutoFeatureExtractor, pipeline, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, Wav2Vec2FeatureExtractor
|
| 10 |
-
from pyctcdecode import BeamSearchDecoderCTC
|
| 11 |
|
| 12 |
|
| 13 |
def log_results(result: Dataset, args: Dict[str, str]):
|
|
@@ -16,7 +16,7 @@ def log_results(result: Dataset, args: Dict[str, str]):
|
|
| 16 |
log_outputs = args.log_outputs
|
| 17 |
lm = "withLM" if args.use_lm else "noLM"
|
| 18 |
model_id = args.model_id.replace("/", "_").replace(".", "")
|
| 19 |
-
dataset_id = "_".join(args.dataset.split("/") + [
|
| 20 |
|
| 21 |
# load metric
|
| 22 |
wer = load_metric("wer")
|
|
@@ -112,11 +112,27 @@ def main(args):
|
|
| 112 |
args.device = 0 if torch.cuda.is_available() else -1
|
| 113 |
# asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
-
asr = pipeline("automatic-speech-recognition", model=args.model_id, feature_extractor=feature_extractor, device=args.device, decoder=BeamSearchDecoderCTC.load_from_dir("./"))
|
| 120 |
|
| 121 |
# map function to decode audio
|
| 122 |
def map_to_pred(batch):
|
|
|
|
| 6 |
import torch
|
| 7 |
from datasets import Audio, Dataset, load_dataset, load_metric
|
| 8 |
|
| 9 |
+
from transformers import AutoFeatureExtractor, AutoModelForCTC, pipeline, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, Wav2Vec2FeatureExtractor
|
| 10 |
+
# from pyctcdecode import BeamSearchDecoderCTC
|
| 11 |
|
| 12 |
|
| 13 |
def log_results(result: Dataset, args: Dict[str, str]):
|
|
|
|
| 16 |
log_outputs = args.log_outputs
|
| 17 |
lm = "withLM" if args.use_lm else "noLM"
|
| 18 |
model_id = args.model_id.replace("/", "_").replace(".", "")
|
| 19 |
+
dataset_id = "_".join([model_id] + args.dataset.split("/") + [args.config, args.split, lm])
|
| 20 |
|
| 21 |
# load metric
|
| 22 |
wer = load_metric("wer")
|
|
|
|
| 112 |
args.device = 0 if torch.cuda.is_available() else -1
|
| 113 |
# asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
|
| 114 |
|
| 115 |
+
model_instance = AutoModelForCTC.from_pretrained(args.model_id)
|
| 116 |
+
if args.use_lm:
|
| 117 |
+
processor = Wav2Vec2ProcessorWithLM.from_pretrained(args.model_id)
|
| 118 |
+
decoder = processor.decoder
|
| 119 |
+
else:
|
| 120 |
+
processor = Wav2Vec2Processor.from_pretrained(args.model_id)
|
| 121 |
+
decoder = None
|
| 122 |
+
asr = pipeline(
|
| 123 |
+
"automatic-speech-recognition",
|
| 124 |
+
model=model_instance,
|
| 125 |
+
tokenizer=processor.tokenizer,
|
| 126 |
+
feature_extractor=processor.feature_extractor,
|
| 127 |
+
decoder=decoder,
|
| 128 |
+
device=args.device
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# feature_extractor_dict, _ = Wav2Vec2FeatureExtractor.get_feature_extractor_dict(args.model_id)
|
| 132 |
+
# feature_extractor_dict["processor_class"] = "Wav2Vec2Processor" if not args.use_lm else "Wav2Vec2ProcessorWithLM"
|
| 133 |
+
# feature_extractor = Wav2Vec2FeatureExtractor.from_dict(feature_extractor_dict)
|
| 134 |
|
| 135 |
+
# asr = pipeline("automatic-speech-recognition", model=args.model_id, feature_extractor=feature_extractor, device=args.device, decoder=BeamSearchDecoderCTC.load_from_dir("./"))
|
| 136 |
|
| 137 |
# map function to decode audio
|
| 138 |
def map_to_pred(batch):
|