shakii's picture
Update app.py
8990b6e verified
raw
history blame
1.81 kB
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # For development - you should restrict this in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model and tokenizer
model_name = "fakespot-ai/roberta-base-ai-text-detection-v1"
#model_name = "SuperAnnotate/ai-detector"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
class TextRequest(BaseModel):
text: str
@app.post("/predict")
async def predict(request: TextRequest):
try:
# Tokenize the input text
inputs = tokenizer(request.text, return_tensors="pt", truncation=True, max_length=512)
# Make prediction
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Get the probability scores
human_prob = predictions[0][0].item()
ai_prob = predictions[0][1].item()
return {
"text": request.text,
"human_probability": round(human_prob * 100, 2),
"ai_probability": round(ai_prob * 100, 2),
"prediction": "AI-generated" if ai_prob > human_prob else "Human-written"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
return {"message": "AI Text Detection API is running"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)