AnwenHu commited on
Commit
7914e1b
·
verified ·
1 Parent(s): d39ff0f

Delete mplug_docowl/local_serve

Browse files
mplug_docowl/local_serve/__init__.py DELETED
File without changes
mplug_docowl/local_serve/examples/Rebecca_(1939_poster)_Small.jpeg DELETED
Binary file (18.9 kB)
 
mplug_docowl/local_serve/examples/extreme_ironing.jpg DELETED
Binary file (62.6 kB)
 
mplug_docowl/local_serve/local_web_server.py DELETED
@@ -1,392 +0,0 @@
1
- import argparse
2
- import datetime
3
- import json
4
- import os
5
- import time
6
-
7
- import gradio as gr
8
- import requests
9
-
10
- from mplug_owl2.conversation import (default_conversation, conv_templates,
11
- SeparatorStyle)
12
- from mplug_owl2.constants import LOGDIR
13
- from mplug_owl2.utils import (build_logger, server_error_msg,
14
- violates_moderation, moderation_msg)
15
- from .model_worker import ModelWorker
16
- import hashlib
17
-
18
- logger = build_logger("gradio_web_server_local", "gradio_web_server_local.log")
19
-
20
- headers = {"User-Agent": "mPLUG-Owl2 Client"}
21
-
22
- no_change_btn = gr.Button.update()
23
- enable_btn = gr.Button.update(interactive=True)
24
- disable_btn = gr.Button.update(interactive=False)
25
-
26
- def get_conv_log_filename():
27
- t = datetime.datetime.now()
28
- name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
29
- return name
30
-
31
- get_window_url_params = """
32
- function() {
33
- const params = new URLSearchParams(window.location.search);
34
- url_params = Object.fromEntries(params);
35
- console.log(url_params);
36
- return url_params;
37
- }
38
- """
39
-
40
-
41
- def load_demo(url_params, request: gr.Request):
42
- logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
43
- state = default_conversation.copy()
44
- return state
45
-
46
-
47
- def vote_last_response(state, vote_type, request: gr.Request):
48
- with open(get_conv_log_filename(), "a") as fout:
49
- data = {
50
- "tstamp": round(time.time(), 4),
51
- "type": vote_type,
52
- "state": state.dict(),
53
- "ip": request.client.host,
54
- }
55
- fout.write(json.dumps(data) + "\n")
56
-
57
-
58
- def upvote_last_response(state, request: gr.Request):
59
- logger.info(f"upvote. ip: {request.client.host}")
60
- vote_last_response(state, "upvote", request)
61
- return ("",) + (disable_btn,) * 3
62
-
63
-
64
- def downvote_last_response(state, request: gr.Request):
65
- logger.info(f"downvote. ip: {request.client.host}")
66
- vote_last_response(state, "downvote", request)
67
- return ("",) + (disable_btn,) * 3
68
-
69
-
70
- def flag_last_response(state, request: gr.Request):
71
- logger.info(f"flag. ip: {request.client.host}")
72
- vote_last_response(state, "flag", request)
73
- return ("",) + (disable_btn,) * 3
74
-
75
-
76
- def regenerate(state, image_process_mode, request: gr.Request):
77
- logger.info(f"regenerate. ip: {request.client.host}")
78
- state.messages[-1][-1] = None
79
- prev_human_msg = state.messages[-2]
80
- if type(prev_human_msg[1]) in (tuple, list):
81
- prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
82
- state.skip_next = False
83
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
84
-
85
-
86
- def clear_history(request: gr.Request):
87
- logger.info(f"clear_history. ip: {request.client.host}")
88
- state = default_conversation.copy()
89
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
90
-
91
-
92
- def add_text(state, text, image, image_process_mode, request: gr.Request):
93
- logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
94
- if len(text) <= 0 and image is None:
95
- state.skip_next = True
96
- return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
97
- if args.moderate:
98
- flagged = violates_moderation(text)
99
- if flagged:
100
- state.skip_next = True
101
- return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
102
- no_change_btn,) * 5
103
-
104
- text = text[:3584] # Hard cut-off
105
- if image is not None:
106
- text = text[:3500] # Hard cut-off for images
107
- if '<|image|>' not in text:
108
- text = '<|image|>' + text
109
- text = (text, image, image_process_mode)
110
- if len(state.get_images(return_pil=True)) > 0:
111
- state = default_conversation.copy()
112
- state.append_message(state.roles[0], text)
113
- state.append_message(state.roles[1], None)
114
- state.skip_next = False
115
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
116
-
117
-
118
- def http_bot(state, temperature, top_p, max_new_tokens, request: gr.Request):
119
- logger.info(f"http_bot. ip: {request.client.host}")
120
- start_tstamp = time.time()
121
-
122
- if state.skip_next:
123
- # This generate call is skipped due to invalid inputs
124
- yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
125
- return
126
-
127
- if len(state.messages) == state.offset + 2:
128
- # First round of conversation
129
- template_name = "mplug_owl2"
130
- new_state = conv_templates[template_name].copy()
131
- new_state.append_message(new_state.roles[0], state.messages[-2][1])
132
- new_state.append_message(new_state.roles[1], None)
133
- state = new_state
134
-
135
- # Construct prompt
136
- prompt = state.get_prompt()
137
-
138
- all_images = state.get_images(return_pil=True)
139
- all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
140
- for image, hash in zip(all_images, all_image_hash):
141
- t = datetime.datetime.now()
142
- filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
143
- if not os.path.isfile(filename):
144
- os.makedirs(os.path.dirname(filename), exist_ok=True)
145
- image.save(filename)
146
-
147
- # Make requests
148
- pload = {
149
- "prompt": prompt,
150
- "temperature": float(temperature),
151
- "top_p": float(top_p),
152
- "max_new_tokens": min(int(max_new_tokens), 2048),
153
- "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
154
- "images": f'List of {len(state.get_images())} images: {all_image_hash}',
155
- }
156
- logger.info(f"==== request ====\n{pload}")
157
-
158
- pload['images'] = state.get_images()
159
-
160
- state.messages[-1][-1] = "▌"
161
- yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
162
-
163
- try:
164
- # Stream output
165
- # response = requests.post(worker_addr + "/worker_generate_stream",
166
- # headers=headers, json=pload, stream=True, timeout=10)
167
- # for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
168
- response = model.generate_stream_gate(pload)
169
- for chunk in response:
170
- if chunk:
171
- data = json.loads(chunk.decode())
172
- if data["error_code"] == 0:
173
- output = data["text"][len(prompt):].strip()
174
- state.messages[-1][-1] = output + "▌"
175
- yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
176
- else:
177
- output = data["text"] + f" (error_code: {data['error_code']})"
178
- state.messages[-1][-1] = output
179
- yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
180
- return
181
- time.sleep(0.03)
182
- except requests.exceptions.RequestException as e:
183
- state.messages[-1][-1] = server_error_msg
184
- yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
185
- return
186
-
187
- state.messages[-1][-1] = state.messages[-1][-1][:-1]
188
- yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
189
-
190
- finish_tstamp = time.time()
191
- logger.info(f"{output}")
192
-
193
- with open(get_conv_log_filename(), "a") as fout:
194
- data = {
195
- "tstamp": round(finish_tstamp, 4),
196
- "type": "chat",
197
- "start": round(start_tstamp, 4),
198
- "finish": round(start_tstamp, 4),
199
- "state": state.dict(),
200
- "images": all_image_hash,
201
- "ip": request.client.host,
202
- }
203
- fout.write(json.dumps(data) + "\n")
204
-
205
-
206
- title_markdown = ("""
207
- <h1 align="center"><a href="https://github.com/X-PLUG/mPLUG-Owl"><img src="https://z1.ax1x.com/2023/11/03/piM1rGQ.md.png", alt="mPLUG-Owl" border="0" style="margin: 0 auto; height: 200px;" /></a> </h1>
208
-
209
- <h2 align="center"> mPLUG-Owl2: Revolutionizing Multi-modal Large Language Model with Modality Collaboration</h2>
210
-
211
- <h5 align="center"> If you like our project, please give us a star ✨ on Github for latest update. </h2>
212
-
213
- <div align="center">
214
- <div style="display:flex; gap: 0.25rem;" align="center">
215
- <a href='https://github.com/X-PLUG/mPLUG-Owl'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
216
- <a href="https://arxiv.org/abs/2304.14178"><img src="https://img.shields.io/badge/Arxiv-2304.14178-red"></a>
217
- <a href='https://github.com/X-PLUG/mPLUG-Owl/stargazers'><img src='https://img.shields.io/github/stars/X-PLUG/mPLUG-Owl.svg?style=social'></a>
218
- </div>
219
- </div>
220
-
221
- """)
222
-
223
-
224
- tos_markdown = ("""
225
- ### Terms of use
226
- By using this service, users are required to agree to the following terms:
227
- The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
228
- Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
229
- For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
230
- """)
231
-
232
-
233
- learn_more_markdown = ("""
234
- ### License
235
- The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
236
- """)
237
-
238
- block_css = """
239
-
240
- #buttons button {
241
- min-width: min(120px,100%);
242
- }
243
-
244
- """
245
-
246
- def build_demo(embed_mode):
247
- textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
248
- with gr.Blocks(title="mPLUG-Owl2", theme=gr.themes.Default(), css=block_css) as demo:
249
- state = gr.State()
250
-
251
- if not embed_mode:
252
- gr.Markdown(title_markdown)
253
-
254
- with gr.Row():
255
- with gr.Column(scale=3):
256
- imagebox = gr.Image(type="pil")
257
- image_process_mode = gr.Radio(
258
- ["Crop", "Resize", "Pad", "Default"],
259
- value="Default",
260
- label="Preprocess for non-square image", visible=False)
261
-
262
- cur_dir = os.path.dirname(os.path.abspath(__file__))
263
- gr.Examples(examples=[
264
- [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
265
- [f"{cur_dir}/examples/Rebecca_(1939_poster)_Small.jpeg", "What is the name of the movie in the poster?"],
266
- ], inputs=[imagebox, textbox])
267
-
268
- with gr.Accordion("Parameters", open=True) as parameter_row:
269
- temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
270
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
271
- max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
272
-
273
- with gr.Column(scale=8):
274
- chatbot = gr.Chatbot(elem_id="Chatbot", label="mPLUG-Owl2 Chatbot", height=600)
275
- with gr.Row():
276
- with gr.Column(scale=8):
277
- textbox.render()
278
- with gr.Column(scale=1, min_width=50):
279
- submit_btn = gr.Button(value="Send", variant="primary")
280
- with gr.Row(elem_id="buttons") as button_row:
281
- upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
282
- downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
283
- flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
284
- #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
285
- regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
286
- clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
287
-
288
- if not embed_mode:
289
- gr.Markdown(tos_markdown)
290
- gr.Markdown(learn_more_markdown)
291
- url_params = gr.JSON(visible=False)
292
-
293
- # Register listeners
294
- btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
295
- upvote_btn.click(
296
- upvote_last_response,
297
- state,
298
- [textbox, upvote_btn, downvote_btn, flag_btn],
299
- queue=False
300
- )
301
- downvote_btn.click(
302
- downvote_last_response,
303
- state,
304
- [textbox, upvote_btn, downvote_btn, flag_btn],
305
- queue=False
306
- )
307
- flag_btn.click(
308
- flag_last_response,
309
- state,
310
- [textbox, upvote_btn, downvote_btn, flag_btn],
311
- queue=False
312
- )
313
-
314
- regenerate_btn.click(
315
- regenerate,
316
- [state, image_process_mode],
317
- [state, chatbot, textbox, imagebox] + btn_list,
318
- queue=False
319
- ).then(
320
- http_bot,
321
- [state, temperature, top_p, max_output_tokens],
322
- [state, chatbot] + btn_list
323
- )
324
-
325
- clear_btn.click(
326
- clear_history,
327
- None,
328
- [state, chatbot, textbox, imagebox] + btn_list,
329
- queue=False
330
- )
331
-
332
- textbox.submit(
333
- add_text,
334
- [state, textbox, imagebox, image_process_mode],
335
- [state, chatbot, textbox, imagebox] + btn_list,
336
- queue=False
337
- ).then(
338
- http_bot,
339
- [state, temperature, top_p, max_output_tokens],
340
- [state, chatbot] + btn_list
341
- )
342
-
343
- submit_btn.click(
344
- add_text,
345
- [state, textbox, imagebox, image_process_mode],
346
- [state, chatbot, textbox, imagebox] + btn_list,
347
- queue=False
348
- ).then(
349
- http_bot,
350
- [state, temperature, top_p, max_output_tokens],
351
- [state, chatbot] + btn_list
352
- )
353
-
354
- demo.load(
355
- load_demo,
356
- [url_params],
357
- state,
358
- _js=get_window_url_params,
359
- queue=False
360
- )
361
-
362
- return demo
363
-
364
-
365
- if __name__ == "__main__":
366
- parser = argparse.ArgumentParser()
367
- parser.add_argument("--host", type=str, default="0.0.0.0")
368
- parser.add_argument("--port", type=int)
369
- parser.add_argument("--concurrency-count", type=int, default=10)
370
- parser.add_argument("--model-list-mode", type=str, default="once",
371
- choices=["once", "reload"])
372
- parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
373
- parser.add_argument("--device", type=str, default="cuda")
374
- parser.add_argument("--load-8bit", action="store_true")
375
- parser.add_argument("--load-4bit", action="store_true")
376
- parser.add_argument("--moderate", action="store_true")
377
- parser.add_argument("--embed", action="store_true")
378
- args = parser.parse_args()
379
- logger.info(f"args: {args}")
380
-
381
- model = ModelWorker(args.model_path, None, None, args.load_8bit, args.load_4bit, args.device)
382
-
383
- logger.info(args)
384
- demo = build_demo(args.embed)
385
- demo.queue(
386
- concurrency_count=args.concurrency_count,
387
- api_open=False
388
- ).launch(
389
- server_name=args.host,
390
- server_port=args.port,
391
- share=False
392
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mplug_docowl/local_serve/model_worker.py DELETED
@@ -1,143 +0,0 @@
1
- """
2
- A model worker executes the model.
3
- """
4
- import argparse
5
- import asyncio
6
- import json
7
- import time
8
- import threading
9
- import uuid
10
-
11
- import requests
12
- import torch
13
- from functools import partial
14
-
15
- from mplug_owl2.constants import WORKER_HEART_BEAT_INTERVAL
16
- from mplug_owl2.utils import (build_logger, server_error_msg,
17
- pretty_print_semaphore)
18
- from mplug_owl2.model.builder import load_pretrained_model
19
- from mplug_owl2.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
20
- from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
21
- from transformers import TextIteratorStreamer
22
- from threading import Thread
23
-
24
- GB = 1 << 30
25
-
26
- worker_id = str(uuid.uuid4())[:6]
27
- logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
28
-
29
- class ModelWorker:
30
- def __init__(self, model_path, model_base, model_name, load_8bit, load_4bit, device):
31
- self.worker_id = worker_id
32
- if model_path.endswith("/"):
33
- model_path = model_path[:-1]
34
- if model_name is None:
35
- model_paths = model_path.split("/")
36
- if model_paths[-1].startswith('checkpoint-'):
37
- self.model_name = model_paths[-2] + "_" + model_paths[-1]
38
- else:
39
- self.model_name = model_paths[-1]
40
- else:
41
- self.model_name = model_name
42
-
43
- self.device = device
44
- logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
45
- self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
46
- model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
47
- self.is_multimodal = True
48
-
49
- @torch.inference_mode()
50
- def generate_stream(self, params):
51
- tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
52
-
53
- prompt = params["prompt"]
54
- ori_prompt = prompt
55
- images = params.get("images", None)
56
- num_image_tokens = 0
57
- if images is not None and len(images) > 0 and self.is_multimodal:
58
- if len(images) > 0:
59
- if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
60
- raise ValueError("Number of images does not match number of <|image|> tokens in prompt")
61
-
62
- images = [load_image_from_base64(image) for image in images]
63
- images = process_images(images, image_processor, model.config)
64
-
65
- if type(images) is list:
66
- images = [image.to(self.model.device, dtype=torch.float16) for image in images]
67
- else:
68
- images = images.to(self.model.device, dtype=torch.float16)
69
-
70
- replace_token = DEFAULT_IMAGE_TOKEN
71
- prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
72
-
73
- num_image_tokens = prompt.count(replace_token) * (model.get_model().visual_abstractor.config.num_learnable_queries + 1)
74
- else:
75
- images = None
76
- image_args = {"images": images}
77
- else:
78
- images = None
79
- image_args = {}
80
-
81
- temperature = float(params.get("temperature", 1.0))
82
- top_p = float(params.get("top_p", 1.0))
83
- max_context_length = getattr(model.config, 'max_position_embeddings', 4096)
84
- max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
85
- stop_str = params.get("stop", None)
86
- do_sample = True if temperature > 0.001 else False
87
-
88
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
89
- keywords = [stop_str]
90
- stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
91
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
92
-
93
- max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
94
-
95
- if max_new_tokens < 1:
96
- yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
97
- return
98
-
99
- thread = Thread(target=model.generate, kwargs=dict(
100
- inputs=input_ids,
101
- do_sample=do_sample,
102
- temperature=temperature,
103
- top_p=top_p,
104
- max_new_tokens=max_new_tokens,
105
- streamer=streamer,
106
- stopping_criteria=[stopping_criteria],
107
- use_cache=True,
108
- **image_args
109
- ))
110
- thread.start()
111
-
112
- generated_text = ori_prompt
113
- for new_text in streamer:
114
- generated_text += new_text
115
- if generated_text.endswith(stop_str):
116
- generated_text = generated_text[:-len(stop_str)]
117
- yield json.dumps({"text": generated_text, "error_code": 0}).encode()
118
-
119
- def generate_stream_gate(self, params):
120
- try:
121
- for x in self.generate_stream(params):
122
- yield x
123
- except ValueError as e:
124
- print("Caught ValueError:", e)
125
- ret = {
126
- "text": server_error_msg,
127
- "error_code": 1,
128
- }
129
- yield json.dumps(ret).encode()
130
- except torch.cuda.CudaError as e:
131
- print("Caught torch.cuda.CudaError:", e)
132
- ret = {
133
- "text": server_error_msg,
134
- "error_code": 1,
135
- }
136
- yield json.dumps(ret).encode()
137
- except Exception as e:
138
- print("Caught Unknown Error", e)
139
- ret = {
140
- "text": server_error_msg,
141
- "error_code": 1,
142
- }
143
- yield json.dumps(ret).encode()