leofltt commited on
Commit
9e55989
·
verified ·
1 Parent(s): 96018e7

Update app.py

Browse files

updated asr model

Files changed (1) hide show
  1. app.py +14 -9
app.py CHANGED
@@ -6,25 +6,30 @@ from datasets import load_dataset
6
  from transformers import pipeline
7
  from transformers import BarkModel, BarkProcessor
8
 
 
9
 
 
10
 
11
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
12
 
13
- # load speech translation checkpoint
14
- asr_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device)
15
 
16
- barkmodel = BarkModel.from_pretrained("suno/bark")
17
- barkprocessor = BarkProcessor.from_pretrained("suno/bark")
18
 
19
 
20
  def translate(audio):
21
- outputs = asr_pipe(audio, max_new_tokens=256, generate_kwargs={"task": "transcribe", "language": "it"})
22
- return outputs["text"]
 
 
 
23
 
24
 
25
  def synthesise(text):
26
- inputs = barkprocessor(text=[text], voice_preset="v2/it_speaker_4",return_tensors="pt")
27
- speech = barkmodel.generate(**inputs, do_sample=True)
28
  return speech
29
 
30
 
@@ -32,7 +37,7 @@ def speech_to_speech_translation(audio):
32
  translated_text = translate(audio)
33
  synthesised_speech = synthesise(translated_text)
34
  synthesised_speech = (synthesised_speech.numpy() * 32767).astype(np.int16)
35
- return 16000, synthesised_speech
36
 
37
 
38
  title = "Cascaded STST"
 
6
  from transformers import pipeline
7
  from transformers import BarkModel, BarkProcessor
8
 
9
+ from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration
10
 
11
+ SAMPLE_RATE = 16000
12
 
13
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
14
 
15
+ asr_model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-mustc-multilingual-st")
16
+ asr_processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-mustc-multilingual-st")
17
 
18
+ bark_model = BarkModel.from_pretrained("suno/bark")
19
+ bark_processor = BarkProcessor.from_pretrained("suno/bark")
20
 
21
 
22
  def translate(audio):
23
+ inputs = processor(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt")
24
+ generated_ids = asr_model.generate(inputs["input_features"],attention_mask=inputs["attention_mask"],
25
+ forced_bos_token_id=asr_processor.tokenizer.lang_code_to_id["it"],)
26
+ translation = processor.batch_decode(generated_ids, skip_special_tokens=True)
27
+ return translation
28
 
29
 
30
  def synthesise(text):
31
+ inputs = bark_processor(text=text, voice_preset="v2/it_speaker_4",return_tensors="pt")
32
+ speech = bark_model.generate(**inputs, do_sample=True)
33
  return speech
34
 
35
 
 
37
  translated_text = translate(audio)
38
  synthesised_speech = synthesise(translated_text)
39
  synthesised_speech = (synthesised_speech.numpy() * 32767).astype(np.int16)
40
+ return SAMPLE_RATE, synthesised_speech
41
 
42
 
43
  title = "Cascaded STST"