leofltt commited on
Commit
99623ea
·
verified ·
1 Parent(s): b2319dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -6,13 +6,12 @@ from datasets import load_dataset
6
  from transformers import pipeline
7
  from transformers import BarkModel, BarkProcessor
8
 
9
- from transformers import AutoProcessor, SeamlessM4Tv2Model
10
-
11
 
12
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
 
14
- asr_model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large")
15
- asr_processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large")
16
 
17
  asr_model.to(device)
18
 
@@ -24,8 +23,9 @@ bark_model.to(device)
24
 
25
  def translate(audio):
26
  inputs = asr_processor(audio, sampling_rate=16000, return_tensors="pt")
27
- output_tokens = asr_model.generate(**inputs, tgt_lang="ita", generate_speech=False)
28
- translation = asr_processor.decode(output_tokens[0].tolist()[0], skip_special_tokens=True)
 
29
  return translation
30
 
31
 
 
6
  from transformers import pipeline
7
  from transformers import BarkModel, BarkProcessor
8
 
9
+ from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration
 
10
 
11
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
12
 
13
+ asr_model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-mustc-multilingual-st")
14
+ asr_processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-mustc-multilingual-st")
15
 
16
  asr_model.to(device)
17
 
 
23
 
24
  def translate(audio):
25
  inputs = asr_processor(audio, sampling_rate=16000, return_tensors="pt")
26
+ generated_ids = asr_model.generate(inputs["input_features"],attention_mask=inputs["attention_mask"],
27
+ forced_bos_token_id=asr_processor.tokenizer.lang_code_to_id['it'],)
28
+ translation = asr_processor.batch_decode(generated_ids, skip_special_tokens=True)
29
  return translation
30
 
31