t5-grammar-corrector / predict.py
aaurelions's picture
Add: python package
885e17b
raw
history blame
884 Bytes
from transformers import T5ForConditionalGeneration, T5Tokenizer
from cog import BasePredictor, Input
class Predictor(BasePredictor):
def setup(self):
"""Load the model and tokenizer into memory to make running multiple predictions efficient"""
self.model = T5ForConditionalGeneration.from_pretrained("aaurelions/t5-grammar-corrector")
self.tokenizer = T5Tokenizer.from_pretrained("aaurelions/t5-grammar-corrector")
def predict(self, text: str = Input(description="Text to correct")) -> str:
"""Run a single prediction on the model"""
input_text = "fix grammar: " + text
input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids
output_ids = self.model.generate(input_ids, max_length=128)
corrected_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
return corrected_text