Update translator.py
Browse files- translator.py +9 -2
translator.py
CHANGED
@@ -368,9 +368,16 @@ def handle_translation_request(request):
|
|
368 |
tokenized = tokenizer(input_text, return_tensors="pt", padding=True)
|
369 |
tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
|
370 |
|
371 |
-
# Generate translation
|
372 |
with torch.no_grad():
|
373 |
-
translated = model.generate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
374 |
|
375 |
# Decode the translation
|
376 |
result = tokenizer.decode(translated[0], skip_special_tokens=True)
|
|
|
368 |
tokenized = tokenizer(input_text, return_tensors="pt", padding=True)
|
369 |
tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
|
370 |
|
|
|
371 |
with torch.no_grad():
|
372 |
+
translated = model.generate(
|
373 |
+
**tokenized,
|
374 |
+
max_length=100, # Reasonable output length
|
375 |
+
num_beams=4, # Same as in training
|
376 |
+
length_penalty=0.6, # Same as in training
|
377 |
+
early_stopping=True, # Same as in training
|
378 |
+
repetition_penalty=1.5, # Add this to prevent repetition
|
379 |
+
no_repeat_ngram_size=3 # Add this to prevent repetition
|
380 |
+
)
|
381 |
|
382 |
# Decode the translation
|
383 |
result = tokenizer.decode(translated[0], skip_special_tokens=True)
|