Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,209 Bytes
c1ee666 d5c312e c1ee666 d5c312e c1ee666 d5c312e c1ee666 d5c312e c1ee666 d5c312e c1ee666 d5c312e c1ee666 d5c312e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
from transformers import MBartForConditionalGeneration, AutoTokenizer
verbalizer_model_name = "skypro1111/mbart-large-50-verbalization"
class Verbalizer():
def __init__(self, device):
self.device = device
self.model = MBartForConditionalGeneration.from_pretrained(verbalizer_model_name,
low_cpu_mem_usage=True,
device_map=device,
)
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(verbalizer_model_name)
self.tokenizer.src_lang = "uk_XX"
self.tokenizer.tgt_lang = "uk_XX"
def generate_text(self, text):
"""Generate text for a single input."""
# Prepare input
input_text = "<verbalization>:" + text
encoded_input = self.tokenizer(
input_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024,
).to(self.device)
output_ids = self.model.generate(
**encoded_input, max_length=1024, num_beams=5, early_stopping=True
)
normalized_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
return normalized_text.strip() |