File size: 884 Bytes
885e17b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from transformers import T5ForConditionalGeneration, T5Tokenizer
from cog import BasePredictor, Input

class Predictor(BasePredictor):
    def setup(self):
        """Load the model and tokenizer into memory to make running multiple predictions efficient"""
        self.model = T5ForConditionalGeneration.from_pretrained("aaurelions/t5-grammar-corrector")
        self.tokenizer = T5Tokenizer.from_pretrained("aaurelions/t5-grammar-corrector")

    def predict(self, text: str = Input(description="Text to correct")) -> str:
        """Run a single prediction on the model"""
        input_text = "fix grammar: " + text
        input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids
        output_ids = self.model.generate(input_ids, max_length=128)
        corrected_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
        return corrected_text