MonsterMMORPG commited on
Commit
16f7413
·
verified ·
1 Parent(s): 9206968

Upload gradio_web_server.py

Browse files
Files changed (1) hide show
  1. gradio_web_server.py +94 -92
gradio_web_server.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import argparse
2
  import datetime
3
  import json
@@ -19,9 +20,9 @@ logger = build_logger("gradio_web_server", "gradio_web_server.log")
19
 
20
  headers = {"User-Agent": "LLaVA Client"}
21
 
22
- no_change_btn = gr.Button.update()
23
- enable_btn = gr.Button.update(interactive=True)
24
- disable_btn = gr.Button.update(interactive=False)
25
 
26
  priority = {
27
  "vicuna-13b": "aaaaaaa",
@@ -58,12 +59,11 @@ function() {
58
  def load_demo(url_params, request: gr.Request):
59
  logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
60
 
61
- dropdown_update = gr.Dropdown.update(visible=True)
62
  if "model" in url_params:
63
  model = url_params["model"]
64
  if model in models:
65
- dropdown_update = gr.Dropdown.update(
66
- value=model, visible=True)
67
 
68
  state = default_conversation.copy()
69
  return state, dropdown_update
@@ -73,7 +73,7 @@ def load_demo_refresh_model_list(request: gr.Request):
73
  logger.info(f"load_demo. ip: {request.client.host}")
74
  models = get_model_list()
75
  state = default_conversation.copy()
76
- dropdown_update = gr.Dropdown.update(
77
  choices=models,
78
  value=models[0] if len(models) > 0 else ""
79
  )
@@ -124,8 +124,7 @@ def clear_history(request: gr.Request):
124
  logger.info(f"clear_history. ip: {request.client.host}")
125
  state = default_conversation.copy()
126
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
127
-
128
-
129
 
130
  def add_text(state, text, image, image_process_mode, request: gr.Request):
131
  logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
@@ -153,65 +152,14 @@ def add_text(state, text, image, image_process_mode, request: gr.Request):
153
  state.skip_next = False
154
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
155
 
156
-
157
- def batch_process_images(folder_path, textbox, model_selector, temperature, top_p, max_output_tokens, request: gr.Request):
158
- print("Starting batch processing of images")
159
-
160
- # Initialize counters and timer
161
- image_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif'))]
162
- total_images = len(image_files)
163
- processed_images = 0
164
- total_processing_time = 0
165
-
166
- # Process each image file
167
- for filename in image_files:
168
- image_path = os.path.join(folder_path, filename)
169
- start_time = time.time()
170
-
171
- with Image.open(image_path) as image:
172
- state = default_conversation.copy()
173
- state, _, _, _, _, _, _, _, _ = add_text(state, textbox, image, "Default", request)
174
-
175
- # Call http_bot and iterate over the generator
176
- response_text = ""
177
- for state_update in http_bot(state, model_selector, temperature, top_p, max_output_tokens, request):
178
- # Update state and extract response text
179
- state, chatbot_output, *_ = state_update
180
- response_text = chatbot_output
181
-
182
- # Save the final response to a file
183
- try:
184
- with open(os.path.splitext(image_path)[0] + '.txt', 'w') as f:
185
- f.write(response_text[0][1])
186
- except Exception as e:
187
- print(f"An error occurred: {e}")
188
-
189
- # Update processing information
190
- processed_images += 1
191
- processing_time = time.time() - start_time
192
- total_processing_time += processing_time
193
- average_processing_time = total_processing_time / processed_images
194
- images_left = total_images - processed_images
195
- eta_seconds = average_processing_time * images_left
196
- eta = datetime.timedelta(seconds=int(eta_seconds))
197
-
198
- # Display progress information
199
- print(f"{processed_images}/{total_images} images processed, {images_left} left, average process time {average_processing_time :.2f} seconds, ETA: {str(eta)}")
200
-
201
- return "Batch processing completed."
202
-
203
-
204
-
205
 
206
  def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
207
  logger.info(f"http_bot. ip: {request.client.host}")
208
- print(f"model_selector {model_selector}")
209
  start_tstamp = time.time()
210
  model_name = model_selector
211
-
212
  if state.skip_next:
213
  # This generate call is skipped due to invalid inputs
214
- print("invalid input state.skip_next")
215
  yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
216
  return
217
 
@@ -220,6 +168,15 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request:
220
  if "llava" in model_name.lower():
221
  if 'llama-2' in model_name.lower():
222
  template_name = "llava_llama_2"
 
 
 
 
 
 
 
 
 
223
  elif "v1" in model_name.lower():
224
  if 'mmtag' in model_name.lower():
225
  template_name = "v1_mmtag"
@@ -242,7 +199,6 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request:
242
  template_name = "llama_2"
243
  else:
244
  template_name = "vicuna_v1"
245
- print(f"template_name {template_name}")
246
  new_state = conv_templates[template_name].copy()
247
  new_state.append_message(new_state.roles[0], state.messages[-2][1])
248
  new_state.append_message(new_state.roles[1], None)
@@ -258,7 +214,6 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request:
258
  # No available worker
259
  if worker_addr == "":
260
  state.messages[-1][-1] = server_error_msg
261
- print(f"error No available worker")
262
  yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
263
  return
264
 
@@ -285,12 +240,12 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request:
285
  "images": f'List of {len(state.get_images())} images: {all_image_hash}',
286
  }
287
  logger.info(f"==== request ====\n{pload}")
288
-
289
  pload['images'] = state.get_images()
290
 
291
  state.messages[-1][-1] = "▌"
292
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
293
- print(f"entering Stream output")
294
  try:
295
  # Stream output
296
  response = requests.post(worker_addr + "/worker_generate_stream",
@@ -334,7 +289,8 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request:
334
 
335
  title_markdown = ("""
336
  Most Up To Date Scripts On : https://www.patreon.com/posts/sota-very-best-90744385 \n
337
- Original Project : https://llava-vl.github.io
 
338
  """)
339
 
340
  tos_markdown = ("""
@@ -359,17 +315,56 @@ block_css = """
359
 
360
  """
361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
 
 
 
 
 
 
 
 
 
363
 
 
 
364
 
 
365
 
366
- def build_demo(embed_mode):
367
  textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
368
-
369
- # New components for batch processing
370
  folder_input = gr.Textbox(label="Enter Folder Path for Batch Processing")
371
  batch_btn = gr.Button("Batch Process")
372
-
373
  with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
374
  state = gr.State()
375
 
@@ -392,7 +387,8 @@ def build_demo(embed_mode):
392
  value="Default",
393
  label="Preprocess for non-square image", visible=False)
394
 
395
- cur_dir = os.path.dirname(os.path.abspath(__file__))
 
396
  gr.Examples(examples=[
397
  [f"{cur_dir}/examples/extreme_ironing.jpg", "just caption the image with details, colors, items, objects, emotions, art style, drawing style and objects but do not add any description or comment. do not miss any item in the given image"],
398
  ], inputs=[imagebox, textbox])
@@ -403,7 +399,12 @@ def build_demo(embed_mode):
403
  max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
404
 
405
  with gr.Column(scale=8):
406
- chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", height=550)
 
 
 
 
 
407
  with gr.Row():
408
  with gr.Column(scale=8):
409
  textbox.render()
@@ -413,9 +414,13 @@ def build_demo(embed_mode):
413
  upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
414
  downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
415
  flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
 
416
  regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
417
  clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
418
 
 
 
 
419
  url_params = gr.JSON(visible=False)
420
 
421
  # Add new components for batch processing
@@ -435,31 +440,28 @@ def build_demo(embed_mode):
435
  upvote_btn.click(
436
  upvote_last_response,
437
  [state, model_selector],
438
- [textbox, upvote_btn, downvote_btn, flag_btn],
439
- queue=False
440
  )
441
  downvote_btn.click(
442
  downvote_last_response,
443
  [state, model_selector],
444
- [textbox, upvote_btn, downvote_btn, flag_btn],
445
- queue=False
446
  )
447
  flag_btn.click(
448
  flag_last_response,
449
  [state, model_selector],
450
- [textbox, upvote_btn, downvote_btn, flag_btn],
451
- queue=False
452
  )
453
 
454
  regenerate_btn.click(
455
  regenerate,
456
  [state, image_process_mode],
457
- [state, chatbot, textbox, imagebox] + btn_list,
458
- queue=False
459
  ).then(
460
  http_bot,
461
  [state, model_selector, temperature, top_p, max_output_tokens],
462
- [state, chatbot] + btn_list
 
463
  )
464
 
465
  clear_btn.click(
@@ -477,18 +479,19 @@ def build_demo(embed_mode):
477
  ).then(
478
  http_bot,
479
  [state, model_selector, temperature, top_p, max_output_tokens],
480
- [state, chatbot] + btn_list
 
481
  )
482
 
483
  submit_btn.click(
484
  add_text,
485
  [state, textbox, imagebox, image_process_mode],
486
- [state, chatbot, textbox, imagebox] + btn_list,
487
- queue=False
488
  ).then(
489
  http_bot,
490
  [state, model_selector, temperature, top_p, max_output_tokens],
491
- [state, chatbot] + btn_list
 
492
  )
493
 
494
  if args.model_list_mode == "once":
@@ -496,8 +499,7 @@ def build_demo(embed_mode):
496
  load_demo,
497
  [url_params],
498
  [state, model_selector],
499
- _js=get_window_url_params,
500
- queue=False
501
  )
502
  elif args.model_list_mode == "reload":
503
  demo.load(
@@ -517,8 +519,8 @@ if __name__ == "__main__":
517
  parser.add_argument("--host", type=str, default="0.0.0.0")
518
  parser.add_argument("--port", type=int)
519
  parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
520
- parser.add_argument("--concurrency-count", type=int, default=10)
521
- parser.add_argument("--model-list-mode", type=str, default="reload",
522
  choices=["once", "reload"])
523
  parser.add_argument("--share", action="store_true")
524
  parser.add_argument("--moderate", action="store_true")
@@ -529,12 +531,12 @@ if __name__ == "__main__":
529
  models = get_model_list()
530
 
531
  logger.info(args)
532
- demo = build_demo(args.embed)
533
  demo.queue(
534
- concurrency_count=args.concurrency_count,
535
  api_open=False
536
  ).launch(
537
  server_name=args.host,
538
  server_port=args.port,
539
- share=args.share
 
540
  )
 
1
+
2
  import argparse
3
  import datetime
4
  import json
 
20
 
21
  headers = {"User-Agent": "LLaVA Client"}
22
 
23
+ no_change_btn = gr.Button()
24
+ enable_btn = gr.Button(interactive=True)
25
+ disable_btn = gr.Button(interactive=False)
26
 
27
  priority = {
28
  "vicuna-13b": "aaaaaaa",
 
59
  def load_demo(url_params, request: gr.Request):
60
  logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
61
 
62
+ dropdown_update = gr.Dropdown(visible=True)
63
  if "model" in url_params:
64
  model = url_params["model"]
65
  if model in models:
66
+ dropdown_update = gr.Dropdown(value=model, visible=True)
 
67
 
68
  state = default_conversation.copy()
69
  return state, dropdown_update
 
73
  logger.info(f"load_demo. ip: {request.client.host}")
74
  models = get_model_list()
75
  state = default_conversation.copy()
76
+ dropdown_update = gr.Dropdown(
77
  choices=models,
78
  value=models[0] if len(models) > 0 else ""
79
  )
 
124
  logger.info(f"clear_history. ip: {request.client.host}")
125
  state = default_conversation.copy()
126
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
127
+
 
128
 
129
  def add_text(state, text, image, image_process_mode, request: gr.Request):
130
  logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
 
152
  state.skip_next = False
153
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
157
  logger.info(f"http_bot. ip: {request.client.host}")
 
158
  start_tstamp = time.time()
159
  model_name = model_selector
160
+
161
  if state.skip_next:
162
  # This generate call is skipped due to invalid inputs
 
163
  yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
164
  return
165
 
 
168
  if "llava" in model_name.lower():
169
  if 'llama-2' in model_name.lower():
170
  template_name = "llava_llama_2"
171
+ elif "mistral" in model_name.lower() or "mixtral" in model_name.lower():
172
+ if 'orca' in model_name.lower():
173
+ template_name = "mistral_orca"
174
+ elif 'hermes' in model_name.lower():
175
+ template_name = "chatml_direct"
176
+ else:
177
+ template_name = "mistral_instruct"
178
+ elif 'llava-v1.6-34b' in model_name.lower():
179
+ template_name = "chatml_direct"
180
  elif "v1" in model_name.lower():
181
  if 'mmtag' in model_name.lower():
182
  template_name = "v1_mmtag"
 
199
  template_name = "llama_2"
200
  else:
201
  template_name = "vicuna_v1"
 
202
  new_state = conv_templates[template_name].copy()
203
  new_state.append_message(new_state.roles[0], state.messages[-2][1])
204
  new_state.append_message(new_state.roles[1], None)
 
214
  # No available worker
215
  if worker_addr == "":
216
  state.messages[-1][-1] = server_error_msg
 
217
  yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
218
  return
219
 
 
240
  "images": f'List of {len(state.get_images())} images: {all_image_hash}',
241
  }
242
  logger.info(f"==== request ====\n{pload}")
243
+
244
  pload['images'] = state.get_images()
245
 
246
  state.messages[-1][-1] = "▌"
247
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
248
+
249
  try:
250
  # Stream output
251
  response = requests.post(worker_addr + "/worker_generate_stream",
 
289
 
290
  title_markdown = ("""
291
  Most Up To Date Scripts On : https://www.patreon.com/posts/sota-very-best-90744385 \n
292
+ Original Project : https://llava-vl.github.io\n
293
+ REFRESH PAGE AFTER PART 3 TO SEE LOADED MODEL
294
  """)
295
 
296
  tos_markdown = ("""
 
315
 
316
  """
317
 
318
+ def batch_process_images(folder_path, textbox, model_selector, temperature, top_p, max_output_tokens, request: gr.Request):
319
+ print("Starting batch processing of images")
320
+
321
+ # Initialize counters and timer
322
+ image_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif'))]
323
+ total_images = len(image_files)
324
+ processed_images = 0
325
+ total_processing_time = 0
326
+
327
+ # Process each image file
328
+ for filename in image_files:
329
+ image_path = os.path.join(folder_path, filename)
330
+ start_time = time.time()
331
+
332
+ with Image.open(image_path) as image:
333
+ state = default_conversation.copy()
334
+ state, _, _, _, _, _, _, _, _ = add_text(state, textbox, image, "Default", request)
335
+
336
+ # Call http_bot and iterate over the generator
337
+ response_text = ""
338
+ for state_update in http_bot(state, model_selector, temperature, top_p, max_output_tokens, request):
339
+ # Update state and extract response text
340
+ state, chatbot_output, *_ = state_update
341
+ response_text = chatbot_output
342
+
343
+ # Save the final response to a file
344
+ try:
345
+ with open(os.path.splitext(image_path)[0] + '.txt', 'w') as f:
346
+ f.write(response_text[0][1])
347
+ except Exception as e:
348
+ print(f"An error occurred: {e}")
349
 
350
+ # Update processing information
351
+ processed_images += 1
352
+ processing_time = time.time() - start_time
353
+ total_processing_time += processing_time
354
+ average_processing_time = total_processing_time / processed_images
355
+ images_left = total_images - processed_images
356
+ eta_seconds = average_processing_time * images_left
357
+ eta = datetime.timedelta(seconds=int(eta_seconds))
358
 
359
+ # Display progress information
360
+ print(f"{processed_images}/{total_images} images processed, {images_left} left, average process time {average_processing_time :.2f} seconds, ETA: {str(eta)}")
361
 
362
+ return "Batch processing completed."
363
 
364
+ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
365
  textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
 
 
366
  folder_input = gr.Textbox(label="Enter Folder Path for Batch Processing")
367
  batch_btn = gr.Button("Batch Process")
 
368
  with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
369
  state = gr.State()
370
 
 
387
  value="Default",
388
  label="Preprocess for non-square image", visible=False)
389
 
390
+ if cur_dir is None:
391
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
392
  gr.Examples(examples=[
393
  [f"{cur_dir}/examples/extreme_ironing.jpg", "just caption the image with details, colors, items, objects, emotions, art style, drawing style and objects but do not add any description or comment. do not miss any item in the given image"],
394
  ], inputs=[imagebox, textbox])
 
399
  max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
400
 
401
  with gr.Column(scale=8):
402
+ chatbot = gr.Chatbot(
403
+ elem_id="chatbot",
404
+ label="LLaVA Chatbot",
405
+ height=650,
406
+ layout="panel",
407
+ )
408
  with gr.Row():
409
  with gr.Column(scale=8):
410
  textbox.render()
 
414
  upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
415
  downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
416
  flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
417
+ #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
418
  regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
419
  clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
420
 
421
+ if not embed_mode:
422
+ gr.Markdown(tos_markdown)
423
+ gr.Markdown(learn_more_markdown)
424
  url_params = gr.JSON(visible=False)
425
 
426
  # Add new components for batch processing
 
440
  upvote_btn.click(
441
  upvote_last_response,
442
  [state, model_selector],
443
+ [textbox, upvote_btn, downvote_btn, flag_btn]
 
444
  )
445
  downvote_btn.click(
446
  downvote_last_response,
447
  [state, model_selector],
448
+ [textbox, upvote_btn, downvote_btn, flag_btn]
 
449
  )
450
  flag_btn.click(
451
  flag_last_response,
452
  [state, model_selector],
453
+ [textbox, upvote_btn, downvote_btn, flag_btn]
 
454
  )
455
 
456
  regenerate_btn.click(
457
  regenerate,
458
  [state, image_process_mode],
459
+ [state, chatbot, textbox, imagebox] + btn_list
 
460
  ).then(
461
  http_bot,
462
  [state, model_selector, temperature, top_p, max_output_tokens],
463
+ [state, chatbot] + btn_list,
464
+ concurrency_limit=concurrency_count
465
  )
466
 
467
  clear_btn.click(
 
479
  ).then(
480
  http_bot,
481
  [state, model_selector, temperature, top_p, max_output_tokens],
482
+ [state, chatbot] + btn_list,
483
+ concurrency_limit=concurrency_count
484
  )
485
 
486
  submit_btn.click(
487
  add_text,
488
  [state, textbox, imagebox, image_process_mode],
489
+ [state, chatbot, textbox, imagebox] + btn_list
 
490
  ).then(
491
  http_bot,
492
  [state, model_selector, temperature, top_p, max_output_tokens],
493
+ [state, chatbot] + btn_list,
494
+ concurrency_limit=concurrency_count
495
  )
496
 
497
  if args.model_list_mode == "once":
 
499
  load_demo,
500
  [url_params],
501
  [state, model_selector],
502
+ _js=get_window_url_params
 
503
  )
504
  elif args.model_list_mode == "reload":
505
  demo.load(
 
519
  parser.add_argument("--host", type=str, default="0.0.0.0")
520
  parser.add_argument("--port", type=int)
521
  parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
522
+ parser.add_argument("--concurrency-count", type=int, default=16)
523
+ parser.add_argument("--model-list-mode", type=str, default="once",
524
  choices=["once", "reload"])
525
  parser.add_argument("--share", action="store_true")
526
  parser.add_argument("--moderate", action="store_true")
 
531
  models = get_model_list()
532
 
533
  logger.info(args)
534
+ demo = build_demo(args.embed, concurrency_count=args.concurrency_count)
535
  demo.queue(
 
536
  api_open=False
537
  ).launch(
538
  server_name=args.host,
539
  server_port=args.port,
540
+ share=args.share,
541
+ inbrowser=True
542
  )