File size: 5,585 Bytes
35208ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
#!/usr/bin/env python3
"""
T5 Detoxification API for Hugging Face Spaces
FastAPI service that can be called from external WebSocket servers
"""
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import logging
import time
import os
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="T5 Detoxification API", version="1.0.0")
class TextRequest(BaseModel):
text: str
max_length: int = 256
class TextResponse(BaseModel):
original_text: str
detoxified_text: str
processing_time: float
device: str
class T5Service:
def __init__(self):
self.model = None
self.tokenizer = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.loaded = False
self.load_model()
def load_model(self):
"""Load T5 detoxification model"""
try:
logger.info(f"Loading T5 model on {self.device}...")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained('s-nlp/t5-paranmt-detox')
logger.info("Tokenizer loaded")
# Load model with optimization
self.model = AutoModelForSeq2SeqLM.from_pretrained(
's-nlp/t5-paranmt-detox',
torch_dtype=torch.float16 if self.device.type == 'cuda' else torch.float32,
low_cpu_mem_usage=True
)
# Move to device and optimize
self.model = self.model.to(self.device)
self.model.eval()
# Try torch.compile for better performance
try:
if torch.__version__.startswith("2"):
self.model = torch.compile(self.model, mode="reduce-overhead")
logger.info("Model compiled with torch.compile()")
except Exception as e:
logger.warning(f"torch.compile failed: {e}")
self.loaded = True
logger.info(f"T5 model loaded successfully on {self.device}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
self.loaded = False
def detoxify_text(self, text: str, max_length: int = 256) -> str:
"""Detoxify text using T5 model"""
if not self.loaded or not text.strip():
return text
try:
# Tokenize
inputs = self.tokenizer(
text.strip(),
return_tensors="pt",
truncation=True,
max_length=max_length
)
inputs = inputs.to(self.device)
# Generate detoxified text
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=max_length,
num_beams=1,
do_sample=False,
early_stopping=True
)
# Decode
detoxified = self.tokenizer.decode(
outputs[0],
skip_special_tokens=True
).strip()
return detoxified if detoxified else text
except Exception as e:
logger.error(f"Error in detoxification: {e}")
return text
# Initialize the service
t5_service = T5Service()
@app.get("/")
async def root():
"""Health check endpoint"""
return {
"message": "T5 Detoxification API",
"status": "running",
"model_loaded": t5_service.loaded,
"device": str(t5_service.device)
}
@app.get("/health")
async def health_check():
"""Detailed health check"""
return {
"status": "healthy" if t5_service.loaded else "unhealthy",
"model_loaded": t5_service.loaded,
"device": str(t5_service.device),
"timestamp": time.time()
}
@app.post("/detoxify", response_model=TextResponse)
async def detoxify_text(request: TextRequest):
"""Detoxify text using T5 model"""
if not request.text.strip():
raise HTTPException(status_code=400, detail="Text cannot be empty")
if not t5_service.loaded:
raise HTTPException(status_code=503, detail="T5 model not loaded")
start_time = time.time()
try:
detoxified_text = t5_service.detoxify_text(
request.text,
request.max_length
)
processing_time = time.time() - start_time
return TextResponse(
original_text=request.text,
detoxified_text=detoxified_text,
processing_time=round(processing_time, 3),
device=str(t5_service.device)
)
except Exception as e:
logger.error(f"Error processing request: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@app.get("/status")
async def get_status():
"""Get service status"""
return {
"model_loaded": t5_service.loaded,
"device": str(t5_service.device),
"uptime": time.time()
}
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port) |