|
from transformers import T5ForConditionalGeneration, T5Tokenizer |
|
from cog import BasePredictor, Input |
|
|
|
class Predictor(BasePredictor): |
|
def setup(self): |
|
"""Load the model and tokenizer into memory to make running multiple predictions efficient""" |
|
self.model = T5ForConditionalGeneration.from_pretrained("aaurelions/t5-grammar-corrector") |
|
self.tokenizer = T5Tokenizer.from_pretrained("aaurelions/t5-grammar-corrector") |
|
|
|
def predict(self, text: str = Input(description="Text to correct")) -> str: |
|
"""Run a single prediction on the model""" |
|
input_text = "fix grammar: " + text |
|
input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids |
|
output_ids = self.model.generate(input_ids, max_length=128) |
|
corrected_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
return corrected_text |