MoritzLaurer commited on
Commit
01f35e8
·
verified ·
1 Parent(s): 479ac18

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +16 -15
handler.py CHANGED
@@ -14,20 +14,21 @@ class EndpointHandler:
14
  self.tokenizer = AutoTokenizer.from_pretrained(path)
15
  #self.feature_extractor = AutoFeatureExtractor.from_pretrained(path)
16
  self.model = ParlerTTSForConditionalGeneration.from_pretrained(path).to(device) #torch_dtype=torch.float16
 
17
 
18
- def preprocess_text(self, text):
19
- """Implement the same preprocessing as the Gradio app"""
20
- text = self.number_normalizer(text).strip()
21
- text = text.replace("-", " ")
22
- if text[-1] not in punctuation:
23
- text = f"{text}."
24
-
25
- abbreviations_pattern = r'\b[A-Z][A-Z\.]+\b'
26
- abbreviations = re.findall(abbreviations_pattern, text)
27
- for abv in abbreviations:
28
- if abv in text:
29
- text = text.replace(abv, " ".join(abv.replace(".","")))
30
- return text
31
 
32
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
33
  """
@@ -62,8 +63,8 @@ class EndpointHandler:
62
  with torch.autocast(device):
63
  outputs = self.model.generate(
64
  **voice_description, prompt_input_ids=inputs.input_ids,
65
- prompt_attention_mask=inputs.attention_mask, attention_mask=inputs.attention_mask,
66
- **parameters
67
  )
68
 
69
  # postprocess the prediction
 
14
  self.tokenizer = AutoTokenizer.from_pretrained(path)
15
  #self.feature_extractor = AutoFeatureExtractor.from_pretrained(path)
16
  self.model = ParlerTTSForConditionalGeneration.from_pretrained(path).to(device) #torch_dtype=torch.float16
17
+ self.number_normalizer = EnglishNumberNormalizer() # Initialize number normalizer
18
 
19
+ def preprocess_text(self, text):
20
+ """Implement the same preprocessing as the Gradio app"""
21
+ text = self.number_normalizer(text).strip()
22
+ text = text.replace("-", " ")
23
+ if text[-1] not in punctuation:
24
+ text = f"{text}."
25
+
26
+ abbreviations_pattern = r'\b[A-Z][A-Z\.]+\b'
27
+ abbreviations = re.findall(abbreviations_pattern, text)
28
+ for abv in abbreviations:
29
+ if abv in text:
30
+ text = text.replace(abv, " ".join(abv.replace(".","")))
31
+ return text
32
 
33
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
34
  """
 
63
  with torch.autocast(device):
64
  outputs = self.model.generate(
65
  **voice_description, prompt_input_ids=inputs.input_ids,
66
+ prompt_attention_mask=voice_description.attention_mask, attention_mask=inputs.attention_mask,
67
+ **gen_kwargs
68
  )
69
 
70
  # postprocess the prediction