sashtech's picture
Update app.py
69fcd05 verified
raw
history blame
3.3 kB
import os
import subprocess
import sys
import gradio as gr
from transformers import pipeline
import spacy
import nltk
from nltk.corpus import wordnet
# Function to install GECToR
def install_gector():
if not os.path.exists('gector'):
print("Cloning GECToR repository...")
subprocess.run(["git", "clone", "https://github.com/grammarly/gector.git"], check=True)
# Install dependencies from GECToR requirements
subprocess.run([sys.executable, "-m", "pip", "install", "-r", "gector/requirements.txt"], check=True)
# Manually add GECToR to the Python path
sys.path.append(os.path.abspath('gector'))
# Install and import GECToR
install_gector()
from gector.gec_model import GecBERTModel
# Initialize GECToR model for grammar correction
gector_model = GecBERTModel(vocab_path='gector/data/output_vocabulary',
model_paths=['https://grammarly-nlp-data.s3.amazonaws.com/gector/roberta_1_gector.th'],
is_ensemble=False)
# Initialize the English text classification pipeline for AI detection
pipeline_en = pipeline(task="text-classification", model="Hello-SimpleAI/chatgpt-detector-roberta")
# Function to predict the label and score for English text (AI Detection)
def predict_en(text):
res = pipeline_en(text)[0]
return res['label'], res['score']
# Ensure necessary NLTK data is downloaded for Humanifier
nltk.download('wordnet')
nltk.download('omw-1.4')
# Ensure the SpaCy model is installed for Humanifier
try:
nlp = spacy.load("en_core_web_sm")
except OSError:
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
nlp = spacy.load("en_core_web_sm")
# Function to correct grammar using GECToR
def correct_grammar_with_gector(text):
corrected_sentences = []
sentences = [text]
for sentence in sentences:
preds = gector_model.handle_batch([sentence])
corrected_sentences.append(preds[0])
return ' '.join(corrected_sentences)
# Gradio app setup with three tabs
with gr.Blocks() as demo:
with gr.Tab("AI Detection"):
t1 = gr.Textbox(lines=5, label='Text')
button1 = gr.Button("πŸ€– Predict!")
label1 = gr.Textbox(lines=1, label='Predicted Label πŸŽƒ')
score1 = gr.Textbox(lines=1, label='Prob')
# Connect the prediction function to the button
button1.click(predict_en, inputs=[t1], outputs=[label1, score1], api_name='predict_en')
with gr.Tab("Humanifier"):
text_input = gr.Textbox(lines=5, label="Input Text")
paraphrase_button = gr.Button("Paraphrase & Correct")
output_text = gr.Textbox(label="Paraphrased Text")
# Connect the paraphrasing function to the button
paraphrase_button.click(correct_grammar_with_gector, inputs=text_input, outputs=output_text)
with gr.Tab("Grammar Correction"):
grammar_input = gr.Textbox(lines=5, label="Input Text")
grammar_button = gr.Button("Correct Grammar")
grammar_output = gr.Textbox(label="Corrected Text")
# Connect the GECToR grammar correction function to the button
grammar_button.click(correct_grammar_with_gector, inputs=grammar_input, outputs=grammar_output)
# Launch the app with all functionalities
demo.launch()