sooks commited on
Commit
3118b47
·
1 Parent(s): 3c01444

Create server.py

Browse files
Files changed (1) hide show
  1. detector/server.py +155 -0
detector/server.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from http.server import HTTPServer, SimpleHTTPRequestHandler
4
+ from multiprocessing import Process
5
+ import subprocess
6
+ from transformers import RobertaForSequenceClassification, RobertaTokenizer
7
+ import json
8
+ import fire
9
+ import torch
10
+ import re
11
+ from urllib.parse import urlparse, unquote, parse_qs, urlencode
12
+
13
+ model: RobertaForSequenceClassification = None
14
+ tokenizer: RobertaTokenizer = None
15
+ device: str = None
16
+
17
+ # Remove spaces query params from query
18
+ regex = r"__theme=(.+)"
19
+
20
+
21
+ def log(*args):
22
+ print(f"[{os.environ.get('RANK', '')}]", *args, file=sys.stderr)
23
+
24
+
25
+ class RequestHandler(SimpleHTTPRequestHandler):
26
+
27
+ def do_POST(self):
28
+ self.begin_content('application/json,charset=UTF-8')
29
+
30
+ content_length = int(self.headers['Content-Length'])
31
+ if content_length > 0:
32
+ post_data = self.rfile.read(content_length).decode('utf-8')
33
+ try:
34
+ post_data = json.loads(post_data)
35
+
36
+ if 'text' not in post_data:
37
+ self.wfile.write(json.dumps({"error": "missing key 'text'"}).encode('utf-8'))
38
+ else:
39
+ all_tokens, used_tokens, fake, real = self.infer(post_data['text'])
40
+
41
+ self.wfile.write(json.dumps(dict(
42
+ all_tokens=all_tokens,
43
+ used_tokens=used_tokens,
44
+ real_probability=real,
45
+ fake_probability=fake
46
+ )).encode('utf-8'))
47
+
48
+ except Exception as e:
49
+ self.wfile.write(json.dumps({"error": str(e)}).encode('utf-8'))
50
+
51
+ def do_GET(self):
52
+ query = urlparse(self.path).query
53
+ query = re.sub(regex, "", query, 0, re.MULTILINE)
54
+ query = unquote(query)
55
+
56
+ if not query:
57
+ self.begin_content('text/html')
58
+
59
+ html = os.path.join(os.path.dirname(__file__), 'index.html')
60
+ self.wfile.write(open(html).read().encode())
61
+ return
62
+
63
+ self.begin_content('application/json;charset=UTF-8')
64
+
65
+ all_tokens, used_tokens, fake, real = self.infer(query)
66
+
67
+ self.wfile.write(json.dumps(dict(
68
+ all_tokens=all_tokens,
69
+ used_tokens=used_tokens,
70
+ real_probability=real,
71
+ fake_probability=fake
72
+ )).encode())
73
+
74
+ def infer(self, query):
75
+ tokens = tokenizer.encode(query)
76
+ all_tokens = len(tokens)
77
+ tokens = tokens[:tokenizer.max_len - 2]
78
+ used_tokens = len(tokens)
79
+ tokens = torch.tensor([tokenizer.bos_token_id] + tokens + [tokenizer.eos_token_id]).unsqueeze(0)
80
+ mask = torch.ones_like(tokens)
81
+
82
+ with torch.no_grad():
83
+ logits = model(tokens.to(device), attention_mask=mask.to(device))[0]
84
+ probs = logits.softmax(dim=-1)
85
+
86
+ fake, real = probs.detach().cpu().flatten().numpy().tolist()
87
+
88
+ return all_tokens, used_tokens, fake, real
89
+
90
+ def begin_content(self, content_type):
91
+ self.send_response(200)
92
+ self.send_header('Content-Type', content_type)
93
+ self.send_header('Access-Control-Allow-Origin', '*')
94
+ self.end_headers()
95
+
96
+ def log_message(self, format, *args):
97
+ log(format % args)
98
+
99
+
100
+ def serve_forever(server, model, tokenizer, device):
101
+ log('Process has started; loading the model ...')
102
+ globals()['model'] = model.to(device)
103
+ globals()['tokenizer'] = tokenizer
104
+ globals()['device'] = device
105
+
106
+ log(f'Ready to serve at http://localhost:{server.server_address[1]}')
107
+ server.serve_forever()
108
+
109
+
110
+ def main(checkpoint, port=8080, device='cuda' if torch.cuda.is_available() else 'cpu'):
111
+ if checkpoint.startswith('gs://'):
112
+ print(f'Downloading {checkpoint}', file=sys.stderr)
113
+ subprocess.check_output(['gsutil', 'cp', checkpoint, '.'])
114
+ checkpoint = os.path.basename(checkpoint)
115
+ assert os.path.isfile(checkpoint)
116
+
117
+ print(f'Loading checkpoint from {checkpoint}')
118
+ data = torch.load(checkpoint, map_location='cpu')
119
+
120
+ model_name = 'roberta-large' if data['args']['large'] else 'roberta-base'
121
+ model = RobertaForSequenceClassification.from_pretrained(model_name)
122
+ tokenizer = RobertaTokenizer.from_pretrained(model_name)
123
+
124
+ model.load_state_dict(data['model_state_dict'])
125
+ model.eval()
126
+
127
+ print(f'Starting HTTP server on port {port}', file=sys.stderr)
128
+ server = HTTPServer(('0.0.0.0', port), RequestHandler)
129
+
130
+ # avoid calling CUDA API before forking; doing so in a subprocess is fine.
131
+ num_workers = int(subprocess.check_output([sys.executable, '-c', 'import torch; print(torch.cuda.device_count())']))
132
+
133
+ if num_workers <= 1:
134
+ serve_forever(server, model, tokenizer, device)
135
+ else:
136
+ print(f'Launching {num_workers} worker processes...')
137
+
138
+ subprocesses = []
139
+
140
+ for i in range(num_workers):
141
+ os.environ['RANK'] = f'{i}'
142
+ os.environ['CUDA_VISIBLE_DEVICES'] = f'{i}'
143
+ process = Process(target=serve_forever, args=(server, model, tokenizer, device))
144
+ process.start()
145
+ subprocesses.append(process)
146
+
147
+ del os.environ['RANK']
148
+ del os.environ['CUDA_VISIBLE_DEVICES']
149
+
150
+ for process in subprocesses:
151
+ process.join()
152
+
153
+
154
+ if __name__ == '__main__':
155
+ fire.Fire(main)