leofltt commited on
Commit
7219472
·
verified ·
1 Parent(s): c88b4e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -8
app.py CHANGED
@@ -6,14 +6,13 @@ from datasets import load_dataset
6
  from transformers import pipeline
7
  from transformers import BarkModel, BarkProcessor
8
 
9
- from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration
10
 
11
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
12
 
13
- # asr_pipe = pipeline("automatic-speech-recognition", model="facebook/s2t-medium-mustc-multilingual-st", device=device)
14
 
15
- asr_model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-mustc-multilingual-st")
16
- asr_processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-mustc-multilingual-st")
17
 
18
  asr_model.to(device)
19
 
@@ -25,9 +24,8 @@ bark_model.to(device)
25
 
26
  def translate(audio):
27
  inputs = asr_processor(audio, sampling_rate=16000, return_tensors="pt")
28
- generated_ids = asr_model.generate(inputs["input_features"],attention_mask=inputs["attention_mask"],
29
- forced_bos_token_id=asr_processor.tokenizer.lang_code_to_id['it'],)
30
- translation = asr_processor.batch_decode(generated_ids, skip_special_tokens=True)
31
  return translation
32
 
33
 
 
6
  from transformers import pipeline
7
  from transformers import BarkModel, BarkProcessor
8
 
9
+ from transformers import AutoProcessor, SeamlessM4Tv2Model
10
 
 
11
 
12
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
 
14
+ asr_model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large")
15
+ asr_processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large")
16
 
17
  asr_model.to(device)
18
 
 
24
 
25
  def translate(audio):
26
  inputs = asr_processor(audio, sampling_rate=16000, return_tensors="pt")
27
+ output_tokens = asr_model.generate(**inputs, tgt_lang="it", generate_speech=False)
28
+ translation = asr_processor.decode(output_tokens[0].tolist()[0], skip_special_tokens=True)
 
29
  return translation
30
 
31