sashtech commited on
Commit
504dc4a
·
verified ·
1 Parent(s): ee971be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -1,12 +1,11 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
  import spacy
5
  import subprocess
6
  import nltk
7
  from nltk.corpus import wordnet
8
  from gensim import downloader as api
9
- from textblob import TextBlob # Import TextBlob for simple grammar correction
10
 
11
  # Ensure necessary NLTK data is downloaded
12
  nltk.download('wordnet')
@@ -29,6 +28,9 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
  tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
30
  model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english").to(device)
31
 
 
 
 
32
  # AI detection function using DistilBERT
33
  def detect_ai_generated(text):
34
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
@@ -76,10 +78,9 @@ def paraphrase_with_spacy_nltk(text):
76
 
77
  return paraphrased_sentence
78
 
79
- # Grammar correction function using TextBlob
80
  def correct_grammar(text):
81
- blob = TextBlob(text)
82
- corrected_text = str(blob.correct())
83
  return corrected_text
84
 
85
  # Combined function: Paraphrase -> Grammar Check
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
3
  import torch
4
  import spacy
5
  import subprocess
6
  import nltk
7
  from nltk.corpus import wordnet
8
  from gensim import downloader as api
 
9
 
10
  # Ensure necessary NLTK data is downloaded
11
  nltk.download('wordnet')
 
28
  tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
29
  model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english").to(device)
30
 
31
+ # Load grammar correction model from Hugging Face
32
+ grammar_corrector = pipeline("text2text-generation", model="prithivida/grammar-error-correction", device=0 if torch.cuda.is_available() else -1)
33
+
34
  # AI detection function using DistilBERT
35
  def detect_ai_generated(text):
36
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
 
78
 
79
  return paraphrased_sentence
80
 
81
+ # Grammar correction function using Hugging Face grammar correction model
82
  def correct_grammar(text):
83
+ corrected_text = grammar_corrector(text)[0]['generated_text']
 
84
  return corrected_text
85
 
86
  # Combined function: Paraphrase -> Grammar Check