init
Browse files- pipeline/kotoba_whisper.py +2 -2
- pipeline/push_pipeline.py +0 -2
- pipeline/test_pipeline.py +1 -2
pipeline/kotoba_whisper.py
CHANGED
@@ -20,7 +20,7 @@ class Punctuator:
|
|
20 |
|
21 |
ja_punctuations = ["!", "?", "、", "。"]
|
22 |
|
23 |
-
def __init__(self, model: str = "
|
24 |
self.punctuation_model = PunctCapSegModelONNX.from_pretrained(model)
|
25 |
|
26 |
def punctuate(self, text: str) -> str:
|
@@ -123,7 +123,7 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
|
|
123 |
}
|
124 |
postprocess_params = {"add_punctuation": add_punctuation, "return_timestamps": True, "return_language": False}
|
125 |
forward_params = {} if generate_kwargs is None else generate_kwargs
|
126 |
-
forward_params.update({"max_new_tokens": max_new_tokens, "return_timestamps": True})
|
127 |
return preprocess_params, forward_params, postprocess_params
|
128 |
|
129 |
def preprocess(self,
|
|
|
20 |
|
21 |
ja_punctuations = ["!", "?", "、", "。"]
|
22 |
|
23 |
+
def __init__(self, model: str = "1-800-BAD-CODE/xlm-roberta_punctuation_fullstop_truecase"):
|
24 |
self.punctuation_model = PunctCapSegModelONNX.from_pretrained(model)
|
25 |
|
26 |
def punctuate(self, text: str) -> str:
|
|
|
123 |
}
|
124 |
postprocess_params = {"add_punctuation": add_punctuation, "return_timestamps": True, "return_language": False}
|
125 |
forward_params = {} if generate_kwargs is None else generate_kwargs
|
126 |
+
forward_params.update({"max_new_tokens": max_new_tokens, "return_timestamps": True, "language": "ja", "task": "transcribe"})
|
127 |
return preprocess_params, forward_params, postprocess_params
|
128 |
|
129 |
def preprocess(self,
|
pipeline/push_pipeline.py
CHANGED
@@ -14,8 +14,6 @@ PIPELINE_REGISTRY.register_pipeline(
|
|
14 |
tf_model=TFWhisperForConditionalGeneration
|
15 |
)
|
16 |
pipe = pipeline(task="kotoba-whisper", model="kotoba-tech/kotoba-whisper-v2.0", chunk_length_s=15, batch_size=16)
|
17 |
-
output = pipe(test_audio, add_punctuation=True)
|
18 |
-
pprint(output)
|
19 |
pipe.push_to_hub(model_alias)
|
20 |
|
21 |
|
|
|
14 |
tf_model=TFWhisperForConditionalGeneration
|
15 |
)
|
16 |
pipe = pipeline(task="kotoba-whisper", model="kotoba-tech/kotoba-whisper-v2.0", chunk_length_s=15, batch_size=16)
|
|
|
|
|
17 |
pipe.push_to_hub(model_alias)
|
18 |
|
19 |
|
pipeline/test_pipeline.py
CHANGED
@@ -6,6 +6,5 @@ pipe = pipeline(model="kotoba-tech/kotoba-whisper-v2.2", chunk_length_s=None, ba
|
|
6 |
output = pipe("sample_diarization_japanese.mp3")
|
7 |
pprint(output)
|
8 |
|
9 |
-
|
10 |
-
output = pipe("sample_diarization_japanese.mp3")
|
11 |
pprint(output)
|
|
|
6 |
output = pipe("sample_diarization_japanese.mp3")
|
7 |
pprint(output)
|
8 |
|
9 |
+
output = pipe("sample_diarization_japanese.mp3", add_punctu)
|
|
|
10 |
pprint(output)
|