gdnartea commited on
Commit
0003cc7
·
verified ·
1 Parent(s): b83d714

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -10
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, VitsForConditionalGeneration, VitsProcessor
4
  from nemo.collections.asr.models import ASRModel
5
 
6
 
@@ -27,9 +27,9 @@ proc_model = AutoModelForCausalLM.from_pretrained(
27
  trust_remote_code=True,
28
  )
29
 
30
- # Load the TTS model and processor
31
- tts_processor = VitsProcessor.from_pretrained("facebook/mms-tts-eng")
32
- tts_model = VitsForConditionalGeneration.from_pretrained("facebook/mms-tts-eng")
33
 
34
 
35
  def process_speech(speech):
@@ -39,14 +39,12 @@ def process_speech(speech):
39
  # Process the text
40
  inputs = proc_tokenizer.encode(transcription + proc_tokenizer.eos_token, return_tensors='pt')
41
  outputs = proc_model.generate(inputs, max_length=100, temperature=0.7, pad_token_id=proc_tokenizer.eos_token_id)
42
- processed_text = proc_tokenizer.decode(outputs[0], skip_special_tokens=True)
43
-
 
44
  # Convert the processed text to speech
45
- inputs = tts_processor(processed_text, return_tensors="pt")
46
  with torch.no_grad():
47
- logits = tts_model(inputs["input_ids"]).logits
48
- predicted_ids = torch.argmax(logits, dim=-1)
49
- audio = tts_processor.decode(predicted_ids)
50
 
51
  return audio
52
 
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, VitsModel
4
  from nemo.collections.asr.models import ASRModel
5
 
6
 
 
27
  trust_remote_code=True,
28
  )
29
 
30
+ # Load the TTS model
31
+ tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
32
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
33
 
34
 
35
  def process_speech(speech):
 
39
  # Process the text
40
  inputs = proc_tokenizer.encode(transcription + proc_tokenizer.eos_token, return_tensors='pt')
41
  outputs = proc_model.generate(inputs, max_length=100, temperature=0.7, pad_token_id=proc_tokenizer.eos_token_id)
42
+ text = proc_tokenizer.decode(outputs[0], skip_special_tokens=True)
43
+ processed_text = tts_tokenizer(text, return_tensors="pt")
44
+
45
  # Convert the processed text to speech
 
46
  with torch.no_grad():
47
+ audio = tts_model(**inputs).waveform
 
 
48
 
49
  return audio
50