LukasHug commited on
Commit
6d594d5
·
verified ·
1 Parent(s): a51c9a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -67
app.py CHANGED
@@ -6,7 +6,7 @@ import logging
6
  import os
7
  import sys
8
  import time
9
- import spaces
10
  import gradio as gr
11
  import torch
12
  from PIL import Image
@@ -34,6 +34,10 @@ logger = logging.getLogger("gradio_web_server")
34
  LOGDIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs")
35
  os.makedirs(os.path.join(LOGDIR, "serve_images"), exist_ok=True)
36
 
 
 
 
 
37
  default_taxonomy = policy_v1
38
 
39
 
@@ -147,6 +151,7 @@ disable_btn = gr.Button(interactive=False)
147
 
148
 
149
  # Model loading function
 
150
  def load_model(model_path):
151
  global tokenizer, model, processor, context_len
152
 
@@ -183,16 +188,6 @@ def load_model(model_path):
183
  return # Remove return value to avoid Gradio warnings
184
 
185
 
186
- def get_model_list():
187
- models = [
188
- 'AIML-TUDA/QwenGuard-v1.2-3B',
189
- 'AIML-TUDA/QwenGuard-v1.2-7B',
190
- 'AIML-TUDA/LlavaGuard-v1.2-0.5B-OV-hf',
191
- 'AIML-TUDA/LlavaGuard-v1.2-7B-OV-hf',
192
- ]
193
- return models
194
-
195
-
196
  def get_conv_log_filename():
197
  t = datetime.datetime.now()
198
  name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
@@ -206,7 +201,7 @@ def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
206
  global model, tokenizer, processor
207
 
208
  if model is None or processor is None:
209
- return "Model not loaded. Please select a model first."
210
  try:
211
  # Check if it's a Qwen model
212
  if isinstance(model, Qwen2_5_VLForConditionalGeneration):
@@ -290,57 +285,43 @@ function() {
290
 
291
  def load_demo(url_params, request: gr.Request):
292
  logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
293
- models = get_model_list()
294
-
295
- dropdown_update = gr.Dropdown(visible=True)
296
- if "model" in url_params:
297
- model = url_params["model"]
298
- if model in models:
299
- dropdown_update = gr.Dropdown(value=model, visible=True)
300
- load_model(model)
301
-
302
  state = default_conversation.copy()
303
- return state, dropdown_update
304
 
305
 
306
- def load_demo_refresh_model_list(request: gr.Request):
307
  logger.info(f"load_demo. ip: {request.client.host}")
308
- models = get_model_list()
309
  state = default_conversation.copy()
310
- dropdown_update = gr.Dropdown(
311
- choices=models,
312
- value=models[0] if len(models) > 0 else ""
313
- )
314
- return state, dropdown_update
315
 
316
 
317
- def vote_last_response(state, vote_type, model_selector, request: gr.Request):
318
  with open(get_conv_log_filename(), "a") as fout:
319
  data = {
320
  "tstamp": round(time.time(), 4),
321
  "type": vote_type,
322
- "model": model_selector,
323
  "state": state.dict(),
324
  "ip": request.client.host,
325
  }
326
  fout.write(json.dumps(data) + "\n")
327
 
328
 
329
- def upvote_last_response(state, model_selector, request: gr.Request):
330
  logger.info(f"upvote. ip: {request.client.host}")
331
- vote_last_response(state, "upvote", model_selector, request)
332
  return ("",) + (disable_btn,) * 3
333
 
334
 
335
- def downvote_last_response(state, model_selector, request: gr.Request):
336
  logger.info(f"downvote. ip: {request.client.host}")
337
- vote_last_response(state, "downvote", model_selector, request)
338
  return ("",) + (disable_btn,) * 3
339
 
340
 
341
- def flag_last_response(state, model_selector, request: gr.Request):
342
  logger.info(f"flag. ip: {request.client.host}")
343
- vote_last_response(state, "flag", model_selector, request)
344
  return ("",) + (disable_btn,) * 3
345
 
346
 
@@ -390,7 +371,7 @@ def add_text(state, text, image, image_process_mode, request: gr.Request):
390
  return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5
391
 
392
 
393
- def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
394
  start_tstamp = time.time()
395
 
396
  if state.skip_next:
@@ -410,10 +391,6 @@ def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request
410
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
411
  return
412
 
413
- # Load model if needed
414
- if model is None or model_selector != getattr(model, "_name_or_path", ""):
415
- load_model(model_selector)
416
-
417
  # Run inference
418
  output = run_inference(prompt, all_images[0], temperature, top_p, max_new_tokens)
419
 
@@ -434,7 +411,7 @@ def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request
434
  data = {
435
  "tstamp": round(finish_tstamp, 4),
436
  "type": "chat",
437
- "model": model_selector,
438
  "start": round(start_tstamp, 4),
439
  "finish": round(finish_tstamp, 4),
440
  "state": state.dict(),
@@ -477,8 +454,6 @@ block_css = """
477
 
478
 
479
  def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
480
- models = get_model_list()
481
-
482
  with gr.Blocks(title="LlavaGuard", theme=gr.themes.Default(), css=block_css) as demo:
483
  state = gr.State()
484
 
@@ -487,13 +462,7 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
487
 
488
  with gr.Row():
489
  with gr.Column(scale=3):
490
- with gr.Row(elem_id="model_selector_row"):
491
- model_selector = gr.Dropdown(
492
- choices=models,
493
- value=models[0] if len(models) > 0 else "",
494
- interactive=True,
495
- show_label=False,
496
- container=False)
497
 
498
  imagebox = gr.Image(type="pil", label="Image", container=False)
499
  image_process_mode = gr.Radio(
@@ -559,35 +528,29 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
559
 
560
  upvote_btn.click(
561
  upvote_last_response,
562
- [state, model_selector],
563
  [textbox, upvote_btn, downvote_btn, flag_btn]
564
  )
565
 
566
  downvote_btn.click(
567
  downvote_last_response,
568
- [state, model_selector],
569
  [textbox, upvote_btn, downvote_btn, flag_btn]
570
  )
571
 
572
  flag_btn.click(
573
  flag_last_response,
574
- [state, model_selector],
575
  [textbox, upvote_btn, downvote_btn, flag_btn]
576
  )
577
 
578
- model_selector.change(
579
- load_model,
580
- [model_selector],
581
- None
582
- )
583
-
584
  regenerate_btn.click(
585
  regenerate,
586
  [state, image_process_mode],
587
  [state, chatbot, textbox, imagebox] + btn_list
588
  ).then(
589
  llava_bot,
590
- [state, model_selector, temperature, top_p, max_output_tokens],
591
  [state, chatbot] + btn_list,
592
  concurrency_limit=concurrency_count
593
  )
@@ -606,7 +569,7 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
606
  queue=False
607
  ).then(
608
  llava_bot,
609
- [state, model_selector, temperature, top_p, max_output_tokens],
610
  [state, chatbot] + btn_list,
611
  concurrency_limit=concurrency_count
612
  )
@@ -617,15 +580,15 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
617
  [state, chatbot, textbox, imagebox] + btn_list
618
  ).then(
619
  llava_bot,
620
- [state, model_selector, temperature, top_p, max_output_tokens],
621
  [state, chatbot] + btn_list,
622
  concurrency_limit=concurrency_count
623
  )
624
 
625
  demo.load(
626
- load_demo_refresh_model_list,
627
  None,
628
- [state, model_selector],
629
  queue=False
630
  )
631
 
@@ -658,6 +621,8 @@ if api_key:
658
  login(token=api_key)
659
  logger.info("Logged in to Hugging Face Hub")
660
 
 
 
661
 
662
  demo = build_demo(embed_mode=args.embed, cur_dir='./', concurrency_count=args.concurrency_count)
663
  demo.queue(
@@ -667,4 +632,4 @@ demo.queue(
667
  server_name=args.host,
668
  server_port=args.port,
669
  share=args.share
670
- )
 
6
  import os
7
  import sys
8
  import time
9
+ from huggingface_hub import spaces
10
  import gradio as gr
11
  import torch
12
  from PIL import Image
 
34
  LOGDIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs")
35
  os.makedirs(os.path.join(LOGDIR, "serve_images"), exist_ok=True)
36
 
37
+ # Get default model from environment variable or use a fallback
38
+ DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "AIML-TUDA/LlavaGuard-v1.2-7B-OV-hf")
39
+ logger.info(f"Using model: {DEFAULT_MODEL}")
40
+
41
  default_taxonomy = policy_v1
42
 
43
 
 
151
 
152
 
153
  # Model loading function
154
+ @spaces.GPU
155
  def load_model(model_path):
156
  global tokenizer, model, processor, context_len
157
 
 
188
  return # Remove return value to avoid Gradio warnings
189
 
190
 
 
 
 
 
 
 
 
 
 
 
191
  def get_conv_log_filename():
192
  t = datetime.datetime.now()
193
  name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
 
201
  global model, tokenizer, processor
202
 
203
  if model is None or processor is None:
204
+ return "Model not loaded. Please wait for model to initialize."
205
  try:
206
  # Check if it's a Qwen model
207
  if isinstance(model, Qwen2_5_VLForConditionalGeneration):
 
285
 
286
  def load_demo(url_params, request: gr.Request):
287
  logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
 
 
 
 
 
 
 
 
 
288
  state = default_conversation.copy()
289
+ return state
290
 
291
 
292
+ def load_demo_refresh(request: gr.Request):
293
  logger.info(f"load_demo. ip: {request.client.host}")
 
294
  state = default_conversation.copy()
295
+ return state
 
 
 
 
296
 
297
 
298
+ def vote_last_response(state, vote_type, request: gr.Request):
299
  with open(get_conv_log_filename(), "a") as fout:
300
  data = {
301
  "tstamp": round(time.time(), 4),
302
  "type": vote_type,
303
+ "model": DEFAULT_MODEL,
304
  "state": state.dict(),
305
  "ip": request.client.host,
306
  }
307
  fout.write(json.dumps(data) + "\n")
308
 
309
 
310
+ def upvote_last_response(state, request: gr.Request):
311
  logger.info(f"upvote. ip: {request.client.host}")
312
+ vote_last_response(state, "upvote", request)
313
  return ("",) + (disable_btn,) * 3
314
 
315
 
316
+ def downvote_last_response(state, request: gr.Request):
317
  logger.info(f"downvote. ip: {request.client.host}")
318
+ vote_last_response(state, "downvote", request)
319
  return ("",) + (disable_btn,) * 3
320
 
321
 
322
+ def flag_last_response(state, request: gr.Request):
323
  logger.info(f"flag. ip: {request.client.host}")
324
+ vote_last_response(state, "flag", request)
325
  return ("",) + (disable_btn,) * 3
326
 
327
 
 
371
  return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5
372
 
373
 
374
+ def llava_bot(state, temperature, top_p, max_new_tokens, request: gr.Request):
375
  start_tstamp = time.time()
376
 
377
  if state.skip_next:
 
391
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
392
  return
393
 
 
 
 
 
394
  # Run inference
395
  output = run_inference(prompt, all_images[0], temperature, top_p, max_new_tokens)
396
 
 
411
  data = {
412
  "tstamp": round(finish_tstamp, 4),
413
  "type": "chat",
414
+ "model": DEFAULT_MODEL,
415
  "start": round(start_tstamp, 4),
416
  "finish": round(finish_tstamp, 4),
417
  "state": state.dict(),
 
454
 
455
 
456
  def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
 
 
457
  with gr.Blocks(title="LlavaGuard", theme=gr.themes.Default(), css=block_css) as demo:
458
  state = gr.State()
459
 
 
462
 
463
  with gr.Row():
464
  with gr.Column(scale=3):
465
+ # Model selector removed
 
 
 
 
 
 
466
 
467
  imagebox = gr.Image(type="pil", label="Image", container=False)
468
  image_process_mode = gr.Radio(
 
528
 
529
  upvote_btn.click(
530
  upvote_last_response,
531
+ [state],
532
  [textbox, upvote_btn, downvote_btn, flag_btn]
533
  )
534
 
535
  downvote_btn.click(
536
  downvote_last_response,
537
+ [state],
538
  [textbox, upvote_btn, downvote_btn, flag_btn]
539
  )
540
 
541
  flag_btn.click(
542
  flag_last_response,
543
+ [state],
544
  [textbox, upvote_btn, downvote_btn, flag_btn]
545
  )
546
 
 
 
 
 
 
 
547
  regenerate_btn.click(
548
  regenerate,
549
  [state, image_process_mode],
550
  [state, chatbot, textbox, imagebox] + btn_list
551
  ).then(
552
  llava_bot,
553
+ [state, temperature, top_p, max_output_tokens],
554
  [state, chatbot] + btn_list,
555
  concurrency_limit=concurrency_count
556
  )
 
569
  queue=False
570
  ).then(
571
  llava_bot,
572
+ [state, temperature, top_p, max_output_tokens],
573
  [state, chatbot] + btn_list,
574
  concurrency_limit=concurrency_count
575
  )
 
580
  [state, chatbot, textbox, imagebox] + btn_list
581
  ).then(
582
  llava_bot,
583
+ [state, temperature, top_p, max_output_tokens],
584
  [state, chatbot] + btn_list,
585
  concurrency_limit=concurrency_count
586
  )
587
 
588
  demo.load(
589
+ load_demo_refresh,
590
  None,
591
+ [state],
592
  queue=False
593
  )
594
 
 
621
  login(token=api_key)
622
  logger.info("Logged in to Hugging Face Hub")
623
 
624
+ # Load model at startup
625
+ load_model(DEFAULT_MODEL)
626
 
627
  demo = build_demo(embed_mode=args.embed, cur_dir='./', concurrency_count=args.concurrency_count)
628
  demo.queue(
 
632
  server_name=args.host,
633
  server_port=args.port,
634
  share=args.share
635
+ )