asahi417 commited on
Commit
f9f9f36
·
1 Parent(s): 6252154
pipeline/kotoba_whisper.py CHANGED
@@ -20,7 +20,7 @@ class Punctuator:
20
 
21
  ja_punctuations = ["!", "?", "、", "。"]
22
 
23
- def __init__(self, model: str = "pcs_47lang"):
24
  self.punctuation_model = PunctCapSegModelONNX.from_pretrained(model)
25
 
26
  def punctuate(self, text: str) -> str:
@@ -123,7 +123,7 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
123
  }
124
  postprocess_params = {"add_punctuation": add_punctuation, "return_timestamps": True, "return_language": False}
125
  forward_params = {} if generate_kwargs is None else generate_kwargs
126
- forward_params.update({"max_new_tokens": max_new_tokens, "return_timestamps": True})
127
  return preprocess_params, forward_params, postprocess_params
128
 
129
  def preprocess(self,
 
20
 
21
  ja_punctuations = ["!", "?", "、", "。"]
22
 
23
+ def __init__(self, model: str = "1-800-BAD-CODE/xlm-roberta_punctuation_fullstop_truecase"):
24
  self.punctuation_model = PunctCapSegModelONNX.from_pretrained(model)
25
 
26
  def punctuate(self, text: str) -> str:
 
123
  }
124
  postprocess_params = {"add_punctuation": add_punctuation, "return_timestamps": True, "return_language": False}
125
  forward_params = {} if generate_kwargs is None else generate_kwargs
126
+ forward_params.update({"max_new_tokens": max_new_tokens, "return_timestamps": True, "language": "ja", "task": "transcribe"})
127
  return preprocess_params, forward_params, postprocess_params
128
 
129
  def preprocess(self,
pipeline/push_pipeline.py CHANGED
@@ -14,8 +14,6 @@ PIPELINE_REGISTRY.register_pipeline(
14
  tf_model=TFWhisperForConditionalGeneration
15
  )
16
  pipe = pipeline(task="kotoba-whisper", model="kotoba-tech/kotoba-whisper-v2.0", chunk_length_s=15, batch_size=16)
17
- output = pipe(test_audio, add_punctuation=True)
18
- pprint(output)
19
  pipe.push_to_hub(model_alias)
20
 
21
 
 
14
  tf_model=TFWhisperForConditionalGeneration
15
  )
16
  pipe = pipeline(task="kotoba-whisper", model="kotoba-tech/kotoba-whisper-v2.0", chunk_length_s=15, batch_size=16)
 
 
17
  pipe.push_to_hub(model_alias)
18
 
19
 
pipeline/test_pipeline.py CHANGED
@@ -6,6 +6,5 @@ pipe = pipeline(model="kotoba-tech/kotoba-whisper-v2.2", chunk_length_s=None, ba
6
  output = pipe("sample_diarization_japanese.mp3")
7
  pprint(output)
8
 
9
- pipe = pipeline(model="kotoba-tech/kotoba-whisper-v2.2", chunk_length_s=None, batch_size=16, trust_remote_code=True, return_unique_speaker=False)
10
- output = pipe("sample_diarization_japanese.mp3")
11
  pprint(output)
 
6
  output = pipe("sample_diarization_japanese.mp3")
7
  pprint(output)
8
 
9
+ output = pipe("sample_diarization_japanese.mp3", add_punctu)
 
10
  pprint(output)