Mister56 commited on
Commit
2b0bfdd
·
verified ·
1 Parent(s): 0459724

Initial commit

Browse files
Files changed (2) hide show
  1. requirements.txt +35 -0
  2. web_demo3.py +141 -0
requirements.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ packaging==23.2
2
+ addict==2.4.0
3
+ editdistance==0.6.2
4
+ einops==0.7.0
5
+ fairscale==0.4.0
6
+ jsonlines==4.0.0
7
+ markdown2==2.4.10
8
+ matplotlib==3.7.4
9
+ more_itertools==10.1.0
10
+ nltk==3.8.1
11
+ numpy==1.24.4
12
+ opencv_python_headless==4.5.5.64
13
+ openpyxl==3.1.2
14
+ Pillow==10.1.0
15
+ sacrebleu==2.3.2
16
+ seaborn==0.13.0
17
+ shortuuid==1.0.11
18
+ #spacy==3.7.2
19
+ timm==0.9.10
20
+ torch==2.1.2
21
+ torchvision==0.16.2
22
+ tqdm==4.66.1
23
+ protobuf==4.25.0
24
+ transformers==4.40.0
25
+ typing_extensions==4.8.0
26
+ uvicorn==0.24.0.post1
27
+ #xformers==0.0.22.post7
28
+ #flash_attn==2.3.4
29
+ sentencepiece==0.1.99
30
+ accelerate==0.30.1
31
+ socksio==1.0.0
32
+ gradio==4.41.0
33
+ gradio_client
34
+ http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/modelscope_studio-0.4.0.9-py3-none-any.whl
35
+ decord
web_demo3.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import traceback
6
+ import re
7
+ import torch
8
+ import argparse
9
+ from transformers import AutoModel, AutoTokenizer
10
+
11
+ # Argparser
12
+ parser = argparse.ArgumentParser(description='demo')
13
+ parser.add_argument('--device', type=str, default='cpu', help='cpu')
14
+ parser.add_argument('--dtype', type=str, default='fp32', help='fp32')
15
+ args = parser.parse_args()
16
+ device = args.device
17
+ assert device in ['cpu']
18
+
19
+ # Set dtype
20
+ if args.dtype == 'fp32':
21
+ dtype = torch.float32
22
+ else:
23
+ dtype = torch.float16
24
+
25
+ # Load model
26
+ model_path = 'openbmb/MiniCPM-V-2'
27
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=dtype)
28
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
29
+
30
+ model = model.to(device=device)
31
+ model.eval()
32
+
33
+ ERROR_MSG = "Error, please retry"
34
+ model_name = 'MiniCPM-V 2.0'
35
+
36
+ # UI Components
37
+ form_radio = {
38
+ 'choices': ['Beam Search', 'Sampling'],
39
+ 'value': 'Sampling',
40
+ 'interactive': True,
41
+ 'label': 'Decode Type'
42
+ }
43
+
44
+ # Sliders and their settings
45
+ num_beams_slider = {'minimum': 0, 'maximum': 5, 'value': 3, 'step': 1, 'interactive': True, 'label': 'Num Beams'}
46
+ repetition_penalty_slider = {'minimum': 0, 'maximum': 3, 'value': 1.2, 'step': 0.01, 'interactive': True, 'label': 'Repetition Penalty'}
47
+ repetition_penalty_slider2 = {'minimum': 0, 'maximum': 3, 'value': 1.05, 'step': 0.01, 'interactive': True, 'label': 'Repetition Penalty'}
48
+ max_new_tokens_slider = {'minimum': 1, 'maximum': 4096, 'value': 1024, 'step': 1, 'interactive': True, 'label': 'Max New Tokens'}
49
+ top_p_slider = {'minimum': 0, 'maximum': 1, 'value': 0.8, 'step': 0.05, 'interactive': True, 'label': 'Top P'}
50
+ top_k_slider = {'minimum': 0, 'maximum': 200, 'value': 100, 'step': 1, 'interactive': True, 'label': 'Top K'}
51
+ temperature_slider = {'minimum': 0, 'maximum': 2, 'value': 0.7, 'step': 0.05, 'interactive': True, 'label': 'Temperature'}
52
+
53
+ def create_component(params, comp='Slider'):
54
+ if comp == 'Slider':
55
+ return gr.Slider(**params)
56
+ elif comp == 'Radio':
57
+ return gr.Radio(choices=params['choices'], value=params['value'], interactive=params['interactive'], label=params['label'])
58
+ elif comp == 'Button':
59
+ return gr.Button(value=params['value'], interactive=True)
60
+
61
+ def chat(img, msgs, ctx, params=None):
62
+ default_params = {"num_beams": 3, "repetition_penalty": 1.2, "max_new_tokens": 1024}
63
+ if params is None:
64
+ params = default_params
65
+ if img is None:
66
+ return -1, "Error, invalid image, please upload a new image", None, None
67
+ try:
68
+ image = img.convert('RGB')
69
+ answer, context, _ = model.chat(image=image, msgs=msgs, context=None, tokenizer=tokenizer, **params)
70
+ res = re.sub(r'(<box>.*</box>)', '', answer).replace('<ref>', '').replace('</ref>', '').replace('<box>', '').replace('</box>', '')
71
+ return 0, res, None, None
72
+ except Exception as err:
73
+ print(err)
74
+ traceback.print_exc()
75
+ return -1, ERROR_MSG, None, None
76
+
77
+ def upload_img(image, _chatbot, _app_session):
78
+ image = Image.fromarray(image)
79
+ _app_session['sts'] = None
80
+ _app_session['ctx'] = []
81
+ _app_session['img'] = image
82
+ _chatbot.append(('', 'Image uploaded successfully, you can talk to me now'))
83
+ return _chatbot, _app_session
84
+
85
+ def respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
86
+ if _app_cfg.get('ctx', None) is None:
87
+ _chat_bot.append((_question, 'Please upload an image to start'))
88
+ return '', _chat_bot, _app_cfg
89
+
90
+ _context = _app_cfg['ctx'].copy()
91
+ _context.append({"role": "user", "content": _question})
92
+
93
+ if params_form == 'Beam Search':
94
+ params = {'sampling': False, 'num_beams': num_beams, 'repetition_penalty': repetition_penalty, "max_new_tokens": 896}
95
+ else: # Ensure this block is executed for Sampling
96
+ params = {
97
+ 'sampling': True,
98
+ 'top_p': top_p,
99
+ 'top_k': top_k,
100
+ 'temperature': temperature,
101
+ 'repetition_penalty': repetition_penalty_2,
102
+ "max_new_tokens": 896
103
+ }
104
+
105
+ code, _answer, _, sts = chat(_app_cfg['img'], _context, None, params)
106
+
107
+ _context.append({"role": "assistant", "content": _answer})
108
+ _chat_bot.append((_question, _answer))
109
+ if code == 0:
110
+ _app_cfg['ctx'] = _context
111
+ _app_cfg['sts'] = sts
112
+ return '', _chat_bot, _app_cfg
113
+
114
+ def clear(chat_bot, app_session):
115
+ app_session['img'] = None
116
+ chat_bot.clear()
117
+ return chat_bot
118
+
119
+ with gr.Blocks() as demo:
120
+ gr.Markdown("<h1 style='text-align: center;'>Medical Assistant</h1>")
121
+
122
+ with gr.Row():
123
+ with gr.Column(scale=2, min_width=300):
124
+ app_session = gr.State({'sts': None, 'ctx': None, 'img': None})
125
+ bt_pic = gr.Image(label="Upload an image to start")
126
+ txt_message = gr.Textbox(label="Ask your question...")
127
+
128
+ with gr.Column(scale=2, min_width=300):
129
+ chat_bot = gr.Chatbot(label=f"Chatbot")
130
+ clear_button = gr.Button(value='Clear')
131
+ txt_message.submit(
132
+ respond,
133
+ [txt_message, chat_bot, app_session],
134
+ [txt_message, chat_bot, app_session]
135
+ )
136
+
137
+ bt_pic.upload(lambda: None, None, chat_bot, queue=False).then(upload_img, inputs=[bt_pic, chat_bot, app_session], outputs=[chat_bot, app_session])
138
+ clear_button.click(clear, [chat_bot, app_session], chat_bot)
139
+
140
+ # Launch
141
+ demo.launch(share=True, debug=True, show_api=False)