Update translator.py
Browse files- translator.py +9 -2
translator.py
CHANGED
|
@@ -368,9 +368,16 @@ def handle_translation_request(request):
|
|
| 368 |
tokenized = tokenizer(input_text, return_tensors="pt", padding=True)
|
| 369 |
tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
|
| 370 |
|
| 371 |
-
# Generate translation
|
| 372 |
with torch.no_grad():
|
| 373 |
-
translated = model.generate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
|
| 375 |
# Decode the translation
|
| 376 |
result = tokenizer.decode(translated[0], skip_special_tokens=True)
|
|
|
|
| 368 |
tokenized = tokenizer(input_text, return_tensors="pt", padding=True)
|
| 369 |
tokenized = {k: v.to(model.device) for k, v in tokenized.items()}
|
| 370 |
|
|
|
|
| 371 |
with torch.no_grad():
|
| 372 |
+
translated = model.generate(
|
| 373 |
+
**tokenized,
|
| 374 |
+
max_length=100, # Reasonable output length
|
| 375 |
+
num_beams=4, # Same as in training
|
| 376 |
+
length_penalty=0.6, # Same as in training
|
| 377 |
+
early_stopping=True, # Same as in training
|
| 378 |
+
repetition_penalty=1.5, # Add this to prevent repetition
|
| 379 |
+
no_repeat_ngram_size=3 # Add this to prevent repetition
|
| 380 |
+
)
|
| 381 |
|
| 382 |
# Decode the translation
|
| 383 |
result = tokenizer.decode(translated[0], skip_special_tokens=True)
|