import numpy as np
import torch
from transformers import pipeline
from transformers import VitsModel, VitsTokenizer

device = 'cuda' if torch.cuda.is_available() else 'cpu'
pipe = pipeline(
    'automatic-speech-recognition', model='openai/whisper-base', device=device,
)

model = VitsModel.from_pretrained("facebook/mms-tts-eng")
tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")

target_dtype=np.int16
max_range = np.iinfo(target_dtype).max

def speech_to_speech_translation(filepath):
  translation = pipe(filepath, max_new_tokens=256, generate_kwargs={'task': 'translate'})['text']
  
  inputs = tokenizer(translation, return_tensors="pt")
  input_ids = inputs["input_ids"]

  model.eval()
  with torch.inference_mode():
      outputs = model(input_ids)

  speech = outputs["waveform"]
  synthesised_speech = speech / torch.max(torch.abs(speech))  # Normaliza para [-1, 1]
  synthesised_speech = (speech * max_range).numpy().astype(target_dtype)

  return (16000, synthesised_speech.squeeze()), translation