sashtech's picture
Update app.py
01902a2 verified
raw
history blame
3.66 kB
import os
import gradio as gr
from transformers import pipeline
import spacy
import subprocess
import nltk
from nltk.corpus import wordnet
from spellchecker import SpellChecker
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
import uuid # To generate unique link IDs
# Initialize FastAPI app
api_app = FastAPI()
# Initialize the English text classification pipeline for AI detection
pipeline_en = pipeline(task="text-classification", model="Hello-SimpleAI/chatgpt-detector-roberta")
# Initialize the spell checker
spell = SpellChecker()
# Ensure necessary NLTK data is downloaded
nltk.download('wordnet')
nltk.download('omw-1.4')
# Ensure the SpaCy model is installed
try:
nlp = spacy.load("en_core_web_sm")
except OSError:
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
nlp = spacy.load("en_core_web_sm")
# Generate temporary link storage (could be database or in-memory store)
temporary_links = {}
# Define request models for FastAPI
class TextRequest(BaseModel):
text: str
# Function to predict the label and score for English text (AI Detection)
def predict_en(text):
res = pipeline_en(text)[0]
return res['label'], res['score']
# Function to paraphrase and correct grammar with enhanced accuracy
def paraphrase_and_correct(text):
# Here should go all the paraphrasing and grammar correction logic.
return text # For now just return the input
# API Endpoint to create a new temporary link for Gradio interface
@api_app.post("/generate-link/")
async def generate_temporary_link(task: str):
# Check if the task is either 'ai-detection' or 'paraphrase'
if task not in ["ai-detection", "paraphrase"]:
raise HTTPException(status_code=400, detail="Invalid task type.")
# Create a unique link using UUID
link_id = str(uuid.uuid4())
# Set up Gradio interface based on task
if task == "ai-detection":
with gr.Blocks() as demo:
t1 = gr.Textbox(lines=5, label='Text')
button1 = gr.Button("πŸ€– Predict!")
label1 = gr.Textbox(lines=1, label='Predicted Label πŸŽƒ')
score1 = gr.Textbox(lines=1, label='Prob')
# Connect the prediction function to the button
button1.click(fn=predict_en, inputs=t1, outputs=[label1, score1])
elif task == "paraphrase":
with gr.Blocks() as demo:
t2 = gr.Textbox(lines=5, label='Enter text for paraphrasing and grammar correction')
button2 = gr.Button("πŸ”„ Paraphrase and Correct")
result2 = gr.Textbox(lines=10, label='Corrected Text', placeholder="The corrected text will appear here...")
# Connect the paraphrasing and correction function to the button
button2.click(fn=paraphrase_and_correct, inputs=t2, outputs=result2)
# Launch Gradio and get the link
demo_url = demo.launch(share=True, prevent_thread_lock=True)
# Save the generated link in memory (temporary)
temporary_links[link_id] = {"task": task, "url": demo_url}
# Return the link to the user
return {"link_id": link_id, "url": demo_url}
# API Endpoint to get the status or result via the generated link
@api_app.get("/get-link/{link_id}")
async def get_temporary_link(link_id: str):
# Check if the link exists
if link_id not in temporary_links:
raise HTTPException(status_code=404, detail="Link not found.")
# Retrieve the link details
link_details = temporary_links[link_id]
return {"link": link_details["url"]}
# Run the FastAPI app
if __name__ == "__main__":
uvicorn.run(api_app, host="0.0.0.0", port=8000)