Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ import logging
|
|
6 |
import os
|
7 |
import sys
|
8 |
import time
|
9 |
-
import spaces
|
10 |
import gradio as gr
|
11 |
import torch
|
12 |
from PIL import Image
|
@@ -34,6 +34,10 @@ logger = logging.getLogger("gradio_web_server")
|
|
34 |
LOGDIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs")
|
35 |
os.makedirs(os.path.join(LOGDIR, "serve_images"), exist_ok=True)
|
36 |
|
|
|
|
|
|
|
|
|
37 |
default_taxonomy = policy_v1
|
38 |
|
39 |
|
@@ -147,6 +151,7 @@ disable_btn = gr.Button(interactive=False)
|
|
147 |
|
148 |
|
149 |
# Model loading function
|
|
|
150 |
def load_model(model_path):
|
151 |
global tokenizer, model, processor, context_len
|
152 |
|
@@ -183,16 +188,6 @@ def load_model(model_path):
|
|
183 |
return # Remove return value to avoid Gradio warnings
|
184 |
|
185 |
|
186 |
-
def get_model_list():
|
187 |
-
models = [
|
188 |
-
'AIML-TUDA/QwenGuard-v1.2-3B',
|
189 |
-
'AIML-TUDA/QwenGuard-v1.2-7B',
|
190 |
-
'AIML-TUDA/LlavaGuard-v1.2-0.5B-OV-hf',
|
191 |
-
'AIML-TUDA/LlavaGuard-v1.2-7B-OV-hf',
|
192 |
-
]
|
193 |
-
return models
|
194 |
-
|
195 |
-
|
196 |
def get_conv_log_filename():
|
197 |
t = datetime.datetime.now()
|
198 |
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
@@ -206,7 +201,7 @@ def run_inference(prompt, image, temperature=0.2, top_p=0.95, max_tokens=512):
|
|
206 |
global model, tokenizer, processor
|
207 |
|
208 |
if model is None or processor is None:
|
209 |
-
return "Model not loaded. Please
|
210 |
try:
|
211 |
# Check if it's a Qwen model
|
212 |
if isinstance(model, Qwen2_5_VLForConditionalGeneration):
|
@@ -290,57 +285,43 @@ function() {
|
|
290 |
|
291 |
def load_demo(url_params, request: gr.Request):
|
292 |
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
293 |
-
models = get_model_list()
|
294 |
-
|
295 |
-
dropdown_update = gr.Dropdown(visible=True)
|
296 |
-
if "model" in url_params:
|
297 |
-
model = url_params["model"]
|
298 |
-
if model in models:
|
299 |
-
dropdown_update = gr.Dropdown(value=model, visible=True)
|
300 |
-
load_model(model)
|
301 |
-
|
302 |
state = default_conversation.copy()
|
303 |
-
return state
|
304 |
|
305 |
|
306 |
-
def
|
307 |
logger.info(f"load_demo. ip: {request.client.host}")
|
308 |
-
models = get_model_list()
|
309 |
state = default_conversation.copy()
|
310 |
-
|
311 |
-
choices=models,
|
312 |
-
value=models[0] if len(models) > 0 else ""
|
313 |
-
)
|
314 |
-
return state, dropdown_update
|
315 |
|
316 |
|
317 |
-
def vote_last_response(state, vote_type,
|
318 |
with open(get_conv_log_filename(), "a") as fout:
|
319 |
data = {
|
320 |
"tstamp": round(time.time(), 4),
|
321 |
"type": vote_type,
|
322 |
-
"model":
|
323 |
"state": state.dict(),
|
324 |
"ip": request.client.host,
|
325 |
}
|
326 |
fout.write(json.dumps(data) + "\n")
|
327 |
|
328 |
|
329 |
-
def upvote_last_response(state,
|
330 |
logger.info(f"upvote. ip: {request.client.host}")
|
331 |
-
vote_last_response(state, "upvote",
|
332 |
return ("",) + (disable_btn,) * 3
|
333 |
|
334 |
|
335 |
-
def downvote_last_response(state,
|
336 |
logger.info(f"downvote. ip: {request.client.host}")
|
337 |
-
vote_last_response(state, "downvote",
|
338 |
return ("",) + (disable_btn,) * 3
|
339 |
|
340 |
|
341 |
-
def flag_last_response(state,
|
342 |
logger.info(f"flag. ip: {request.client.host}")
|
343 |
-
vote_last_response(state, "flag",
|
344 |
return ("",) + (disable_btn,) * 3
|
345 |
|
346 |
|
@@ -390,7 +371,7 @@ def add_text(state, text, image, image_process_mode, request: gr.Request):
|
|
390 |
return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5
|
391 |
|
392 |
|
393 |
-
def llava_bot(state,
|
394 |
start_tstamp = time.time()
|
395 |
|
396 |
if state.skip_next:
|
@@ -410,10 +391,6 @@ def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request
|
|
410 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
411 |
return
|
412 |
|
413 |
-
# Load model if needed
|
414 |
-
if model is None or model_selector != getattr(model, "_name_or_path", ""):
|
415 |
-
load_model(model_selector)
|
416 |
-
|
417 |
# Run inference
|
418 |
output = run_inference(prompt, all_images[0], temperature, top_p, max_new_tokens)
|
419 |
|
@@ -434,7 +411,7 @@ def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request
|
|
434 |
data = {
|
435 |
"tstamp": round(finish_tstamp, 4),
|
436 |
"type": "chat",
|
437 |
-
"model":
|
438 |
"start": round(start_tstamp, 4),
|
439 |
"finish": round(finish_tstamp, 4),
|
440 |
"state": state.dict(),
|
@@ -477,8 +454,6 @@ block_css = """
|
|
477 |
|
478 |
|
479 |
def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
480 |
-
models = get_model_list()
|
481 |
-
|
482 |
with gr.Blocks(title="LlavaGuard", theme=gr.themes.Default(), css=block_css) as demo:
|
483 |
state = gr.State()
|
484 |
|
@@ -487,13 +462,7 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
|
487 |
|
488 |
with gr.Row():
|
489 |
with gr.Column(scale=3):
|
490 |
-
|
491 |
-
model_selector = gr.Dropdown(
|
492 |
-
choices=models,
|
493 |
-
value=models[0] if len(models) > 0 else "",
|
494 |
-
interactive=True,
|
495 |
-
show_label=False,
|
496 |
-
container=False)
|
497 |
|
498 |
imagebox = gr.Image(type="pil", label="Image", container=False)
|
499 |
image_process_mode = gr.Radio(
|
@@ -559,35 +528,29 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
|
559 |
|
560 |
upvote_btn.click(
|
561 |
upvote_last_response,
|
562 |
-
[state
|
563 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
564 |
)
|
565 |
|
566 |
downvote_btn.click(
|
567 |
downvote_last_response,
|
568 |
-
[state
|
569 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
570 |
)
|
571 |
|
572 |
flag_btn.click(
|
573 |
flag_last_response,
|
574 |
-
[state
|
575 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
576 |
)
|
577 |
|
578 |
-
model_selector.change(
|
579 |
-
load_model,
|
580 |
-
[model_selector],
|
581 |
-
None
|
582 |
-
)
|
583 |
-
|
584 |
regenerate_btn.click(
|
585 |
regenerate,
|
586 |
[state, image_process_mode],
|
587 |
[state, chatbot, textbox, imagebox] + btn_list
|
588 |
).then(
|
589 |
llava_bot,
|
590 |
-
[state,
|
591 |
[state, chatbot] + btn_list,
|
592 |
concurrency_limit=concurrency_count
|
593 |
)
|
@@ -606,7 +569,7 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
|
606 |
queue=False
|
607 |
).then(
|
608 |
llava_bot,
|
609 |
-
[state,
|
610 |
[state, chatbot] + btn_list,
|
611 |
concurrency_limit=concurrency_count
|
612 |
)
|
@@ -617,15 +580,15 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
|
617 |
[state, chatbot, textbox, imagebox] + btn_list
|
618 |
).then(
|
619 |
llava_bot,
|
620 |
-
[state,
|
621 |
[state, chatbot] + btn_list,
|
622 |
concurrency_limit=concurrency_count
|
623 |
)
|
624 |
|
625 |
demo.load(
|
626 |
-
|
627 |
None,
|
628 |
-
[state
|
629 |
queue=False
|
630 |
)
|
631 |
|
@@ -658,6 +621,8 @@ if api_key:
|
|
658 |
login(token=api_key)
|
659 |
logger.info("Logged in to Hugging Face Hub")
|
660 |
|
|
|
|
|
661 |
|
662 |
demo = build_demo(embed_mode=args.embed, cur_dir='./', concurrency_count=args.concurrency_count)
|
663 |
demo.queue(
|
@@ -667,4 +632,4 @@ demo.queue(
|
|
667 |
server_name=args.host,
|
668 |
server_port=args.port,
|
669 |
share=args.share
|
670 |
-
)
|
|
|
6 |
import os
|
7 |
import sys
|
8 |
import time
|
9 |
+
from huggingface_hub import spaces
|
10 |
import gradio as gr
|
11 |
import torch
|
12 |
from PIL import Image
|
|
|
34 |
LOGDIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs")
|
35 |
os.makedirs(os.path.join(LOGDIR, "serve_images"), exist_ok=True)
|
36 |
|
37 |
+
# Get default model from environment variable or use a fallback
|
38 |
+
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "AIML-TUDA/LlavaGuard-v1.2-7B-OV-hf")
|
39 |
+
logger.info(f"Using model: {DEFAULT_MODEL}")
|
40 |
+
|
41 |
default_taxonomy = policy_v1
|
42 |
|
43 |
|
|
|
151 |
|
152 |
|
153 |
# Model loading function
|
154 |
+
@spaces.GPU
|
155 |
def load_model(model_path):
|
156 |
global tokenizer, model, processor, context_len
|
157 |
|
|
|
188 |
return # Remove return value to avoid Gradio warnings
|
189 |
|
190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
def get_conv_log_filename():
|
192 |
t = datetime.datetime.now()
|
193 |
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
|
|
201 |
global model, tokenizer, processor
|
202 |
|
203 |
if model is None or processor is None:
|
204 |
+
return "Model not loaded. Please wait for model to initialize."
|
205 |
try:
|
206 |
# Check if it's a Qwen model
|
207 |
if isinstance(model, Qwen2_5_VLForConditionalGeneration):
|
|
|
285 |
|
286 |
def load_demo(url_params, request: gr.Request):
|
287 |
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
state = default_conversation.copy()
|
289 |
+
return state
|
290 |
|
291 |
|
292 |
+
def load_demo_refresh(request: gr.Request):
|
293 |
logger.info(f"load_demo. ip: {request.client.host}")
|
|
|
294 |
state = default_conversation.copy()
|
295 |
+
return state
|
|
|
|
|
|
|
|
|
296 |
|
297 |
|
298 |
+
def vote_last_response(state, vote_type, request: gr.Request):
|
299 |
with open(get_conv_log_filename(), "a") as fout:
|
300 |
data = {
|
301 |
"tstamp": round(time.time(), 4),
|
302 |
"type": vote_type,
|
303 |
+
"model": DEFAULT_MODEL,
|
304 |
"state": state.dict(),
|
305 |
"ip": request.client.host,
|
306 |
}
|
307 |
fout.write(json.dumps(data) + "\n")
|
308 |
|
309 |
|
310 |
+
def upvote_last_response(state, request: gr.Request):
|
311 |
logger.info(f"upvote. ip: {request.client.host}")
|
312 |
+
vote_last_response(state, "upvote", request)
|
313 |
return ("",) + (disable_btn,) * 3
|
314 |
|
315 |
|
316 |
+
def downvote_last_response(state, request: gr.Request):
|
317 |
logger.info(f"downvote. ip: {request.client.host}")
|
318 |
+
vote_last_response(state, "downvote", request)
|
319 |
return ("",) + (disable_btn,) * 3
|
320 |
|
321 |
|
322 |
+
def flag_last_response(state, request: gr.Request):
|
323 |
logger.info(f"flag. ip: {request.client.host}")
|
324 |
+
vote_last_response(state, "flag", request)
|
325 |
return ("",) + (disable_btn,) * 3
|
326 |
|
327 |
|
|
|
371 |
return (state, state.to_gradio_chatbot(), default_taxonomy, None) + (disable_btn,) * 5
|
372 |
|
373 |
|
374 |
+
def llava_bot(state, temperature, top_p, max_new_tokens, request: gr.Request):
|
375 |
start_tstamp = time.time()
|
376 |
|
377 |
if state.skip_next:
|
|
|
391 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
392 |
return
|
393 |
|
|
|
|
|
|
|
|
|
394 |
# Run inference
|
395 |
output = run_inference(prompt, all_images[0], temperature, top_p, max_new_tokens)
|
396 |
|
|
|
411 |
data = {
|
412 |
"tstamp": round(finish_tstamp, 4),
|
413 |
"type": "chat",
|
414 |
+
"model": DEFAULT_MODEL,
|
415 |
"start": round(start_tstamp, 4),
|
416 |
"finish": round(finish_tstamp, 4),
|
417 |
"state": state.dict(),
|
|
|
454 |
|
455 |
|
456 |
def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
|
|
|
|
457 |
with gr.Blocks(title="LlavaGuard", theme=gr.themes.Default(), css=block_css) as demo:
|
458 |
state = gr.State()
|
459 |
|
|
|
462 |
|
463 |
with gr.Row():
|
464 |
with gr.Column(scale=3):
|
465 |
+
# Model selector removed
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
|
467 |
imagebox = gr.Image(type="pil", label="Image", container=False)
|
468 |
image_process_mode = gr.Radio(
|
|
|
528 |
|
529 |
upvote_btn.click(
|
530 |
upvote_last_response,
|
531 |
+
[state],
|
532 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
533 |
)
|
534 |
|
535 |
downvote_btn.click(
|
536 |
downvote_last_response,
|
537 |
+
[state],
|
538 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
539 |
)
|
540 |
|
541 |
flag_btn.click(
|
542 |
flag_last_response,
|
543 |
+
[state],
|
544 |
[textbox, upvote_btn, downvote_btn, flag_btn]
|
545 |
)
|
546 |
|
|
|
|
|
|
|
|
|
|
|
|
|
547 |
regenerate_btn.click(
|
548 |
regenerate,
|
549 |
[state, image_process_mode],
|
550 |
[state, chatbot, textbox, imagebox] + btn_list
|
551 |
).then(
|
552 |
llava_bot,
|
553 |
+
[state, temperature, top_p, max_output_tokens],
|
554 |
[state, chatbot] + btn_list,
|
555 |
concurrency_limit=concurrency_count
|
556 |
)
|
|
|
569 |
queue=False
|
570 |
).then(
|
571 |
llava_bot,
|
572 |
+
[state, temperature, top_p, max_output_tokens],
|
573 |
[state, chatbot] + btn_list,
|
574 |
concurrency_limit=concurrency_count
|
575 |
)
|
|
|
580 |
[state, chatbot, textbox, imagebox] + btn_list
|
581 |
).then(
|
582 |
llava_bot,
|
583 |
+
[state, temperature, top_p, max_output_tokens],
|
584 |
[state, chatbot] + btn_list,
|
585 |
concurrency_limit=concurrency_count
|
586 |
)
|
587 |
|
588 |
demo.load(
|
589 |
+
load_demo_refresh,
|
590 |
None,
|
591 |
+
[state],
|
592 |
queue=False
|
593 |
)
|
594 |
|
|
|
621 |
login(token=api_key)
|
622 |
logger.info("Logged in to Hugging Face Hub")
|
623 |
|
624 |
+
# Load model at startup
|
625 |
+
load_model(DEFAULT_MODEL)
|
626 |
|
627 |
demo = build_demo(embed_mode=args.embed, cur_dir='./', concurrency_count=args.concurrency_count)
|
628 |
demo.queue(
|
|
|
632 |
server_name=args.host,
|
633 |
server_port=args.port,
|
634 |
share=args.share
|
635 |
+
)
|