|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
|
|
|
|
model_name = "hassaanik/grammar-correction-model" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
model.half() |
|
|
|
|
|
def correct_grammar(text): |
|
|
|
inputs = tokenizer.encode(text, return_tensors="pt", max_length=512, truncation=True).to(device) |
|
|
|
|
|
outputs = model.generate(inputs, max_length=512, num_beams=5, early_stopping=True) |
|
|
|
|
|
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return corrected_text |
|
|
|
|
|
if __name__ == "__main__": |
|
sample_text = "He go to the market yesturday." |
|
corrected_text = correct_grammar(sample_text) |
|
print("Original Text:", sample_text) |
|
print("Corrected Text:", corrected_text) |
|
|