File size: 5,594 Bytes
2b0bfdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#!/usr/bin/env python
# encoding: utf-8
import gradio as gr
from PIL import Image
import traceback
import re
import torch
import argparse
from transformers import AutoModel, AutoTokenizer

# Argparser
parser = argparse.ArgumentParser(description='demo')
parser.add_argument('--device', type=str, default='cpu', help='cpu')
parser.add_argument('--dtype', type=str, default='fp32', help='fp32')
args = parser.parse_args()
device = args.device
assert device in ['cpu']

# Set dtype
if args.dtype == 'fp32':
    dtype = torch.float32
else:
    dtype = torch.float16

# Load model
model_path = 'openbmb/MiniCPM-V-2'
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=dtype)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

model = model.to(device=device)
model.eval()

ERROR_MSG = "Error, please retry"
model_name = 'MiniCPM-V 2.0'

# UI Components
form_radio = {
    'choices': ['Beam Search', 'Sampling'],
    'value': 'Sampling',
    'interactive': True,
    'label': 'Decode Type'
}

# Sliders and their settings
num_beams_slider = {'minimum': 0, 'maximum': 5, 'value': 3, 'step': 1, 'interactive': True, 'label': 'Num Beams'}
repetition_penalty_slider = {'minimum': 0, 'maximum': 3, 'value': 1.2, 'step': 0.01, 'interactive': True, 'label': 'Repetition Penalty'}
repetition_penalty_slider2 = {'minimum': 0, 'maximum': 3, 'value': 1.05, 'step': 0.01, 'interactive': True, 'label': 'Repetition Penalty'}
max_new_tokens_slider = {'minimum': 1, 'maximum': 4096, 'value': 1024, 'step': 1, 'interactive': True, 'label': 'Max New Tokens'}
top_p_slider = {'minimum': 0, 'maximum': 1, 'value': 0.8, 'step': 0.05, 'interactive': True, 'label': 'Top P'}
top_k_slider = {'minimum': 0, 'maximum': 200, 'value': 100, 'step': 1, 'interactive': True, 'label': 'Top K'}
temperature_slider = {'minimum': 0, 'maximum': 2, 'value': 0.7, 'step': 0.05, 'interactive': True, 'label': 'Temperature'}

def create_component(params, comp='Slider'):
    if comp == 'Slider':
        return gr.Slider(**params)
    elif comp == 'Radio':
        return gr.Radio(choices=params['choices'], value=params['value'], interactive=params['interactive'], label=params['label'])
    elif comp == 'Button':
        return gr.Button(value=params['value'], interactive=True)

def chat(img, msgs, ctx, params=None):
    default_params = {"num_beams": 3, "repetition_penalty": 1.2, "max_new_tokens": 1024}
    if params is None:
        params = default_params
    if img is None:
        return -1, "Error, invalid image, please upload a new image", None, None
    try:
        image = img.convert('RGB')
        answer, context, _ = model.chat(image=image, msgs=msgs, context=None, tokenizer=tokenizer, **params)
        res = re.sub(r'(<box>.*</box>)', '', answer).replace('<ref>', '').replace('</ref>', '').replace('<box>', '').replace('</box>', '')
        return 0, res, None, None
    except Exception as err:
        print(err)
        traceback.print_exc()
        return -1, ERROR_MSG, None, None

def upload_img(image, _chatbot, _app_session):
    image = Image.fromarray(image)
    _app_session['sts'] = None
    _app_session['ctx'] = []
    _app_session['img'] = image
    _chatbot.append(('', 'Image uploaded successfully, you can talk to me now'))
    return _chatbot, _app_session

def respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
    if _app_cfg.get('ctx', None) is None:
        _chat_bot.append((_question, 'Please upload an image to start'))
        return '', _chat_bot, _app_cfg

    _context = _app_cfg['ctx'].copy()
    _context.append({"role": "user", "content": _question})

    if params_form == 'Beam Search':
        params = {'sampling': False, 'num_beams': num_beams, 'repetition_penalty': repetition_penalty, "max_new_tokens": 896}
    else:  # Ensure this block is executed for Sampling
        params = {
            'sampling': True,
            'top_p': top_p,
            'top_k': top_k,
            'temperature': temperature,
            'repetition_penalty': repetition_penalty_2,
            "max_new_tokens": 896
        }
    
    code, _answer, _, sts = chat(_app_cfg['img'], _context, None, params)
    
    _context.append({"role": "assistant", "content": _answer}) 
    _chat_bot.append((_question, _answer))
    if code == 0:
        _app_cfg['ctx'] = _context
        _app_cfg['sts'] = sts
    return '', _chat_bot, _app_cfg

def clear(chat_bot, app_session):
    app_session['img'] = None
    chat_bot.clear()
    return chat_bot

with gr.Blocks() as demo:
    gr.Markdown("<h1 style='text-align: center;'>Medical Assistant</h1>")

    with gr.Row():
        with gr.Column(scale=2, min_width=300):
            app_session = gr.State({'sts': None, 'ctx': None, 'img': None})
            bt_pic = gr.Image(label="Upload an image to start")
            txt_message = gr.Textbox(label="Ask your question...")
        
        with gr.Column(scale=2, min_width=300):
            chat_bot = gr.Chatbot(label=f"Chatbot")
            clear_button = gr.Button(value='Clear')
            txt_message.submit(
                respond, 
                [txt_message, chat_bot, app_session], 
                [txt_message, chat_bot, app_session]
            )

            bt_pic.upload(lambda: None, None, chat_bot, queue=False).then(upload_img, inputs=[bt_pic, chat_bot, app_session], outputs=[chat_bot, app_session])
            clear_button.click(clear, [chat_bot, app_session], chat_bot)

# Launch
demo.launch(share=True, debug=True, show_api=False)