puppala13 commited on
Commit
dfb8d4c
·
verified ·
1 Parent(s): 59a20a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -22,13 +22,14 @@ def main():
22
 
23
  def translate_text(input_text, model, tokenizer):
24
  # Tokenize input text
25
- input_ids = tokenizer(input_text, return_tensors="pt").input_ids
26
-
27
- # Generate translation
28
- translated_ids = model.generate(input_ids)
 
29
 
30
  # Decode translated text
31
- translated_text = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
32
 
33
  return translated_text
34
 
 
22
 
23
  def translate_text(input_text, model, tokenizer):
24
  # Tokenize input text
25
+ model_inputs = tokenizer(input_text, return_tensors="pt").input_ids
26
+ generated_tokens = model.generate(
27
+ **model_inputs,
28
+ forced_bos_token_id=tokenizer.lang_code_to_id["hi_IN"]
29
+ )
30
 
31
  # Decode translated text
32
+ translated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
33
 
34
  return translated_text
35