Spaces:
Runtime error
Runtime error
Delete mplug_docowl/local_serve
Browse files
mplug_docowl/local_serve/__init__.py
DELETED
File without changes
|
mplug_docowl/local_serve/examples/Rebecca_(1939_poster)_Small.jpeg
DELETED
Binary file (18.9 kB)
|
|
mplug_docowl/local_serve/examples/extreme_ironing.jpg
DELETED
Binary file (62.6 kB)
|
|
mplug_docowl/local_serve/local_web_server.py
DELETED
@@ -1,392 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import datetime
|
3 |
-
import json
|
4 |
-
import os
|
5 |
-
import time
|
6 |
-
|
7 |
-
import gradio as gr
|
8 |
-
import requests
|
9 |
-
|
10 |
-
from mplug_owl2.conversation import (default_conversation, conv_templates,
|
11 |
-
SeparatorStyle)
|
12 |
-
from mplug_owl2.constants import LOGDIR
|
13 |
-
from mplug_owl2.utils import (build_logger, server_error_msg,
|
14 |
-
violates_moderation, moderation_msg)
|
15 |
-
from .model_worker import ModelWorker
|
16 |
-
import hashlib
|
17 |
-
|
18 |
-
logger = build_logger("gradio_web_server_local", "gradio_web_server_local.log")
|
19 |
-
|
20 |
-
headers = {"User-Agent": "mPLUG-Owl2 Client"}
|
21 |
-
|
22 |
-
no_change_btn = gr.Button.update()
|
23 |
-
enable_btn = gr.Button.update(interactive=True)
|
24 |
-
disable_btn = gr.Button.update(interactive=False)
|
25 |
-
|
26 |
-
def get_conv_log_filename():
|
27 |
-
t = datetime.datetime.now()
|
28 |
-
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
29 |
-
return name
|
30 |
-
|
31 |
-
get_window_url_params = """
|
32 |
-
function() {
|
33 |
-
const params = new URLSearchParams(window.location.search);
|
34 |
-
url_params = Object.fromEntries(params);
|
35 |
-
console.log(url_params);
|
36 |
-
return url_params;
|
37 |
-
}
|
38 |
-
"""
|
39 |
-
|
40 |
-
|
41 |
-
def load_demo(url_params, request: gr.Request):
|
42 |
-
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
43 |
-
state = default_conversation.copy()
|
44 |
-
return state
|
45 |
-
|
46 |
-
|
47 |
-
def vote_last_response(state, vote_type, request: gr.Request):
|
48 |
-
with open(get_conv_log_filename(), "a") as fout:
|
49 |
-
data = {
|
50 |
-
"tstamp": round(time.time(), 4),
|
51 |
-
"type": vote_type,
|
52 |
-
"state": state.dict(),
|
53 |
-
"ip": request.client.host,
|
54 |
-
}
|
55 |
-
fout.write(json.dumps(data) + "\n")
|
56 |
-
|
57 |
-
|
58 |
-
def upvote_last_response(state, request: gr.Request):
|
59 |
-
logger.info(f"upvote. ip: {request.client.host}")
|
60 |
-
vote_last_response(state, "upvote", request)
|
61 |
-
return ("",) + (disable_btn,) * 3
|
62 |
-
|
63 |
-
|
64 |
-
def downvote_last_response(state, request: gr.Request):
|
65 |
-
logger.info(f"downvote. ip: {request.client.host}")
|
66 |
-
vote_last_response(state, "downvote", request)
|
67 |
-
return ("",) + (disable_btn,) * 3
|
68 |
-
|
69 |
-
|
70 |
-
def flag_last_response(state, request: gr.Request):
|
71 |
-
logger.info(f"flag. ip: {request.client.host}")
|
72 |
-
vote_last_response(state, "flag", request)
|
73 |
-
return ("",) + (disable_btn,) * 3
|
74 |
-
|
75 |
-
|
76 |
-
def regenerate(state, image_process_mode, request: gr.Request):
|
77 |
-
logger.info(f"regenerate. ip: {request.client.host}")
|
78 |
-
state.messages[-1][-1] = None
|
79 |
-
prev_human_msg = state.messages[-2]
|
80 |
-
if type(prev_human_msg[1]) in (tuple, list):
|
81 |
-
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
|
82 |
-
state.skip_next = False
|
83 |
-
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
84 |
-
|
85 |
-
|
86 |
-
def clear_history(request: gr.Request):
|
87 |
-
logger.info(f"clear_history. ip: {request.client.host}")
|
88 |
-
state = default_conversation.copy()
|
89 |
-
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
90 |
-
|
91 |
-
|
92 |
-
def add_text(state, text, image, image_process_mode, request: gr.Request):
|
93 |
-
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
|
94 |
-
if len(text) <= 0 and image is None:
|
95 |
-
state.skip_next = True
|
96 |
-
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
|
97 |
-
if args.moderate:
|
98 |
-
flagged = violates_moderation(text)
|
99 |
-
if flagged:
|
100 |
-
state.skip_next = True
|
101 |
-
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
|
102 |
-
no_change_btn,) * 5
|
103 |
-
|
104 |
-
text = text[:3584] # Hard cut-off
|
105 |
-
if image is not None:
|
106 |
-
text = text[:3500] # Hard cut-off for images
|
107 |
-
if '<|image|>' not in text:
|
108 |
-
text = '<|image|>' + text
|
109 |
-
text = (text, image, image_process_mode)
|
110 |
-
if len(state.get_images(return_pil=True)) > 0:
|
111 |
-
state = default_conversation.copy()
|
112 |
-
state.append_message(state.roles[0], text)
|
113 |
-
state.append_message(state.roles[1], None)
|
114 |
-
state.skip_next = False
|
115 |
-
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
116 |
-
|
117 |
-
|
118 |
-
def http_bot(state, temperature, top_p, max_new_tokens, request: gr.Request):
|
119 |
-
logger.info(f"http_bot. ip: {request.client.host}")
|
120 |
-
start_tstamp = time.time()
|
121 |
-
|
122 |
-
if state.skip_next:
|
123 |
-
# This generate call is skipped due to invalid inputs
|
124 |
-
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
125 |
-
return
|
126 |
-
|
127 |
-
if len(state.messages) == state.offset + 2:
|
128 |
-
# First round of conversation
|
129 |
-
template_name = "mplug_owl2"
|
130 |
-
new_state = conv_templates[template_name].copy()
|
131 |
-
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
132 |
-
new_state.append_message(new_state.roles[1], None)
|
133 |
-
state = new_state
|
134 |
-
|
135 |
-
# Construct prompt
|
136 |
-
prompt = state.get_prompt()
|
137 |
-
|
138 |
-
all_images = state.get_images(return_pil=True)
|
139 |
-
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
|
140 |
-
for image, hash in zip(all_images, all_image_hash):
|
141 |
-
t = datetime.datetime.now()
|
142 |
-
filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
|
143 |
-
if not os.path.isfile(filename):
|
144 |
-
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
145 |
-
image.save(filename)
|
146 |
-
|
147 |
-
# Make requests
|
148 |
-
pload = {
|
149 |
-
"prompt": prompt,
|
150 |
-
"temperature": float(temperature),
|
151 |
-
"top_p": float(top_p),
|
152 |
-
"max_new_tokens": min(int(max_new_tokens), 2048),
|
153 |
-
"stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
|
154 |
-
"images": f'List of {len(state.get_images())} images: {all_image_hash}',
|
155 |
-
}
|
156 |
-
logger.info(f"==== request ====\n{pload}")
|
157 |
-
|
158 |
-
pload['images'] = state.get_images()
|
159 |
-
|
160 |
-
state.messages[-1][-1] = "▌"
|
161 |
-
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
162 |
-
|
163 |
-
try:
|
164 |
-
# Stream output
|
165 |
-
# response = requests.post(worker_addr + "/worker_generate_stream",
|
166 |
-
# headers=headers, json=pload, stream=True, timeout=10)
|
167 |
-
# for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
168 |
-
response = model.generate_stream_gate(pload)
|
169 |
-
for chunk in response:
|
170 |
-
if chunk:
|
171 |
-
data = json.loads(chunk.decode())
|
172 |
-
if data["error_code"] == 0:
|
173 |
-
output = data["text"][len(prompt):].strip()
|
174 |
-
state.messages[-1][-1] = output + "▌"
|
175 |
-
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
176 |
-
else:
|
177 |
-
output = data["text"] + f" (error_code: {data['error_code']})"
|
178 |
-
state.messages[-1][-1] = output
|
179 |
-
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
180 |
-
return
|
181 |
-
time.sleep(0.03)
|
182 |
-
except requests.exceptions.RequestException as e:
|
183 |
-
state.messages[-1][-1] = server_error_msg
|
184 |
-
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
185 |
-
return
|
186 |
-
|
187 |
-
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
188 |
-
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
189 |
-
|
190 |
-
finish_tstamp = time.time()
|
191 |
-
logger.info(f"{output}")
|
192 |
-
|
193 |
-
with open(get_conv_log_filename(), "a") as fout:
|
194 |
-
data = {
|
195 |
-
"tstamp": round(finish_tstamp, 4),
|
196 |
-
"type": "chat",
|
197 |
-
"start": round(start_tstamp, 4),
|
198 |
-
"finish": round(start_tstamp, 4),
|
199 |
-
"state": state.dict(),
|
200 |
-
"images": all_image_hash,
|
201 |
-
"ip": request.client.host,
|
202 |
-
}
|
203 |
-
fout.write(json.dumps(data) + "\n")
|
204 |
-
|
205 |
-
|
206 |
-
title_markdown = ("""
|
207 |
-
<h1 align="center"><a href="https://github.com/X-PLUG/mPLUG-Owl"><img src="https://z1.ax1x.com/2023/11/03/piM1rGQ.md.png", alt="mPLUG-Owl" border="0" style="margin: 0 auto; height: 200px;" /></a> </h1>
|
208 |
-
|
209 |
-
<h2 align="center"> mPLUG-Owl2: Revolutionizing Multi-modal Large Language Model with Modality Collaboration</h2>
|
210 |
-
|
211 |
-
<h5 align="center"> If you like our project, please give us a star ✨ on Github for latest update. </h2>
|
212 |
-
|
213 |
-
<div align="center">
|
214 |
-
<div style="display:flex; gap: 0.25rem;" align="center">
|
215 |
-
<a href='https://github.com/X-PLUG/mPLUG-Owl'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
|
216 |
-
<a href="https://arxiv.org/abs/2304.14178"><img src="https://img.shields.io/badge/Arxiv-2304.14178-red"></a>
|
217 |
-
<a href='https://github.com/X-PLUG/mPLUG-Owl/stargazers'><img src='https://img.shields.io/github/stars/X-PLUG/mPLUG-Owl.svg?style=social'></a>
|
218 |
-
</div>
|
219 |
-
</div>
|
220 |
-
|
221 |
-
""")
|
222 |
-
|
223 |
-
|
224 |
-
tos_markdown = ("""
|
225 |
-
### Terms of use
|
226 |
-
By using this service, users are required to agree to the following terms:
|
227 |
-
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
228 |
-
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
229 |
-
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
230 |
-
""")
|
231 |
-
|
232 |
-
|
233 |
-
learn_more_markdown = ("""
|
234 |
-
### License
|
235 |
-
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
|
236 |
-
""")
|
237 |
-
|
238 |
-
block_css = """
|
239 |
-
|
240 |
-
#buttons button {
|
241 |
-
min-width: min(120px,100%);
|
242 |
-
}
|
243 |
-
|
244 |
-
"""
|
245 |
-
|
246 |
-
def build_demo(embed_mode):
|
247 |
-
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
|
248 |
-
with gr.Blocks(title="mPLUG-Owl2", theme=gr.themes.Default(), css=block_css) as demo:
|
249 |
-
state = gr.State()
|
250 |
-
|
251 |
-
if not embed_mode:
|
252 |
-
gr.Markdown(title_markdown)
|
253 |
-
|
254 |
-
with gr.Row():
|
255 |
-
with gr.Column(scale=3):
|
256 |
-
imagebox = gr.Image(type="pil")
|
257 |
-
image_process_mode = gr.Radio(
|
258 |
-
["Crop", "Resize", "Pad", "Default"],
|
259 |
-
value="Default",
|
260 |
-
label="Preprocess for non-square image", visible=False)
|
261 |
-
|
262 |
-
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
263 |
-
gr.Examples(examples=[
|
264 |
-
[f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
|
265 |
-
[f"{cur_dir}/examples/Rebecca_(1939_poster)_Small.jpeg", "What is the name of the movie in the poster?"],
|
266 |
-
], inputs=[imagebox, textbox])
|
267 |
-
|
268 |
-
with gr.Accordion("Parameters", open=True) as parameter_row:
|
269 |
-
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
|
270 |
-
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
|
271 |
-
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
|
272 |
-
|
273 |
-
with gr.Column(scale=8):
|
274 |
-
chatbot = gr.Chatbot(elem_id="Chatbot", label="mPLUG-Owl2 Chatbot", height=600)
|
275 |
-
with gr.Row():
|
276 |
-
with gr.Column(scale=8):
|
277 |
-
textbox.render()
|
278 |
-
with gr.Column(scale=1, min_width=50):
|
279 |
-
submit_btn = gr.Button(value="Send", variant="primary")
|
280 |
-
with gr.Row(elem_id="buttons") as button_row:
|
281 |
-
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
282 |
-
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
283 |
-
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
|
284 |
-
#stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
|
285 |
-
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
286 |
-
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
|
287 |
-
|
288 |
-
if not embed_mode:
|
289 |
-
gr.Markdown(tos_markdown)
|
290 |
-
gr.Markdown(learn_more_markdown)
|
291 |
-
url_params = gr.JSON(visible=False)
|
292 |
-
|
293 |
-
# Register listeners
|
294 |
-
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
295 |
-
upvote_btn.click(
|
296 |
-
upvote_last_response,
|
297 |
-
state,
|
298 |
-
[textbox, upvote_btn, downvote_btn, flag_btn],
|
299 |
-
queue=False
|
300 |
-
)
|
301 |
-
downvote_btn.click(
|
302 |
-
downvote_last_response,
|
303 |
-
state,
|
304 |
-
[textbox, upvote_btn, downvote_btn, flag_btn],
|
305 |
-
queue=False
|
306 |
-
)
|
307 |
-
flag_btn.click(
|
308 |
-
flag_last_response,
|
309 |
-
state,
|
310 |
-
[textbox, upvote_btn, downvote_btn, flag_btn],
|
311 |
-
queue=False
|
312 |
-
)
|
313 |
-
|
314 |
-
regenerate_btn.click(
|
315 |
-
regenerate,
|
316 |
-
[state, image_process_mode],
|
317 |
-
[state, chatbot, textbox, imagebox] + btn_list,
|
318 |
-
queue=False
|
319 |
-
).then(
|
320 |
-
http_bot,
|
321 |
-
[state, temperature, top_p, max_output_tokens],
|
322 |
-
[state, chatbot] + btn_list
|
323 |
-
)
|
324 |
-
|
325 |
-
clear_btn.click(
|
326 |
-
clear_history,
|
327 |
-
None,
|
328 |
-
[state, chatbot, textbox, imagebox] + btn_list,
|
329 |
-
queue=False
|
330 |
-
)
|
331 |
-
|
332 |
-
textbox.submit(
|
333 |
-
add_text,
|
334 |
-
[state, textbox, imagebox, image_process_mode],
|
335 |
-
[state, chatbot, textbox, imagebox] + btn_list,
|
336 |
-
queue=False
|
337 |
-
).then(
|
338 |
-
http_bot,
|
339 |
-
[state, temperature, top_p, max_output_tokens],
|
340 |
-
[state, chatbot] + btn_list
|
341 |
-
)
|
342 |
-
|
343 |
-
submit_btn.click(
|
344 |
-
add_text,
|
345 |
-
[state, textbox, imagebox, image_process_mode],
|
346 |
-
[state, chatbot, textbox, imagebox] + btn_list,
|
347 |
-
queue=False
|
348 |
-
).then(
|
349 |
-
http_bot,
|
350 |
-
[state, temperature, top_p, max_output_tokens],
|
351 |
-
[state, chatbot] + btn_list
|
352 |
-
)
|
353 |
-
|
354 |
-
demo.load(
|
355 |
-
load_demo,
|
356 |
-
[url_params],
|
357 |
-
state,
|
358 |
-
_js=get_window_url_params,
|
359 |
-
queue=False
|
360 |
-
)
|
361 |
-
|
362 |
-
return demo
|
363 |
-
|
364 |
-
|
365 |
-
if __name__ == "__main__":
|
366 |
-
parser = argparse.ArgumentParser()
|
367 |
-
parser.add_argument("--host", type=str, default="0.0.0.0")
|
368 |
-
parser.add_argument("--port", type=int)
|
369 |
-
parser.add_argument("--concurrency-count", type=int, default=10)
|
370 |
-
parser.add_argument("--model-list-mode", type=str, default="once",
|
371 |
-
choices=["once", "reload"])
|
372 |
-
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
373 |
-
parser.add_argument("--device", type=str, default="cuda")
|
374 |
-
parser.add_argument("--load-8bit", action="store_true")
|
375 |
-
parser.add_argument("--load-4bit", action="store_true")
|
376 |
-
parser.add_argument("--moderate", action="store_true")
|
377 |
-
parser.add_argument("--embed", action="store_true")
|
378 |
-
args = parser.parse_args()
|
379 |
-
logger.info(f"args: {args}")
|
380 |
-
|
381 |
-
model = ModelWorker(args.model_path, None, None, args.load_8bit, args.load_4bit, args.device)
|
382 |
-
|
383 |
-
logger.info(args)
|
384 |
-
demo = build_demo(args.embed)
|
385 |
-
demo.queue(
|
386 |
-
concurrency_count=args.concurrency_count,
|
387 |
-
api_open=False
|
388 |
-
).launch(
|
389 |
-
server_name=args.host,
|
390 |
-
server_port=args.port,
|
391 |
-
share=False
|
392 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mplug_docowl/local_serve/model_worker.py
DELETED
@@ -1,143 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
A model worker executes the model.
|
3 |
-
"""
|
4 |
-
import argparse
|
5 |
-
import asyncio
|
6 |
-
import json
|
7 |
-
import time
|
8 |
-
import threading
|
9 |
-
import uuid
|
10 |
-
|
11 |
-
import requests
|
12 |
-
import torch
|
13 |
-
from functools import partial
|
14 |
-
|
15 |
-
from mplug_owl2.constants import WORKER_HEART_BEAT_INTERVAL
|
16 |
-
from mplug_owl2.utils import (build_logger, server_error_msg,
|
17 |
-
pretty_print_semaphore)
|
18 |
-
from mplug_owl2.model.builder import load_pretrained_model
|
19 |
-
from mplug_owl2.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
|
20 |
-
from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
21 |
-
from transformers import TextIteratorStreamer
|
22 |
-
from threading import Thread
|
23 |
-
|
24 |
-
GB = 1 << 30
|
25 |
-
|
26 |
-
worker_id = str(uuid.uuid4())[:6]
|
27 |
-
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
28 |
-
|
29 |
-
class ModelWorker:
|
30 |
-
def __init__(self, model_path, model_base, model_name, load_8bit, load_4bit, device):
|
31 |
-
self.worker_id = worker_id
|
32 |
-
if model_path.endswith("/"):
|
33 |
-
model_path = model_path[:-1]
|
34 |
-
if model_name is None:
|
35 |
-
model_paths = model_path.split("/")
|
36 |
-
if model_paths[-1].startswith('checkpoint-'):
|
37 |
-
self.model_name = model_paths[-2] + "_" + model_paths[-1]
|
38 |
-
else:
|
39 |
-
self.model_name = model_paths[-1]
|
40 |
-
else:
|
41 |
-
self.model_name = model_name
|
42 |
-
|
43 |
-
self.device = device
|
44 |
-
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
|
45 |
-
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
|
46 |
-
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
|
47 |
-
self.is_multimodal = True
|
48 |
-
|
49 |
-
@torch.inference_mode()
|
50 |
-
def generate_stream(self, params):
|
51 |
-
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
|
52 |
-
|
53 |
-
prompt = params["prompt"]
|
54 |
-
ori_prompt = prompt
|
55 |
-
images = params.get("images", None)
|
56 |
-
num_image_tokens = 0
|
57 |
-
if images is not None and len(images) > 0 and self.is_multimodal:
|
58 |
-
if len(images) > 0:
|
59 |
-
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
60 |
-
raise ValueError("Number of images does not match number of <|image|> tokens in prompt")
|
61 |
-
|
62 |
-
images = [load_image_from_base64(image) for image in images]
|
63 |
-
images = process_images(images, image_processor, model.config)
|
64 |
-
|
65 |
-
if type(images) is list:
|
66 |
-
images = [image.to(self.model.device, dtype=torch.float16) for image in images]
|
67 |
-
else:
|
68 |
-
images = images.to(self.model.device, dtype=torch.float16)
|
69 |
-
|
70 |
-
replace_token = DEFAULT_IMAGE_TOKEN
|
71 |
-
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
72 |
-
|
73 |
-
num_image_tokens = prompt.count(replace_token) * (model.get_model().visual_abstractor.config.num_learnable_queries + 1)
|
74 |
-
else:
|
75 |
-
images = None
|
76 |
-
image_args = {"images": images}
|
77 |
-
else:
|
78 |
-
images = None
|
79 |
-
image_args = {}
|
80 |
-
|
81 |
-
temperature = float(params.get("temperature", 1.0))
|
82 |
-
top_p = float(params.get("top_p", 1.0))
|
83 |
-
max_context_length = getattr(model.config, 'max_position_embeddings', 4096)
|
84 |
-
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
|
85 |
-
stop_str = params.get("stop", None)
|
86 |
-
do_sample = True if temperature > 0.001 else False
|
87 |
-
|
88 |
-
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|
89 |
-
keywords = [stop_str]
|
90 |
-
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
91 |
-
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
|
92 |
-
|
93 |
-
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
|
94 |
-
|
95 |
-
if max_new_tokens < 1:
|
96 |
-
yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
|
97 |
-
return
|
98 |
-
|
99 |
-
thread = Thread(target=model.generate, kwargs=dict(
|
100 |
-
inputs=input_ids,
|
101 |
-
do_sample=do_sample,
|
102 |
-
temperature=temperature,
|
103 |
-
top_p=top_p,
|
104 |
-
max_new_tokens=max_new_tokens,
|
105 |
-
streamer=streamer,
|
106 |
-
stopping_criteria=[stopping_criteria],
|
107 |
-
use_cache=True,
|
108 |
-
**image_args
|
109 |
-
))
|
110 |
-
thread.start()
|
111 |
-
|
112 |
-
generated_text = ori_prompt
|
113 |
-
for new_text in streamer:
|
114 |
-
generated_text += new_text
|
115 |
-
if generated_text.endswith(stop_str):
|
116 |
-
generated_text = generated_text[:-len(stop_str)]
|
117 |
-
yield json.dumps({"text": generated_text, "error_code": 0}).encode()
|
118 |
-
|
119 |
-
def generate_stream_gate(self, params):
|
120 |
-
try:
|
121 |
-
for x in self.generate_stream(params):
|
122 |
-
yield x
|
123 |
-
except ValueError as e:
|
124 |
-
print("Caught ValueError:", e)
|
125 |
-
ret = {
|
126 |
-
"text": server_error_msg,
|
127 |
-
"error_code": 1,
|
128 |
-
}
|
129 |
-
yield json.dumps(ret).encode()
|
130 |
-
except torch.cuda.CudaError as e:
|
131 |
-
print("Caught torch.cuda.CudaError:", e)
|
132 |
-
ret = {
|
133 |
-
"text": server_error_msg,
|
134 |
-
"error_code": 1,
|
135 |
-
}
|
136 |
-
yield json.dumps(ret).encode()
|
137 |
-
except Exception as e:
|
138 |
-
print("Caught Unknown Error", e)
|
139 |
-
ret = {
|
140 |
-
"text": server_error_msg,
|
141 |
-
"error_code": 1,
|
142 |
-
}
|
143 |
-
yield json.dumps(ret).encode()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|