MominRuaf commited on
Commit
35208ed
·
verified ·
1 Parent(s): 4035ad1

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +177 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ T5 Detoxification API for Hugging Face Spaces
4
+ FastAPI service that can be called from external WebSocket servers
5
+ """
6
+
7
+ from fastapi import FastAPI, HTTPException
8
+ from pydantic import BaseModel
9
+ import torch
10
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
11
+ import logging
12
+ import time
13
+ import os
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ app = FastAPI(title="T5 Detoxification API", version="1.0.0")
20
+
21
+ class TextRequest(BaseModel):
22
+ text: str
23
+ max_length: int = 256
24
+
25
+ class TextResponse(BaseModel):
26
+ original_text: str
27
+ detoxified_text: str
28
+ processing_time: float
29
+ device: str
30
+
31
+ class T5Service:
32
+ def __init__(self):
33
+ self.model = None
34
+ self.tokenizer = None
35
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ self.loaded = False
37
+ self.load_model()
38
+
39
+ def load_model(self):
40
+ """Load T5 detoxification model"""
41
+ try:
42
+ logger.info(f"Loading T5 model on {self.device}...")
43
+
44
+ # Load tokenizer
45
+ self.tokenizer = AutoTokenizer.from_pretrained('s-nlp/t5-paranmt-detox')
46
+ logger.info("Tokenizer loaded")
47
+
48
+ # Load model with optimization
49
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(
50
+ 's-nlp/t5-paranmt-detox',
51
+ torch_dtype=torch.float16 if self.device.type == 'cuda' else torch.float32,
52
+ low_cpu_mem_usage=True
53
+ )
54
+
55
+ # Move to device and optimize
56
+ self.model = self.model.to(self.device)
57
+ self.model.eval()
58
+
59
+ # Try torch.compile for better performance
60
+ try:
61
+ if torch.__version__.startswith("2"):
62
+ self.model = torch.compile(self.model, mode="reduce-overhead")
63
+ logger.info("Model compiled with torch.compile()")
64
+ except Exception as e:
65
+ logger.warning(f"torch.compile failed: {e}")
66
+
67
+ self.loaded = True
68
+ logger.info(f"T5 model loaded successfully on {self.device}")
69
+
70
+ except Exception as e:
71
+ logger.error(f"Failed to load model: {e}")
72
+ self.loaded = False
73
+
74
+ def detoxify_text(self, text: str, max_length: int = 256) -> str:
75
+ """Detoxify text using T5 model"""
76
+ if not self.loaded or not text.strip():
77
+ return text
78
+
79
+ try:
80
+ # Tokenize
81
+ inputs = self.tokenizer(
82
+ text.strip(),
83
+ return_tensors="pt",
84
+ truncation=True,
85
+ max_length=max_length
86
+ )
87
+
88
+ inputs = inputs.to(self.device)
89
+
90
+ # Generate detoxified text
91
+ with torch.no_grad():
92
+ outputs = self.model.generate(
93
+ **inputs,
94
+ max_length=max_length,
95
+ num_beams=1,
96
+ do_sample=False,
97
+ early_stopping=True
98
+ )
99
+
100
+ # Decode
101
+ detoxified = self.tokenizer.decode(
102
+ outputs[0],
103
+ skip_special_tokens=True
104
+ ).strip()
105
+
106
+ return detoxified if detoxified else text
107
+
108
+ except Exception as e:
109
+ logger.error(f"Error in detoxification: {e}")
110
+ return text
111
+
112
+ # Initialize the service
113
+ t5_service = T5Service()
114
+
115
+ @app.get("/")
116
+ async def root():
117
+ """Health check endpoint"""
118
+ return {
119
+ "message": "T5 Detoxification API",
120
+ "status": "running",
121
+ "model_loaded": t5_service.loaded,
122
+ "device": str(t5_service.device)
123
+ }
124
+
125
+ @app.get("/health")
126
+ async def health_check():
127
+ """Detailed health check"""
128
+ return {
129
+ "status": "healthy" if t5_service.loaded else "unhealthy",
130
+ "model_loaded": t5_service.loaded,
131
+ "device": str(t5_service.device),
132
+ "timestamp": time.time()
133
+ }
134
+
135
+ @app.post("/detoxify", response_model=TextResponse)
136
+ async def detoxify_text(request: TextRequest):
137
+ """Detoxify text using T5 model"""
138
+ if not request.text.strip():
139
+ raise HTTPException(status_code=400, detail="Text cannot be empty")
140
+
141
+ if not t5_service.loaded:
142
+ raise HTTPException(status_code=503, detail="T5 model not loaded")
143
+
144
+ start_time = time.time()
145
+
146
+ try:
147
+ detoxified_text = t5_service.detoxify_text(
148
+ request.text,
149
+ request.max_length
150
+ )
151
+
152
+ processing_time = time.time() - start_time
153
+
154
+ return TextResponse(
155
+ original_text=request.text,
156
+ detoxified_text=detoxified_text,
157
+ processing_time=round(processing_time, 3),
158
+ device=str(t5_service.device)
159
+ )
160
+
161
+ except Exception as e:
162
+ logger.error(f"Error processing request: {e}")
163
+ raise HTTPException(status_code=500, detail="Internal server error")
164
+
165
+ @app.get("/status")
166
+ async def get_status():
167
+ """Get service status"""
168
+ return {
169
+ "model_loaded": t5_service.loaded,
170
+ "device": str(t5_service.device),
171
+ "uptime": time.time()
172
+ }
173
+
174
+ if __name__ == "__main__":
175
+ import uvicorn
176
+ port = int(os.getenv("PORT", 7860))
177
+ uvicorn.run(app, host="0.0.0.0", port=port)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi>=0.100.0
2
+ uvicorn>=0.20.0
3
+ pydantic>=2.0.0
4
+ torch>=2.0.0
5
+ transformers>=4.36.0
6
+ accelerate>=0.20.0
7
+ sentencepiece>=0.1.99