shakii's picture
Update app.py
4fa25e4 verified
raw
history blame
1.81 kB
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # For development - you should restrict this in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model and tokenizer
#model_name = "fakespot-ai/roberta-base-ai-text-detection-v1"
model_name = "SuperAnnotate/ai-detector"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
class TextRequest(BaseModel):
text: str
@app.post("/predict")
async def predict(request: TextRequest):
try:
# Tokenize the input text
inputs = tokenizer(request.text, return_tensors="pt", truncation=True, max_length=512)
# Make prediction
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Get the probability scores
human_prob = predictions[0][0].item()
ai_prob = predictions[0][1].item()
return {
"text": request.text,
"human_probability": round(human_prob * 100, 2),
"ai_probability": round(ai_prob * 100, 2),
"prediction": "AI-generated" if ai_prob > human_prob else "Human-written"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
return {"message": "AI Text Detection API is running"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)