yasserrmd commited on
Commit
2cc9d7b
·
verified ·
1 Parent(s): c2b01a1

Update generate_audio.py

Browse files
Files changed (1) hide show
  1. generate_audio.py +18 -3
generate_audio.py CHANGED
@@ -71,9 +71,24 @@ class TTSGenerator:
71
  np.array: Audio array.
72
  int: Sampling rate.
73
  """
74
- input_ids = self.parler_tokenizer(self.speaker1_description, return_tensors="pt").input_ids.to(self.device)
75
- prompt_input_ids = self.parler_tokenizer(text, return_tensors="pt").input_ids.to(self.device)
76
- generation = self.parler_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  audio_arr = generation.cpu().numpy().squeeze()
78
  return audio_arr, self.parler_model.config.sampling_rate
79
 
 
71
  np.array: Audio array.
72
  int: Sampling rate.
73
  """
74
+ # input_ids = self.parler_tokenizer(self.speaker1_description, return_tensors="pt").input_ids.to(self.device)
75
+ # prompt_input_ids = self.parler_tokenizer(text, return_tensors="pt").input_ids.to(self.device)
76
+ # generation = self.parler_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
77
+ # audio_arr = generation.cpu().numpy().squeeze()
78
+ # return audio_arr, self.parler_model.config.sampling_rate
79
+ input_ids = self.parler_tokenizer(self.speaker1_description, return_tensors="pt", padding=True).input_ids.to(self.device)
80
+ attention_mask_input = self.parler_tokenizer(self.speaker1_description, return_tensors="pt", padding=True).attention_mask.to(self.device)
81
+
82
+ prompt_input_ids = self.parler_tokenizer(text, return_tensors="pt", padding=True).input_ids.to(self.device)
83
+ attention_mask_prompt = self.parler_tokenizer(text, return_tensors="pt", padding=True).attention_mask.to(self.device)
84
+
85
+ # Generate audio with input IDs and attention masks
86
+ generation = self.parler_model.generate(
87
+ input_ids=input_ids,
88
+ attention_mask=attention_mask_input,
89
+ prompt_input_ids=prompt_input_ids,
90
+ prompt_attention_mask=attention_mask_prompt
91
+ )
92
  audio_arr = generation.cpu().numpy().squeeze()
93
  return audio_arr, self.parler_model.config.sampling_rate
94