#!/usr/bin/env python3 """ T5 Detoxification API for Hugging Face Spaces FastAPI service that can be called from external WebSocket servers """ from fastapi import FastAPI, HTTPException from pydantic import BaseModel import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import logging import time import os # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="T5 Detoxification API", version="1.0.0") class TextRequest(BaseModel): text: str max_length: int = 256 class TextResponse(BaseModel): original_text: str detoxified_text: str processing_time: float device: str class T5Service: def __init__(self): self.model = None self.tokenizer = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.loaded = False self.load_model() def load_model(self): """Load T5 detoxification model""" try: logger.info(f"Loading T5 model on {self.device}...") # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained('s-nlp/t5-paranmt-detox') logger.info("Tokenizer loaded") # Load model with optimization self.model = AutoModelForSeq2SeqLM.from_pretrained( 's-nlp/t5-paranmt-detox', torch_dtype=torch.float16 if self.device.type == 'cuda' else torch.float32, low_cpu_mem_usage=True ) # Move to device and optimize self.model = self.model.to(self.device) self.model.eval() # Try torch.compile for better performance try: if torch.__version__.startswith("2"): self.model = torch.compile(self.model, mode="reduce-overhead") logger.info("Model compiled with torch.compile()") except Exception as e: logger.warning(f"torch.compile failed: {e}") self.loaded = True logger.info(f"T5 model loaded successfully on {self.device}") except Exception as e: logger.error(f"Failed to load model: {e}") self.loaded = False def detoxify_text(self, text: str, max_length: int = 256) -> str: """Detoxify text using T5 model""" if not self.loaded or not text.strip(): return text try: # Tokenize inputs = self.tokenizer( text.strip(), return_tensors="pt", truncation=True, max_length=max_length ) inputs = inputs.to(self.device) # Generate detoxified text with torch.no_grad(): outputs = self.model.generate( **inputs, max_length=max_length, num_beams=1, do_sample=False, early_stopping=True ) # Decode detoxified = self.tokenizer.decode( outputs[0], skip_special_tokens=True ).strip() return detoxified if detoxified else text except Exception as e: logger.error(f"Error in detoxification: {e}") return text # Initialize the service t5_service = T5Service() @app.get("/") async def root(): """Health check endpoint""" return { "message": "T5 Detoxification API", "status": "running", "model_loaded": t5_service.loaded, "device": str(t5_service.device) } @app.get("/health") async def health_check(): """Detailed health check""" return { "status": "healthy" if t5_service.loaded else "unhealthy", "model_loaded": t5_service.loaded, "device": str(t5_service.device), "timestamp": time.time() } @app.post("/detoxify", response_model=TextResponse) async def detoxify_text(request: TextRequest): """Detoxify text using T5 model""" if not request.text.strip(): raise HTTPException(status_code=400, detail="Text cannot be empty") if not t5_service.loaded: raise HTTPException(status_code=503, detail="T5 model not loaded") start_time = time.time() try: detoxified_text = t5_service.detoxify_text( request.text, request.max_length ) processing_time = time.time() - start_time return TextResponse( original_text=request.text, detoxified_text=detoxified_text, processing_time=round(processing_time, 3), device=str(t5_service.device) ) except Exception as e: logger.error(f"Error processing request: {e}") raise HTTPException(status_code=500, detail="Internal server error") @app.get("/status") async def get_status(): """Get service status""" return { "model_loaded": t5_service.loaded, "device": str(t5_service.device), "uptime": time.time() } if __name__ == "__main__": import uvicorn port = int(os.getenv("PORT", 7860)) uvicorn.run(app, host="0.0.0.0", port=port)