sontal commited on
Commit
ac4da2d
·
0 Parent(s):

init commit

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ venv
2
+ *.log*
3
+ .DS_Store
4
+ __pycache__/
5
+ *.py[cod]
6
+ serve_images
7
+ *-conv.json
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.9
5
+
6
+ WORKDIR /code
7
+
8
+ COPY ./requirements.txt /code/requirements.txt
9
+
10
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
11
+
12
+ RUN useradd -m -u 1000 user
13
+ USER user
14
+ ENV HOME=/home/user \
15
+ PATH=/home/user/.local/bin:$PATH
16
+
17
+ WORKDIR $HOME/app
18
+
19
+ COPY --chown=user . $HOME/app
20
+
21
+ CMD ["python", "gradio_web_server.py", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Yi 34B VL
3
+ emoji: 😻
4
+ colorFrom: indigo
5
+ colorTo: pink
6
+ sdk: docker
7
+ pinned: false
8
+ ---
9
+
10
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app/.DS_Store ADDED
Binary file (6.15 kB). View file
 
app/__init__.py ADDED
File without changes
app/constants.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
13
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
app/conversation.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+ MPT = auto()
11
+ PLAIN = auto()
12
+ LLAMA_2 = auto()
13
+ GPT = auto()
14
+
15
+
16
+ @dataclasses.dataclass
17
+ class Conversation:
18
+ """A class that keeps all conversation history."""
19
+ system: str
20
+ roles: List[str]
21
+ messages: List[List[str]]
22
+ offset: int
23
+ sep_style: SeparatorStyle = SeparatorStyle.GPT
24
+ sep: str = "###"
25
+ sep2: str = None
26
+ version: str = "Unknown"
27
+
28
+ skip_next: bool = False
29
+
30
+ def get_prompt(self):
31
+ messages = self.messages
32
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
33
+ messages = self.messages.copy()
34
+ init_role, init_msg = messages[0].copy()
35
+ init_msg = init_msg[0].replace("<image>", "").strip()
36
+ if 'mmtag' in self.version:
37
+ messages[0] = (init_role, init_msg)
38
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
39
+ messages.insert(1, (self.roles[1], "Received."))
40
+ else:
41
+ messages[0] = (init_role, "<image> " + init_msg)
42
+
43
+ if self.sep_style == SeparatorStyle.SINGLE:
44
+ ret = self.system + self.sep
45
+ for role, message in messages:
46
+ if message:
47
+ if type(message) is tuple:
48
+ message, _, _ = message
49
+ ret += role + ": " + message + self.sep
50
+ else:
51
+ ret += role + ":"
52
+ elif self.sep_style == SeparatorStyle.TWO:
53
+ seps = [self.sep, self.sep2]
54
+ ret = self.system + seps[0]
55
+ for i, (role, message) in enumerate(messages):
56
+ if message:
57
+ if type(message) is tuple:
58
+ message, _, _ = message
59
+ ret += role + ": " + message + seps[i % 2]
60
+ else:
61
+ ret += role + ":"
62
+ elif self.sep_style == SeparatorStyle.MPT:
63
+ ret = self.system + self.sep
64
+ for role, message in messages:
65
+ if message:
66
+ if type(message) is tuple:
67
+ message, _, _ = message
68
+ ret += role + message + self.sep
69
+ else:
70
+ ret += role
71
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
72
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
73
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
74
+ ret = ""
75
+
76
+ for i, (role, message) in enumerate(messages):
77
+ if i == 0:
78
+ assert message, "first message should not be none"
79
+ assert role == self.roles[0], "first message should come from user"
80
+ if message:
81
+ if type(message) is tuple:
82
+ message, _, _ = message
83
+ if i == 0: message = wrap_sys(self.system) + message
84
+ if i % 2 == 0:
85
+ message = wrap_inst(message)
86
+ ret += self.sep + message
87
+ else:
88
+ ret += " " + message + " " + self.sep2
89
+ else:
90
+ ret += ""
91
+ ret = ret.lstrip(self.sep)
92
+ elif self.sep_style == SeparatorStyle.PLAIN:
93
+ seps = [self.sep, self.sep2]
94
+ ret = self.system
95
+ for i, (role, message) in enumerate(messages):
96
+ if message:
97
+ if type(message) is tuple:
98
+ message, _, _ = message
99
+ ret += message + seps[i % 2]
100
+ else:
101
+ ret += ""
102
+ elif self.sep_style == SeparatorStyle.GPT:
103
+ ret = []
104
+ for i, (role, message) in enumerate(messages):
105
+ if message:
106
+ if type(message) is tuple:
107
+ message, _, _ = message
108
+ ret.append({
109
+ "role": role.lower(),
110
+ "content": message
111
+ })
112
+ pass
113
+ else:
114
+ raise ValueError(f"Invalid style: {self.sep_style}")
115
+
116
+ return ret
117
+
118
+ def append_message(self, role, message):
119
+ self.messages.append([role, message])
120
+
121
+ def get_images(self, return_pil=False):
122
+ images = []
123
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
124
+ if i % 2 == 0:
125
+ if type(msg) is tuple:
126
+ import base64
127
+ from io import BytesIO
128
+ from PIL import Image
129
+ msg, image, image_process_mode = msg
130
+ if image_process_mode == "Pad":
131
+ def expand2square(pil_img, background_color=(122, 116, 104)):
132
+ width, height = pil_img.size
133
+ if width == height:
134
+ return pil_img
135
+ elif width > height:
136
+ result = Image.new(pil_img.mode, (width, width), background_color)
137
+ result.paste(pil_img, (0, (width - height) // 2))
138
+ return result
139
+ else:
140
+ result = Image.new(pil_img.mode, (height, height), background_color)
141
+ result.paste(pil_img, ((height - width) // 2, 0))
142
+ return result
143
+ image = expand2square(image)
144
+ elif image_process_mode in ["Default", "Crop"]:
145
+ pass
146
+ elif image_process_mode == "Resize":
147
+ image = image.resize((336, 336))
148
+ else:
149
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
150
+ max_hw, min_hw = max(image.size), min(image.size)
151
+ aspect_ratio = max_hw / min_hw
152
+ max_len, min_len = 800, 400
153
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
154
+ longest_edge = int(shortest_edge * aspect_ratio)
155
+ W, H = image.size
156
+ if longest_edge != max(image.size):
157
+ if H > W:
158
+ H, W = longest_edge, shortest_edge
159
+ else:
160
+ H, W = shortest_edge, longest_edge
161
+ image = image.resize((W, H))
162
+ if return_pil:
163
+ images.append(image)
164
+ else:
165
+ buffered = BytesIO()
166
+ image.save(buffered, format="jpeg")
167
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
168
+ images.append(img_b64_str)
169
+ return images
170
+
171
+ def to_gradio_chatbot(self):
172
+ ret = []
173
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
174
+ if i % 2 == 0:
175
+ if type(msg) is tuple:
176
+ import base64
177
+ from io import BytesIO
178
+ msg, image, image_process_mode = msg
179
+ max_hw, min_hw = max(image.size), min(image.size)
180
+ aspect_ratio = max_hw / min_hw
181
+ max_len, min_len = 800, 400
182
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
183
+ longest_edge = int(shortest_edge * aspect_ratio)
184
+ W, H = image.size
185
+ if H > W:
186
+ H, W = longest_edge, shortest_edge
187
+ else:
188
+ H, W = shortest_edge, longest_edge
189
+ image = image.resize((W, H))
190
+ buffered = BytesIO()
191
+ image.save(buffered, format="JPEG")
192
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
193
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
194
+ msg = img_str + msg.replace('<image>', '').strip()
195
+ ret.append([msg, None])
196
+ else:
197
+ ret.append([msg, None])
198
+ else:
199
+ ret[-1][-1] = msg
200
+ return ret
201
+
202
+ def copy(self):
203
+ return Conversation(
204
+ system=self.system,
205
+ roles=self.roles,
206
+ messages=[[x, y] for x, y in self.messages],
207
+ offset=self.offset,
208
+ sep_style=self.sep_style,
209
+ sep=self.sep,
210
+ sep2=self.sep2,
211
+ version=self.version)
212
+
213
+ def dict(self):
214
+ if len(self.get_images()) > 0:
215
+ return {
216
+ "system": self.system,
217
+ "roles": self.roles,
218
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
219
+ "offset": self.offset,
220
+ "sep": self.sep,
221
+ "sep2": self.sep2,
222
+ }
223
+ return {
224
+ "system": self.system,
225
+ "roles": self.roles,
226
+ "messages": self.messages,
227
+ "offset": self.offset,
228
+ "sep": self.sep,
229
+ "sep2": self.sep2,
230
+ }
231
+
232
+
233
+ conv_vicuna_v0 = Conversation(
234
+ system="A chat between a curious human and an artificial intelligence assistant. "
235
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
236
+ roles=("Human", "Assistant"),
237
+ messages=(
238
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
239
+ ("Assistant",
240
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
241
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
242
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
243
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
244
+ "renewable and non-renewable energy sources:\n"
245
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
246
+ "energy sources are finite and will eventually run out.\n"
247
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
248
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
249
+ "and other negative effects.\n"
250
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
251
+ "have lower operational costs than non-renewable sources.\n"
252
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
253
+ "locations than non-renewable sources.\n"
254
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
255
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
256
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
257
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
258
+ ),
259
+ offset=2,
260
+ sep_style=SeparatorStyle.SINGLE,
261
+ sep="###",
262
+ )
263
+
264
+ conv_vicuna_v1 = Conversation(
265
+ system="A chat between a curious user and an artificial intelligence assistant. "
266
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
267
+ roles=("USER", "ASSISTANT"),
268
+ version="v1",
269
+ messages=(),
270
+ offset=0,
271
+ sep_style=SeparatorStyle.GPT,
272
+ sep=" ",
273
+ sep2="</s>",
274
+ )
275
+
276
+ conv_llama_2 = Conversation(
277
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
278
+
279
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
280
+ roles=("USER", "ASSISTANT"),
281
+ version="llama_v2",
282
+ messages=(),
283
+ offset=0,
284
+ sep_style=SeparatorStyle.LLAMA_2,
285
+ sep="<s>",
286
+ sep2="</s>",
287
+ )
288
+
289
+ conv_llava_llama_2 = Conversation(
290
+ system="You are a helpful language and vision assistant. "
291
+ "You are able to understand the visual content that the user provides, "
292
+ "and assist the user with a variety of tasks using natural language.",
293
+ roles=("USER", "ASSISTANT"),
294
+ version="llama_v2",
295
+ messages=(),
296
+ offset=0,
297
+ sep_style=SeparatorStyle.LLAMA_2,
298
+ sep="<s>",
299
+ sep2="</s>",
300
+ )
301
+
302
+ conv_mpt = Conversation(
303
+ system="""<|im_start|>system
304
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
305
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
306
+ version="mpt",
307
+ messages=(),
308
+ offset=0,
309
+ sep_style=SeparatorStyle.MPT,
310
+ sep="<|im_end|>",
311
+ )
312
+
313
+ conv_llava_plain = Conversation(
314
+ system="",
315
+ roles=("", ""),
316
+ messages=(
317
+ ),
318
+ offset=0,
319
+ sep_style=SeparatorStyle.PLAIN,
320
+ sep="\n",
321
+ )
322
+
323
+ conv_llava_v0 = Conversation(
324
+ system="A chat between a curious human and an artificial intelligence assistant. "
325
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
326
+ roles=("Human", "Assistant"),
327
+ messages=(
328
+ ),
329
+ offset=0,
330
+ sep_style=SeparatorStyle.SINGLE,
331
+ sep="###",
332
+ )
333
+
334
+ conv_llava_v0_mmtag = Conversation(
335
+ system="A chat between a curious user and an artificial intelligence assistant. "
336
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
337
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
338
+ roles=("Human", "Assistant"),
339
+ messages=(
340
+ ),
341
+ offset=0,
342
+ sep_style=SeparatorStyle.SINGLE,
343
+ sep="###",
344
+ version="v0_mmtag",
345
+ )
346
+
347
+ conv_llava_v1 = Conversation(
348
+ system="A chat between a curious human and an artificial intelligence assistant. "
349
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
350
+ roles=("USER", "ASSISTANT"),
351
+ version="v1",
352
+ messages=(),
353
+ offset=0,
354
+ sep_style=SeparatorStyle.TWO,
355
+ sep=" ",
356
+ sep2="</s>",
357
+ )
358
+
359
+ conv_llava_v1_mmtag = Conversation(
360
+ system="A chat between a curious user and an artificial intelligence assistant. "
361
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
362
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
363
+ roles=("USER", "ASSISTANT"),
364
+ messages=(),
365
+ offset=0,
366
+ sep_style=SeparatorStyle.TWO,
367
+ sep=" ",
368
+ sep2="</s>",
369
+ version="v1_mmtag",
370
+ )
371
+
372
+ default_conversation = conv_vicuna_v1
373
+ conv_templates = {
374
+ "default": conv_vicuna_v0,
375
+ "v0": conv_vicuna_v0,
376
+ "v1": conv_vicuna_v1,
377
+ "vicuna_v1": conv_vicuna_v1,
378
+ "llama_2": conv_llama_2,
379
+
380
+ "plain": conv_llava_plain,
381
+ "v0_plain": conv_llava_plain,
382
+ "llava_v0": conv_llava_v0,
383
+ "v0_mmtag": conv_llava_v0_mmtag,
384
+ "llava_v1": conv_llava_v1,
385
+ "v1_mmtag": conv_llava_v1_mmtag,
386
+ "llava_llama_2": conv_llava_llama_2,
387
+
388
+ "mpt": conv_mpt,
389
+ }
390
+
391
+
392
+ if __name__ == "__main__":
393
+ print(default_conversation.get_prompt())
app/utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import logging.handlers
4
+ import os
5
+ import sys
6
+
7
+ import requests
8
+
9
+ from app.constants import LOGDIR
10
+
11
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12
+ moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13
+
14
+ handler = None
15
+
16
+
17
+ def build_logger(logger_name, logger_filename):
18
+ global handler
19
+
20
+ formatter = logging.Formatter(
21
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22
+ datefmt="%Y-%m-%d %H:%M:%S",
23
+ )
24
+
25
+ # Set the format of root handlers
26
+ if not logging.getLogger().handlers:
27
+ logging.basicConfig(level=logging.INFO)
28
+ logging.getLogger().handlers[0].setFormatter(formatter)
29
+
30
+ # Redirect stdout and stderr to loggers
31
+ stdout_logger = logging.getLogger("stdout")
32
+ stdout_logger.setLevel(logging.INFO)
33
+ sl = StreamToLogger(stdout_logger, logging.INFO)
34
+ sys.stdout = sl
35
+
36
+ stderr_logger = logging.getLogger("stderr")
37
+ stderr_logger.setLevel(logging.ERROR)
38
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
39
+ sys.stderr = sl
40
+
41
+ # Get logger
42
+ logger = logging.getLogger(logger_name)
43
+ logger.setLevel(logging.INFO)
44
+
45
+ # Add a file handler for all loggers
46
+ if handler is None:
47
+ os.makedirs(LOGDIR, exist_ok=True)
48
+ filename = os.path.join(LOGDIR, logger_filename)
49
+ handler = logging.handlers.TimedRotatingFileHandler(
50
+ filename, when='D', utc=True, encoding='UTF-8')
51
+ handler.setFormatter(formatter)
52
+
53
+ for name, item in logging.root.manager.loggerDict.items():
54
+ if isinstance(item, logging.Logger):
55
+ item.addHandler(handler)
56
+
57
+ return logger
58
+
59
+
60
+ class StreamToLogger(object):
61
+ """
62
+ Fake file-like stream object that redirects writes to a logger instance.
63
+ """
64
+ def __init__(self, logger, log_level=logging.INFO):
65
+ self.terminal = sys.stdout
66
+ self.logger = logger
67
+ self.log_level = log_level
68
+ self.linebuf = ''
69
+
70
+ def __getattr__(self, attr):
71
+ return getattr(self.terminal, attr)
72
+
73
+ def write(self, buf):
74
+ temp_linebuf = self.linebuf + buf
75
+ self.linebuf = ''
76
+ for line in temp_linebuf.splitlines(True):
77
+ # From the io.TextIOWrapper docs:
78
+ # On output, if newline is None, any '\n' characters written
79
+ # are translated to the system default line separator.
80
+ # By default sys.stdout.write() expects '\n' newlines and then
81
+ # translates them so this is still cross platform.
82
+ if line[-1] == '\n':
83
+ self.logger.log(self.log_level, line.rstrip())
84
+ else:
85
+ self.linebuf += line
86
+
87
+ def flush(self):
88
+ if self.linebuf != '':
89
+ self.logger.log(self.log_level, self.linebuf.rstrip())
90
+ self.linebuf = ''
91
+
92
+
93
+ def disable_torch_init():
94
+ """
95
+ Disable the redundant torch default initialization to accelerate model creation.
96
+ """
97
+ import torch
98
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100
+
101
+
102
+ def violates_moderation(text):
103
+ """
104
+ Check whether the text violates OpenAI moderation API.
105
+ """
106
+ url = "https://api.openai.com/v1/moderations"
107
+ headers = {"Content-Type": "application/json",
108
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109
+ text = text.replace("\n", "")
110
+ data = "{" + '"input": ' + f'"{text}"' + "}"
111
+ data = data.encode("utf-8")
112
+ try:
113
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
114
+ flagged = ret.json()["results"][0]["flagged"]
115
+ except requests.exceptions.RequestException as e:
116
+ flagged = False
117
+ except KeyError as e:
118
+ flagged = False
119
+
120
+ return flagged
121
+
122
+
123
+ def pretty_print_semaphore(semaphore):
124
+ if semaphore is None:
125
+ return "None"
126
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
examples/extreme_ironing.jpg ADDED
examples/waterview.jpg ADDED
gradio_web_server.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
+ import time
6
+
7
+ import gradio as gr
8
+ import requests
9
+ from PIL import Image
10
+ import base64
11
+ from io import BytesIO
12
+
13
+ from app.conversation import (default_conversation, conv_templates,
14
+ SeparatorStyle)
15
+ from app.constants import LOGDIR
16
+ from app.utils import (build_logger, server_error_msg,
17
+ violates_moderation, moderation_msg)
18
+ import hashlib
19
+
20
+ worker_addr = os.getenv('WORKER_ADDR')
21
+ apikey = os.getenv('AUTHORIZATION')
22
+
23
+
24
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
25
+
26
+ headers = {"Authorization": apikey}
27
+
28
+ no_change_btn = gr.Button()
29
+ enable_btn = gr.Button(interactive=True)
30
+ disable_btn = gr.Button(interactive=False)
31
+
32
+ priority = {
33
+ "vicuna-13b": "aaaaaaa",
34
+ "koala-13b": "aaaaaab",
35
+ }
36
+
37
+
38
+ def get_conv_log_filename():
39
+ t = datetime.datetime.now()
40
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
41
+ return name
42
+
43
+
44
+ def get_model_list():
45
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
46
+ assert ret.status_code == 200
47
+ ret = requests.post(args.controller_url + "/list_models")
48
+ models = ret.json()["models"]
49
+ models.sort(key=lambda x: priority.get(x, x))
50
+ logger.info(f"Models: {models}")
51
+ return models
52
+
53
+
54
+ get_window_url_params = """
55
+ function() {
56
+ const params = new URLSearchParams(window.location.search);
57
+ url_params = Object.fromEntries(params);
58
+ console.log(url_params);
59
+ return url_params;
60
+ }
61
+ """
62
+
63
+
64
+ def load_demo(url_params, request: gr.Request):
65
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
66
+
67
+ dropdown_update = gr.Dropdown.update(visible=True)
68
+ if "model" in url_params:
69
+ model = url_params["model"]
70
+ if model in models:
71
+ dropdown_update = gr.Dropdown.update(
72
+ value=model, visible=True)
73
+
74
+ state = default_conversation.copy()
75
+ return state, dropdown_update
76
+
77
+
78
+ def load_demo_refresh_model_list(request: gr.Request):
79
+ logger.info(f"load_demo. ip: {request.client.host}")
80
+ models = get_model_list()
81
+ state = default_conversation.copy()
82
+ dropdown_update = gr.Dropdown(
83
+ choices=models,
84
+ value=models[0] if len(models) > 0 else ""
85
+ )
86
+ return state, dropdown_update
87
+
88
+
89
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
90
+ with open(get_conv_log_filename(), "a") as fout:
91
+ data = {
92
+ "tstamp": round(time.time(), 4),
93
+ "type": vote_type,
94
+ "model": model_selector,
95
+ "state": state.dict(),
96
+ "ip": request.client.host,
97
+ }
98
+ fout.write(json.dumps(data) + "\n")
99
+
100
+
101
+ def upvote_last_response(state, model_selector, request: gr.Request):
102
+ logger.info(f"upvote. ip: {request.client.host}")
103
+ vote_last_response(state, "upvote", model_selector, request)
104
+ return ("",) + (disable_btn,) * 3
105
+
106
+
107
+ def downvote_last_response(state, model_selector, request: gr.Request):
108
+ logger.info(f"downvote. ip: {request.client.host}")
109
+ vote_last_response(state, "downvote", model_selector, request)
110
+ return ("",) + (disable_btn,) * 3
111
+
112
+
113
+ def flag_last_response(state, model_selector, request: gr.Request):
114
+ logger.info(f"flag. ip: {request.client.host}")
115
+ vote_last_response(state, "flag", model_selector, request)
116
+ return ("",) + (disable_btn,) * 3
117
+
118
+
119
+ def regenerate(state, image_process_mode, request: gr.Request):
120
+ logger.info(f"regenerate. ip: {request.client.host}")
121
+ state.messages[-1][-1] = None
122
+ prev_human_msg = state.messages[-2]
123
+ if type(prev_human_msg[1]) in (tuple, list):
124
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
125
+ state.skip_next = False
126
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
127
+
128
+
129
+ def clear_history(request: gr.Request):
130
+ logger.info(f"clear_history. ip: {request.client.host}")
131
+ state = default_conversation.copy()
132
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
133
+
134
+
135
+ def add_text(state, text, image, image_process_mode, request: gr.Request):
136
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
137
+ if len(text) <= 0 and image is None:
138
+ state.skip_next = True
139
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
140
+ if args.moderate:
141
+ flagged = violates_moderation(text)
142
+ if flagged:
143
+ state.skip_next = True
144
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
145
+ no_change_btn,) * 5
146
+
147
+ text = text[:1536] # Hard cut-off
148
+ if image is not None:
149
+ text = text[:1200] # Hard cut-off for images
150
+ if '<image>' not in text:
151
+ # text = '<Image><image></Image>' + text
152
+ text = text + '\n<image>'
153
+ text = (text, image, image_process_mode)
154
+ if len(state.get_images(return_pil=True)) > 0:
155
+ state = default_conversation.copy()
156
+ state.append_message(state.roles[0], text)
157
+ state.append_message(state.roles[1], None)
158
+ state.skip_next = False
159
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
160
+
161
+ def convert_image_to_base64(image, format=None):
162
+ # 如果未指定格式,则使用图像的原始格式
163
+ if format is None:
164
+ format = "jpeg"
165
+
166
+ # 将图像保存到字节流
167
+ buffered = BytesIO()
168
+ image.save(buffered, format=format)
169
+
170
+ # 编码为base64
171
+ img_str = base64.b64encode(buffered.getvalue()).decode()
172
+
173
+ # 格式化最终的字符串
174
+ return f"data:image/{format.lower()};base64,{img_str}"
175
+
176
+ # 修改此处请求模型的逻辑
177
+ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
178
+ logger.info(f"http_bot. ip: {request.client.host}")
179
+ start_tstamp = time.time()
180
+ model_name = model_selector
181
+
182
+ if state.skip_next:
183
+ # This generate call is skipped due to invalid inputs
184
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
185
+ return
186
+
187
+ if len(state.messages) == state.offset + 2:
188
+ # First round of conversation
189
+ if "llava" in model_name.lower():
190
+ if 'llama-2' in model_name.lower():
191
+ template_name = "llava_llama_2"
192
+ elif "v1" in model_name.lower():
193
+ if 'mmtag' in model_name.lower():
194
+ template_name = "v1_mmtag"
195
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
196
+ template_name = "v1_mmtag"
197
+ else:
198
+ template_name = "llava_v1"
199
+ elif "mpt" in model_name.lower():
200
+ template_name = "mpt"
201
+ else:
202
+ if 'mmtag' in model_name.lower():
203
+ template_name = "v0_mmtag"
204
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
205
+ template_name = "v0_mmtag"
206
+ else:
207
+ template_name = "llava_v0"
208
+ elif "mpt" in model_name:
209
+ template_name = "mpt_text"
210
+ elif "llama-2" in model_name:
211
+ template_name = "llama_2"
212
+ else:
213
+ template_name = "vicuna_v1"
214
+ new_state = conv_templates[template_name].copy()
215
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
216
+ new_state.append_message(new_state.roles[1], None)
217
+ state = new_state
218
+
219
+ # Query worker address
220
+ # controller_url = args.controller_url
221
+ # ret = requests.post(controller_url + "/get_worker_address",
222
+ # json={"model": model_name})
223
+ # worker_addr = ret.json()["address"]
224
+ # logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
225
+
226
+ # # No available worker
227
+ # if worker_addr == "":
228
+ # state.messages[-1][-1] = server_error_msg
229
+ # yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
230
+ # return
231
+
232
+ # Construct prompt
233
+ prompt = state.get_prompt()
234
+
235
+ all_images = state.get_images(return_pil=True)
236
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
237
+
238
+ image = convert_image_to_base64(all_images[0])
239
+
240
+ # Make requests
241
+ pload = {
242
+ "image_path": image,
243
+ "model": model_name,
244
+ "messages": prompt,
245
+ "stream": True,
246
+ "max_tokens": 512
247
+ }
248
+ logger.info(f"==== request ====\n{json.dumps(pload)}")
249
+
250
+ pload['images'] = state.get_images()
251
+
252
+ state.messages[-1][-1] = "▌"
253
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
254
+
255
+ try:
256
+ # Stream output
257
+ response = requests.post(worker_addr,
258
+ headers=headers, json=pload, timeout=60)
259
+
260
+ output = json.loads(response.text)['message']['content']
261
+ logger.info("the response is {output}")
262
+ state.messages[-1][-1] = output
263
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
264
+ except requests.exceptions.RequestException as e:
265
+ state.messages[-1][-1] = server_error_msg
266
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
267
+ return
268
+
269
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
270
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
271
+
272
+ finish_tstamp = time.time()
273
+ logger.info(f"{output}")
274
+
275
+ with open(get_conv_log_filename(), "a") as fout:
276
+ data = {
277
+ "tstamp": round(finish_tstamp, 4),
278
+ "type": "chat",
279
+ "model": model_name,
280
+ "start": round(start_tstamp, 4),
281
+ "finish": round(finish_tstamp, 4),
282
+ "state": state.dict(),
283
+ "images": all_image_hash,
284
+ "ip": request.client.host,
285
+ }
286
+ fout.write(json.dumps(data) + "\n")
287
+
288
+ title_markdown = ("""
289
+ # 🌋 LLaVA: Large Language and Vision Assistant
290
+ [[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | 📚 [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)]
291
+ """)
292
+
293
+ tos_markdown = ("""
294
+ ### Terms of use
295
+ By using this service, users are required to agree to the following terms:
296
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
297
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
298
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
299
+ """)
300
+
301
+
302
+ learn_more_markdown = ("""
303
+ ### License
304
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
305
+ """)
306
+
307
+ block_css = """
308
+
309
+ #buttons button {
310
+ min-width: min(120px,100%);
311
+ }
312
+
313
+ """
314
+
315
+ def build_demo(embed_mode):
316
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
317
+ with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
318
+ state = gr.State()
319
+
320
+ if not embed_mode:
321
+ gr.Markdown(title_markdown)
322
+
323
+ with gr.Row():
324
+ with gr.Column(scale=3):
325
+ with gr.Row(elem_id="model_selector_row"):
326
+ model_selector = gr.Dropdown(
327
+ choices=models,
328
+ value=models[0] if len(models) > 0 else "",
329
+ interactive=True,
330
+ show_label=False,
331
+ container=False)
332
+
333
+ imagebox = gr.Image(type="pil")
334
+ image_process_mode = gr.Radio(
335
+ ["Crop", "Resize", "Pad", "Default"],
336
+ value="Default",
337
+ label="Preprocess for non-square image", visible=False)
338
+
339
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
340
+ gr.Examples(examples=[
341
+ [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
342
+ [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
343
+ ], inputs=[imagebox, textbox])
344
+
345
+ with gr.Accordion("Parameters", open=False) as parameter_row:
346
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
347
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
348
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
349
+
350
+ with gr.Column(scale=8):
351
+ chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", height=550)
352
+ with gr.Row():
353
+ with gr.Column(scale=8):
354
+ textbox.render()
355
+ with gr.Column(scale=1, min_width=50):
356
+ submit_btn = gr.Button(value="Send", variant="primary")
357
+ with gr.Row(elem_id="buttons") as button_row:
358
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
359
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
360
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
361
+ #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
362
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
363
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
364
+
365
+ if not embed_mode:
366
+ gr.Markdown(tos_markdown)
367
+ gr.Markdown(learn_more_markdown)
368
+ url_params = gr.JSON(visible=False)
369
+
370
+ # Register listeners
371
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
372
+ upvote_btn.click(
373
+ upvote_last_response,
374
+ [state, model_selector],
375
+ [textbox, upvote_btn, downvote_btn, flag_btn],
376
+ queue=False
377
+ )
378
+ downvote_btn.click(
379
+ downvote_last_response,
380
+ [state, model_selector],
381
+ [textbox, upvote_btn, downvote_btn, flag_btn],
382
+ queue=False
383
+ )
384
+ flag_btn.click(
385
+ flag_last_response,
386
+ [state, model_selector],
387
+ [textbox, upvote_btn, downvote_btn, flag_btn],
388
+ queue=False
389
+ )
390
+
391
+ regenerate_btn.click(
392
+ regenerate,
393
+ [state, image_process_mode],
394
+ [state, chatbot, textbox, imagebox] + btn_list,
395
+ queue=False
396
+ ).then(
397
+ http_bot,
398
+ [state, model_selector, temperature, top_p, max_output_tokens],
399
+ [state, chatbot] + btn_list
400
+ )
401
+
402
+ clear_btn.click(
403
+ clear_history,
404
+ None,
405
+ [state, chatbot, textbox, imagebox] + btn_list,
406
+ queue=False
407
+ )
408
+
409
+ textbox.submit(
410
+ add_text,
411
+ [state, textbox, imagebox, image_process_mode],
412
+ [state, chatbot, textbox, imagebox] + btn_list,
413
+ queue=False
414
+ ).then(
415
+ http_bot,
416
+ [state, model_selector, temperature, top_p, max_output_tokens],
417
+ [state, chatbot] + btn_list
418
+ )
419
+
420
+ submit_btn.click(
421
+ add_text,
422
+ [state, textbox, imagebox, image_process_mode],
423
+ [state, chatbot, textbox, imagebox] + btn_list,
424
+ queue=False
425
+ ).then(
426
+ http_bot,
427
+ [state, model_selector, temperature, top_p, max_output_tokens],
428
+ [state, chatbot] + btn_list
429
+ )
430
+
431
+ if args.model_list_mode == "once":
432
+ demo.load(
433
+ load_demo,
434
+ [url_params],
435
+ [state, model_selector],
436
+ _js=get_window_url_params,
437
+ queue=False
438
+ )
439
+ elif args.model_list_mode == "reload":
440
+ demo.load(
441
+ load_demo_refresh_model_list,
442
+ None,
443
+ [state, model_selector],
444
+ queue=False
445
+ )
446
+ else:
447
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
448
+
449
+ return demo
450
+
451
+
452
+ if __name__ == "__main__":
453
+ parser = argparse.ArgumentParser()
454
+ parser.add_argument("--host", type=str, default="0.0.0.0")
455
+ parser.add_argument("--port", type=int)
456
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
457
+ parser.add_argument("--concurrency-count", type=int, default=10)
458
+ parser.add_argument("--model-list-mode", type=str, default="once",
459
+ choices=["once", "reload"])
460
+ parser.add_argument("--share", action="store_true")
461
+ parser.add_argument("--moderate", action="store_true")
462
+ parser.add_argument("--embed", action="store_true")
463
+ args = parser.parse_args()
464
+ logger.info(f"args: {args}")
465
+
466
+ models = ["yi-34b-vl"]
467
+
468
+ logger.info(args)
469
+ demo = build_demo(args.embed)
470
+ demo.queue(
471
+ concurrency_count=args.concurrency_count,
472
+ api_open=False
473
+ ).launch(
474
+ server_name=args.host,
475
+ server_port=args.port,
476
+ share=args.share
477
+ )
requirements.txt ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ altair==5.2.0
3
+ annotated-types==0.6.0
4
+ anyio==3.7.1
5
+ attrs==23.1.0
6
+ certifi==2023.11.17
7
+ charset-normalizer==3.3.2
8
+ click==8.1.7
9
+ contourpy==1.2.0
10
+ cycler==0.12.1
11
+ fastapi==0.105.0
12
+ ffmpy==0.3.1
13
+ filelock==3.13.1
14
+ fonttools==4.47.0
15
+ fsspec==2023.12.2
16
+ gradio==3.47.1
17
+ gradio_client==0.6.0
18
+ h11==0.14.0
19
+ httpcore==1.0.2
20
+ httpx==0.26.0
21
+ huggingface-hub==0.20.1
22
+ idna==3.6
23
+ importlib-resources==6.1.1
24
+ Jinja2==3.1.2
25
+ jsonschema==4.20.0
26
+ jsonschema-specifications==2023.11.2
27
+ kiwisolver==1.4.5
28
+ MarkupSafe==2.1.3
29
+ matplotlib==3.8.2
30
+ numpy==1.26.2
31
+ orjson==3.9.10
32
+ packaging==23.2
33
+ pandas==2.1.4
34
+ Pillow==10.1.0
35
+ pydantic==2.5.2
36
+ pydantic_core==2.14.5
37
+ pydub==0.25.1
38
+ pyparsing==3.1.1
39
+ python-dateutil==2.8.2
40
+ python-multipart==0.0.6
41
+ pytz==2023.3.post1
42
+ PyYAML==6.0.1
43
+ referencing==0.32.0
44
+ requests==2.31.0
45
+ rpds-py==0.15.2
46
+ semantic-version==2.10.0
47
+ six==1.16.0
48
+ sniffio==1.3.0
49
+ starlette==0.27.0
50
+ toolz==0.12.0
51
+ tqdm==4.66.1
52
+ typing_extensions==4.9.0
53
+ tzdata==2023.3
54
+ urllib3==2.1.0
55
+ uvicorn==0.25.0
56
+ websockets==11.0.3