sontal
commited on
Commit
·
ac4da2d
0
Parent(s):
init commit
Browse files- .DS_Store +0 -0
- .gitattributes +35 -0
- .gitignore +7 -0
- Dockerfile +21 -0
- README.md +10 -0
- app/.DS_Store +0 -0
- app/__init__.py +0 -0
- app/constants.py +13 -0
- app/conversation.py +393 -0
- app/utils.py +126 -0
- examples/extreme_ironing.jpg +0 -0
- examples/waterview.jpg +0 -0
- gradio_web_server.py +477 -0
- requirements.txt +56 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
venv
|
2 |
+
*.log*
|
3 |
+
.DS_Store
|
4 |
+
__pycache__/
|
5 |
+
*.py[cod]
|
6 |
+
serve_images
|
7 |
+
*-conv.json
|
Dockerfile
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
|
2 |
+
# you will also find guides on how best to write your Dockerfile
|
3 |
+
|
4 |
+
FROM python:3.9
|
5 |
+
|
6 |
+
WORKDIR /code
|
7 |
+
|
8 |
+
COPY ./requirements.txt /code/requirements.txt
|
9 |
+
|
10 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
11 |
+
|
12 |
+
RUN useradd -m -u 1000 user
|
13 |
+
USER user
|
14 |
+
ENV HOME=/home/user \
|
15 |
+
PATH=/home/user/.local/bin:$PATH
|
16 |
+
|
17 |
+
WORKDIR $HOME/app
|
18 |
+
|
19 |
+
COPY --chown=user . $HOME/app
|
20 |
+
|
21 |
+
CMD ["python", "gradio_web_server.py", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Yi 34B VL
|
3 |
+
emoji: 😻
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: pink
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
---
|
9 |
+
|
10 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
app/__init__.py
ADDED
File without changes
|
app/constants.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
3 |
+
|
4 |
+
LOGDIR = "."
|
5 |
+
|
6 |
+
# Model Constants
|
7 |
+
IGNORE_INDEX = -100
|
8 |
+
IMAGE_TOKEN_INDEX = -200
|
9 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
10 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
11 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
12 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
13 |
+
IMAGE_PLACEHOLDER = "<image-placeholder>"
|
app/conversation.py
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
|
6 |
+
class SeparatorStyle(Enum):
|
7 |
+
"""Different separator style."""
|
8 |
+
SINGLE = auto()
|
9 |
+
TWO = auto()
|
10 |
+
MPT = auto()
|
11 |
+
PLAIN = auto()
|
12 |
+
LLAMA_2 = auto()
|
13 |
+
GPT = auto()
|
14 |
+
|
15 |
+
|
16 |
+
@dataclasses.dataclass
|
17 |
+
class Conversation:
|
18 |
+
"""A class that keeps all conversation history."""
|
19 |
+
system: str
|
20 |
+
roles: List[str]
|
21 |
+
messages: List[List[str]]
|
22 |
+
offset: int
|
23 |
+
sep_style: SeparatorStyle = SeparatorStyle.GPT
|
24 |
+
sep: str = "###"
|
25 |
+
sep2: str = None
|
26 |
+
version: str = "Unknown"
|
27 |
+
|
28 |
+
skip_next: bool = False
|
29 |
+
|
30 |
+
def get_prompt(self):
|
31 |
+
messages = self.messages
|
32 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
33 |
+
messages = self.messages.copy()
|
34 |
+
init_role, init_msg = messages[0].copy()
|
35 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
36 |
+
if 'mmtag' in self.version:
|
37 |
+
messages[0] = (init_role, init_msg)
|
38 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
39 |
+
messages.insert(1, (self.roles[1], "Received."))
|
40 |
+
else:
|
41 |
+
messages[0] = (init_role, "<image> " + init_msg)
|
42 |
+
|
43 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
44 |
+
ret = self.system + self.sep
|
45 |
+
for role, message in messages:
|
46 |
+
if message:
|
47 |
+
if type(message) is tuple:
|
48 |
+
message, _, _ = message
|
49 |
+
ret += role + ": " + message + self.sep
|
50 |
+
else:
|
51 |
+
ret += role + ":"
|
52 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
53 |
+
seps = [self.sep, self.sep2]
|
54 |
+
ret = self.system + seps[0]
|
55 |
+
for i, (role, message) in enumerate(messages):
|
56 |
+
if message:
|
57 |
+
if type(message) is tuple:
|
58 |
+
message, _, _ = message
|
59 |
+
ret += role + ": " + message + seps[i % 2]
|
60 |
+
else:
|
61 |
+
ret += role + ":"
|
62 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
63 |
+
ret = self.system + self.sep
|
64 |
+
for role, message in messages:
|
65 |
+
if message:
|
66 |
+
if type(message) is tuple:
|
67 |
+
message, _, _ = message
|
68 |
+
ret += role + message + self.sep
|
69 |
+
else:
|
70 |
+
ret += role
|
71 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
72 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
|
73 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
74 |
+
ret = ""
|
75 |
+
|
76 |
+
for i, (role, message) in enumerate(messages):
|
77 |
+
if i == 0:
|
78 |
+
assert message, "first message should not be none"
|
79 |
+
assert role == self.roles[0], "first message should come from user"
|
80 |
+
if message:
|
81 |
+
if type(message) is tuple:
|
82 |
+
message, _, _ = message
|
83 |
+
if i == 0: message = wrap_sys(self.system) + message
|
84 |
+
if i % 2 == 0:
|
85 |
+
message = wrap_inst(message)
|
86 |
+
ret += self.sep + message
|
87 |
+
else:
|
88 |
+
ret += " " + message + " " + self.sep2
|
89 |
+
else:
|
90 |
+
ret += ""
|
91 |
+
ret = ret.lstrip(self.sep)
|
92 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
93 |
+
seps = [self.sep, self.sep2]
|
94 |
+
ret = self.system
|
95 |
+
for i, (role, message) in enumerate(messages):
|
96 |
+
if message:
|
97 |
+
if type(message) is tuple:
|
98 |
+
message, _, _ = message
|
99 |
+
ret += message + seps[i % 2]
|
100 |
+
else:
|
101 |
+
ret += ""
|
102 |
+
elif self.sep_style == SeparatorStyle.GPT:
|
103 |
+
ret = []
|
104 |
+
for i, (role, message) in enumerate(messages):
|
105 |
+
if message:
|
106 |
+
if type(message) is tuple:
|
107 |
+
message, _, _ = message
|
108 |
+
ret.append({
|
109 |
+
"role": role.lower(),
|
110 |
+
"content": message
|
111 |
+
})
|
112 |
+
pass
|
113 |
+
else:
|
114 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
115 |
+
|
116 |
+
return ret
|
117 |
+
|
118 |
+
def append_message(self, role, message):
|
119 |
+
self.messages.append([role, message])
|
120 |
+
|
121 |
+
def get_images(self, return_pil=False):
|
122 |
+
images = []
|
123 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
124 |
+
if i % 2 == 0:
|
125 |
+
if type(msg) is tuple:
|
126 |
+
import base64
|
127 |
+
from io import BytesIO
|
128 |
+
from PIL import Image
|
129 |
+
msg, image, image_process_mode = msg
|
130 |
+
if image_process_mode == "Pad":
|
131 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
132 |
+
width, height = pil_img.size
|
133 |
+
if width == height:
|
134 |
+
return pil_img
|
135 |
+
elif width > height:
|
136 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
137 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
138 |
+
return result
|
139 |
+
else:
|
140 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
141 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
142 |
+
return result
|
143 |
+
image = expand2square(image)
|
144 |
+
elif image_process_mode in ["Default", "Crop"]:
|
145 |
+
pass
|
146 |
+
elif image_process_mode == "Resize":
|
147 |
+
image = image.resize((336, 336))
|
148 |
+
else:
|
149 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
150 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
151 |
+
aspect_ratio = max_hw / min_hw
|
152 |
+
max_len, min_len = 800, 400
|
153 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
154 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
155 |
+
W, H = image.size
|
156 |
+
if longest_edge != max(image.size):
|
157 |
+
if H > W:
|
158 |
+
H, W = longest_edge, shortest_edge
|
159 |
+
else:
|
160 |
+
H, W = shortest_edge, longest_edge
|
161 |
+
image = image.resize((W, H))
|
162 |
+
if return_pil:
|
163 |
+
images.append(image)
|
164 |
+
else:
|
165 |
+
buffered = BytesIO()
|
166 |
+
image.save(buffered, format="jpeg")
|
167 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
168 |
+
images.append(img_b64_str)
|
169 |
+
return images
|
170 |
+
|
171 |
+
def to_gradio_chatbot(self):
|
172 |
+
ret = []
|
173 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
174 |
+
if i % 2 == 0:
|
175 |
+
if type(msg) is tuple:
|
176 |
+
import base64
|
177 |
+
from io import BytesIO
|
178 |
+
msg, image, image_process_mode = msg
|
179 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
180 |
+
aspect_ratio = max_hw / min_hw
|
181 |
+
max_len, min_len = 800, 400
|
182 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
183 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
184 |
+
W, H = image.size
|
185 |
+
if H > W:
|
186 |
+
H, W = longest_edge, shortest_edge
|
187 |
+
else:
|
188 |
+
H, W = shortest_edge, longest_edge
|
189 |
+
image = image.resize((W, H))
|
190 |
+
buffered = BytesIO()
|
191 |
+
image.save(buffered, format="JPEG")
|
192 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
193 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
194 |
+
msg = img_str + msg.replace('<image>', '').strip()
|
195 |
+
ret.append([msg, None])
|
196 |
+
else:
|
197 |
+
ret.append([msg, None])
|
198 |
+
else:
|
199 |
+
ret[-1][-1] = msg
|
200 |
+
return ret
|
201 |
+
|
202 |
+
def copy(self):
|
203 |
+
return Conversation(
|
204 |
+
system=self.system,
|
205 |
+
roles=self.roles,
|
206 |
+
messages=[[x, y] for x, y in self.messages],
|
207 |
+
offset=self.offset,
|
208 |
+
sep_style=self.sep_style,
|
209 |
+
sep=self.sep,
|
210 |
+
sep2=self.sep2,
|
211 |
+
version=self.version)
|
212 |
+
|
213 |
+
def dict(self):
|
214 |
+
if len(self.get_images()) > 0:
|
215 |
+
return {
|
216 |
+
"system": self.system,
|
217 |
+
"roles": self.roles,
|
218 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
219 |
+
"offset": self.offset,
|
220 |
+
"sep": self.sep,
|
221 |
+
"sep2": self.sep2,
|
222 |
+
}
|
223 |
+
return {
|
224 |
+
"system": self.system,
|
225 |
+
"roles": self.roles,
|
226 |
+
"messages": self.messages,
|
227 |
+
"offset": self.offset,
|
228 |
+
"sep": self.sep,
|
229 |
+
"sep2": self.sep2,
|
230 |
+
}
|
231 |
+
|
232 |
+
|
233 |
+
conv_vicuna_v0 = Conversation(
|
234 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
235 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
236 |
+
roles=("Human", "Assistant"),
|
237 |
+
messages=(
|
238 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
239 |
+
("Assistant",
|
240 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
241 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
242 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
243 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
244 |
+
"renewable and non-renewable energy sources:\n"
|
245 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
246 |
+
"energy sources are finite and will eventually run out.\n"
|
247 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
248 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
249 |
+
"and other negative effects.\n"
|
250 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
251 |
+
"have lower operational costs than non-renewable sources.\n"
|
252 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
253 |
+
"locations than non-renewable sources.\n"
|
254 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
255 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
256 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
257 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
258 |
+
),
|
259 |
+
offset=2,
|
260 |
+
sep_style=SeparatorStyle.SINGLE,
|
261 |
+
sep="###",
|
262 |
+
)
|
263 |
+
|
264 |
+
conv_vicuna_v1 = Conversation(
|
265 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
266 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
267 |
+
roles=("USER", "ASSISTANT"),
|
268 |
+
version="v1",
|
269 |
+
messages=(),
|
270 |
+
offset=0,
|
271 |
+
sep_style=SeparatorStyle.GPT,
|
272 |
+
sep=" ",
|
273 |
+
sep2="</s>",
|
274 |
+
)
|
275 |
+
|
276 |
+
conv_llama_2 = Conversation(
|
277 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
278 |
+
|
279 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
280 |
+
roles=("USER", "ASSISTANT"),
|
281 |
+
version="llama_v2",
|
282 |
+
messages=(),
|
283 |
+
offset=0,
|
284 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
285 |
+
sep="<s>",
|
286 |
+
sep2="</s>",
|
287 |
+
)
|
288 |
+
|
289 |
+
conv_llava_llama_2 = Conversation(
|
290 |
+
system="You are a helpful language and vision assistant. "
|
291 |
+
"You are able to understand the visual content that the user provides, "
|
292 |
+
"and assist the user with a variety of tasks using natural language.",
|
293 |
+
roles=("USER", "ASSISTANT"),
|
294 |
+
version="llama_v2",
|
295 |
+
messages=(),
|
296 |
+
offset=0,
|
297 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
298 |
+
sep="<s>",
|
299 |
+
sep2="</s>",
|
300 |
+
)
|
301 |
+
|
302 |
+
conv_mpt = Conversation(
|
303 |
+
system="""<|im_start|>system
|
304 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
305 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
306 |
+
version="mpt",
|
307 |
+
messages=(),
|
308 |
+
offset=0,
|
309 |
+
sep_style=SeparatorStyle.MPT,
|
310 |
+
sep="<|im_end|>",
|
311 |
+
)
|
312 |
+
|
313 |
+
conv_llava_plain = Conversation(
|
314 |
+
system="",
|
315 |
+
roles=("", ""),
|
316 |
+
messages=(
|
317 |
+
),
|
318 |
+
offset=0,
|
319 |
+
sep_style=SeparatorStyle.PLAIN,
|
320 |
+
sep="\n",
|
321 |
+
)
|
322 |
+
|
323 |
+
conv_llava_v0 = Conversation(
|
324 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
325 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
326 |
+
roles=("Human", "Assistant"),
|
327 |
+
messages=(
|
328 |
+
),
|
329 |
+
offset=0,
|
330 |
+
sep_style=SeparatorStyle.SINGLE,
|
331 |
+
sep="###",
|
332 |
+
)
|
333 |
+
|
334 |
+
conv_llava_v0_mmtag = Conversation(
|
335 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
336 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
337 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
338 |
+
roles=("Human", "Assistant"),
|
339 |
+
messages=(
|
340 |
+
),
|
341 |
+
offset=0,
|
342 |
+
sep_style=SeparatorStyle.SINGLE,
|
343 |
+
sep="###",
|
344 |
+
version="v0_mmtag",
|
345 |
+
)
|
346 |
+
|
347 |
+
conv_llava_v1 = Conversation(
|
348 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
349 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
350 |
+
roles=("USER", "ASSISTANT"),
|
351 |
+
version="v1",
|
352 |
+
messages=(),
|
353 |
+
offset=0,
|
354 |
+
sep_style=SeparatorStyle.TWO,
|
355 |
+
sep=" ",
|
356 |
+
sep2="</s>",
|
357 |
+
)
|
358 |
+
|
359 |
+
conv_llava_v1_mmtag = Conversation(
|
360 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
361 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
362 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
363 |
+
roles=("USER", "ASSISTANT"),
|
364 |
+
messages=(),
|
365 |
+
offset=0,
|
366 |
+
sep_style=SeparatorStyle.TWO,
|
367 |
+
sep=" ",
|
368 |
+
sep2="</s>",
|
369 |
+
version="v1_mmtag",
|
370 |
+
)
|
371 |
+
|
372 |
+
default_conversation = conv_vicuna_v1
|
373 |
+
conv_templates = {
|
374 |
+
"default": conv_vicuna_v0,
|
375 |
+
"v0": conv_vicuna_v0,
|
376 |
+
"v1": conv_vicuna_v1,
|
377 |
+
"vicuna_v1": conv_vicuna_v1,
|
378 |
+
"llama_2": conv_llama_2,
|
379 |
+
|
380 |
+
"plain": conv_llava_plain,
|
381 |
+
"v0_plain": conv_llava_plain,
|
382 |
+
"llava_v0": conv_llava_v0,
|
383 |
+
"v0_mmtag": conv_llava_v0_mmtag,
|
384 |
+
"llava_v1": conv_llava_v1,
|
385 |
+
"v1_mmtag": conv_llava_v1_mmtag,
|
386 |
+
"llava_llama_2": conv_llava_llama_2,
|
387 |
+
|
388 |
+
"mpt": conv_mpt,
|
389 |
+
}
|
390 |
+
|
391 |
+
|
392 |
+
if __name__ == "__main__":
|
393 |
+
print(default_conversation.get_prompt())
|
app/utils.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import logging
|
3 |
+
import logging.handlers
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import requests
|
8 |
+
|
9 |
+
from app.constants import LOGDIR
|
10 |
+
|
11 |
+
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
12 |
+
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
|
13 |
+
|
14 |
+
handler = None
|
15 |
+
|
16 |
+
|
17 |
+
def build_logger(logger_name, logger_filename):
|
18 |
+
global handler
|
19 |
+
|
20 |
+
formatter = logging.Formatter(
|
21 |
+
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
22 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
23 |
+
)
|
24 |
+
|
25 |
+
# Set the format of root handlers
|
26 |
+
if not logging.getLogger().handlers:
|
27 |
+
logging.basicConfig(level=logging.INFO)
|
28 |
+
logging.getLogger().handlers[0].setFormatter(formatter)
|
29 |
+
|
30 |
+
# Redirect stdout and stderr to loggers
|
31 |
+
stdout_logger = logging.getLogger("stdout")
|
32 |
+
stdout_logger.setLevel(logging.INFO)
|
33 |
+
sl = StreamToLogger(stdout_logger, logging.INFO)
|
34 |
+
sys.stdout = sl
|
35 |
+
|
36 |
+
stderr_logger = logging.getLogger("stderr")
|
37 |
+
stderr_logger.setLevel(logging.ERROR)
|
38 |
+
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
39 |
+
sys.stderr = sl
|
40 |
+
|
41 |
+
# Get logger
|
42 |
+
logger = logging.getLogger(logger_name)
|
43 |
+
logger.setLevel(logging.INFO)
|
44 |
+
|
45 |
+
# Add a file handler for all loggers
|
46 |
+
if handler is None:
|
47 |
+
os.makedirs(LOGDIR, exist_ok=True)
|
48 |
+
filename = os.path.join(LOGDIR, logger_filename)
|
49 |
+
handler = logging.handlers.TimedRotatingFileHandler(
|
50 |
+
filename, when='D', utc=True, encoding='UTF-8')
|
51 |
+
handler.setFormatter(formatter)
|
52 |
+
|
53 |
+
for name, item in logging.root.manager.loggerDict.items():
|
54 |
+
if isinstance(item, logging.Logger):
|
55 |
+
item.addHandler(handler)
|
56 |
+
|
57 |
+
return logger
|
58 |
+
|
59 |
+
|
60 |
+
class StreamToLogger(object):
|
61 |
+
"""
|
62 |
+
Fake file-like stream object that redirects writes to a logger instance.
|
63 |
+
"""
|
64 |
+
def __init__(self, logger, log_level=logging.INFO):
|
65 |
+
self.terminal = sys.stdout
|
66 |
+
self.logger = logger
|
67 |
+
self.log_level = log_level
|
68 |
+
self.linebuf = ''
|
69 |
+
|
70 |
+
def __getattr__(self, attr):
|
71 |
+
return getattr(self.terminal, attr)
|
72 |
+
|
73 |
+
def write(self, buf):
|
74 |
+
temp_linebuf = self.linebuf + buf
|
75 |
+
self.linebuf = ''
|
76 |
+
for line in temp_linebuf.splitlines(True):
|
77 |
+
# From the io.TextIOWrapper docs:
|
78 |
+
# On output, if newline is None, any '\n' characters written
|
79 |
+
# are translated to the system default line separator.
|
80 |
+
# By default sys.stdout.write() expects '\n' newlines and then
|
81 |
+
# translates them so this is still cross platform.
|
82 |
+
if line[-1] == '\n':
|
83 |
+
self.logger.log(self.log_level, line.rstrip())
|
84 |
+
else:
|
85 |
+
self.linebuf += line
|
86 |
+
|
87 |
+
def flush(self):
|
88 |
+
if self.linebuf != '':
|
89 |
+
self.logger.log(self.log_level, self.linebuf.rstrip())
|
90 |
+
self.linebuf = ''
|
91 |
+
|
92 |
+
|
93 |
+
def disable_torch_init():
|
94 |
+
"""
|
95 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
96 |
+
"""
|
97 |
+
import torch
|
98 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
99 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
100 |
+
|
101 |
+
|
102 |
+
def violates_moderation(text):
|
103 |
+
"""
|
104 |
+
Check whether the text violates OpenAI moderation API.
|
105 |
+
"""
|
106 |
+
url = "https://api.openai.com/v1/moderations"
|
107 |
+
headers = {"Content-Type": "application/json",
|
108 |
+
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
|
109 |
+
text = text.replace("\n", "")
|
110 |
+
data = "{" + '"input": ' + f'"{text}"' + "}"
|
111 |
+
data = data.encode("utf-8")
|
112 |
+
try:
|
113 |
+
ret = requests.post(url, headers=headers, data=data, timeout=5)
|
114 |
+
flagged = ret.json()["results"][0]["flagged"]
|
115 |
+
except requests.exceptions.RequestException as e:
|
116 |
+
flagged = False
|
117 |
+
except KeyError as e:
|
118 |
+
flagged = False
|
119 |
+
|
120 |
+
return flagged
|
121 |
+
|
122 |
+
|
123 |
+
def pretty_print_semaphore(semaphore):
|
124 |
+
if semaphore is None:
|
125 |
+
return "None"
|
126 |
+
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
examples/extreme_ironing.jpg
ADDED
![]() |
examples/waterview.jpg
ADDED
![]() |
gradio_web_server.py
ADDED
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import requests
|
9 |
+
from PIL import Image
|
10 |
+
import base64
|
11 |
+
from io import BytesIO
|
12 |
+
|
13 |
+
from app.conversation import (default_conversation, conv_templates,
|
14 |
+
SeparatorStyle)
|
15 |
+
from app.constants import LOGDIR
|
16 |
+
from app.utils import (build_logger, server_error_msg,
|
17 |
+
violates_moderation, moderation_msg)
|
18 |
+
import hashlib
|
19 |
+
|
20 |
+
worker_addr = os.getenv('WORKER_ADDR')
|
21 |
+
apikey = os.getenv('AUTHORIZATION')
|
22 |
+
|
23 |
+
|
24 |
+
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
25 |
+
|
26 |
+
headers = {"Authorization": apikey}
|
27 |
+
|
28 |
+
no_change_btn = gr.Button()
|
29 |
+
enable_btn = gr.Button(interactive=True)
|
30 |
+
disable_btn = gr.Button(interactive=False)
|
31 |
+
|
32 |
+
priority = {
|
33 |
+
"vicuna-13b": "aaaaaaa",
|
34 |
+
"koala-13b": "aaaaaab",
|
35 |
+
}
|
36 |
+
|
37 |
+
|
38 |
+
def get_conv_log_filename():
|
39 |
+
t = datetime.datetime.now()
|
40 |
+
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
41 |
+
return name
|
42 |
+
|
43 |
+
|
44 |
+
def get_model_list():
|
45 |
+
ret = requests.post(args.controller_url + "/refresh_all_workers")
|
46 |
+
assert ret.status_code == 200
|
47 |
+
ret = requests.post(args.controller_url + "/list_models")
|
48 |
+
models = ret.json()["models"]
|
49 |
+
models.sort(key=lambda x: priority.get(x, x))
|
50 |
+
logger.info(f"Models: {models}")
|
51 |
+
return models
|
52 |
+
|
53 |
+
|
54 |
+
get_window_url_params = """
|
55 |
+
function() {
|
56 |
+
const params = new URLSearchParams(window.location.search);
|
57 |
+
url_params = Object.fromEntries(params);
|
58 |
+
console.log(url_params);
|
59 |
+
return url_params;
|
60 |
+
}
|
61 |
+
"""
|
62 |
+
|
63 |
+
|
64 |
+
def load_demo(url_params, request: gr.Request):
|
65 |
+
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
66 |
+
|
67 |
+
dropdown_update = gr.Dropdown.update(visible=True)
|
68 |
+
if "model" in url_params:
|
69 |
+
model = url_params["model"]
|
70 |
+
if model in models:
|
71 |
+
dropdown_update = gr.Dropdown.update(
|
72 |
+
value=model, visible=True)
|
73 |
+
|
74 |
+
state = default_conversation.copy()
|
75 |
+
return state, dropdown_update
|
76 |
+
|
77 |
+
|
78 |
+
def load_demo_refresh_model_list(request: gr.Request):
|
79 |
+
logger.info(f"load_demo. ip: {request.client.host}")
|
80 |
+
models = get_model_list()
|
81 |
+
state = default_conversation.copy()
|
82 |
+
dropdown_update = gr.Dropdown(
|
83 |
+
choices=models,
|
84 |
+
value=models[0] if len(models) > 0 else ""
|
85 |
+
)
|
86 |
+
return state, dropdown_update
|
87 |
+
|
88 |
+
|
89 |
+
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
90 |
+
with open(get_conv_log_filename(), "a") as fout:
|
91 |
+
data = {
|
92 |
+
"tstamp": round(time.time(), 4),
|
93 |
+
"type": vote_type,
|
94 |
+
"model": model_selector,
|
95 |
+
"state": state.dict(),
|
96 |
+
"ip": request.client.host,
|
97 |
+
}
|
98 |
+
fout.write(json.dumps(data) + "\n")
|
99 |
+
|
100 |
+
|
101 |
+
def upvote_last_response(state, model_selector, request: gr.Request):
|
102 |
+
logger.info(f"upvote. ip: {request.client.host}")
|
103 |
+
vote_last_response(state, "upvote", model_selector, request)
|
104 |
+
return ("",) + (disable_btn,) * 3
|
105 |
+
|
106 |
+
|
107 |
+
def downvote_last_response(state, model_selector, request: gr.Request):
|
108 |
+
logger.info(f"downvote. ip: {request.client.host}")
|
109 |
+
vote_last_response(state, "downvote", model_selector, request)
|
110 |
+
return ("",) + (disable_btn,) * 3
|
111 |
+
|
112 |
+
|
113 |
+
def flag_last_response(state, model_selector, request: gr.Request):
|
114 |
+
logger.info(f"flag. ip: {request.client.host}")
|
115 |
+
vote_last_response(state, "flag", model_selector, request)
|
116 |
+
return ("",) + (disable_btn,) * 3
|
117 |
+
|
118 |
+
|
119 |
+
def regenerate(state, image_process_mode, request: gr.Request):
|
120 |
+
logger.info(f"regenerate. ip: {request.client.host}")
|
121 |
+
state.messages[-1][-1] = None
|
122 |
+
prev_human_msg = state.messages[-2]
|
123 |
+
if type(prev_human_msg[1]) in (tuple, list):
|
124 |
+
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
|
125 |
+
state.skip_next = False
|
126 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
127 |
+
|
128 |
+
|
129 |
+
def clear_history(request: gr.Request):
|
130 |
+
logger.info(f"clear_history. ip: {request.client.host}")
|
131 |
+
state = default_conversation.copy()
|
132 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
133 |
+
|
134 |
+
|
135 |
+
def add_text(state, text, image, image_process_mode, request: gr.Request):
|
136 |
+
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
|
137 |
+
if len(text) <= 0 and image is None:
|
138 |
+
state.skip_next = True
|
139 |
+
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
|
140 |
+
if args.moderate:
|
141 |
+
flagged = violates_moderation(text)
|
142 |
+
if flagged:
|
143 |
+
state.skip_next = True
|
144 |
+
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
|
145 |
+
no_change_btn,) * 5
|
146 |
+
|
147 |
+
text = text[:1536] # Hard cut-off
|
148 |
+
if image is not None:
|
149 |
+
text = text[:1200] # Hard cut-off for images
|
150 |
+
if '<image>' not in text:
|
151 |
+
# text = '<Image><image></Image>' + text
|
152 |
+
text = text + '\n<image>'
|
153 |
+
text = (text, image, image_process_mode)
|
154 |
+
if len(state.get_images(return_pil=True)) > 0:
|
155 |
+
state = default_conversation.copy()
|
156 |
+
state.append_message(state.roles[0], text)
|
157 |
+
state.append_message(state.roles[1], None)
|
158 |
+
state.skip_next = False
|
159 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
160 |
+
|
161 |
+
def convert_image_to_base64(image, format=None):
|
162 |
+
# 如果未指定格式,则使用图像的原始格式
|
163 |
+
if format is None:
|
164 |
+
format = "jpeg"
|
165 |
+
|
166 |
+
# 将图像保存到字节流
|
167 |
+
buffered = BytesIO()
|
168 |
+
image.save(buffered, format=format)
|
169 |
+
|
170 |
+
# 编码为base64
|
171 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
172 |
+
|
173 |
+
# 格式化最终的字符串
|
174 |
+
return f"data:image/{format.lower()};base64,{img_str}"
|
175 |
+
|
176 |
+
# 修改此处请求模型的逻辑
|
177 |
+
def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
|
178 |
+
logger.info(f"http_bot. ip: {request.client.host}")
|
179 |
+
start_tstamp = time.time()
|
180 |
+
model_name = model_selector
|
181 |
+
|
182 |
+
if state.skip_next:
|
183 |
+
# This generate call is skipped due to invalid inputs
|
184 |
+
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
185 |
+
return
|
186 |
+
|
187 |
+
if len(state.messages) == state.offset + 2:
|
188 |
+
# First round of conversation
|
189 |
+
if "llava" in model_name.lower():
|
190 |
+
if 'llama-2' in model_name.lower():
|
191 |
+
template_name = "llava_llama_2"
|
192 |
+
elif "v1" in model_name.lower():
|
193 |
+
if 'mmtag' in model_name.lower():
|
194 |
+
template_name = "v1_mmtag"
|
195 |
+
elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
|
196 |
+
template_name = "v1_mmtag"
|
197 |
+
else:
|
198 |
+
template_name = "llava_v1"
|
199 |
+
elif "mpt" in model_name.lower():
|
200 |
+
template_name = "mpt"
|
201 |
+
else:
|
202 |
+
if 'mmtag' in model_name.lower():
|
203 |
+
template_name = "v0_mmtag"
|
204 |
+
elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
|
205 |
+
template_name = "v0_mmtag"
|
206 |
+
else:
|
207 |
+
template_name = "llava_v0"
|
208 |
+
elif "mpt" in model_name:
|
209 |
+
template_name = "mpt_text"
|
210 |
+
elif "llama-2" in model_name:
|
211 |
+
template_name = "llama_2"
|
212 |
+
else:
|
213 |
+
template_name = "vicuna_v1"
|
214 |
+
new_state = conv_templates[template_name].copy()
|
215 |
+
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
216 |
+
new_state.append_message(new_state.roles[1], None)
|
217 |
+
state = new_state
|
218 |
+
|
219 |
+
# Query worker address
|
220 |
+
# controller_url = args.controller_url
|
221 |
+
# ret = requests.post(controller_url + "/get_worker_address",
|
222 |
+
# json={"model": model_name})
|
223 |
+
# worker_addr = ret.json()["address"]
|
224 |
+
# logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
225 |
+
|
226 |
+
# # No available worker
|
227 |
+
# if worker_addr == "":
|
228 |
+
# state.messages[-1][-1] = server_error_msg
|
229 |
+
# yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
230 |
+
# return
|
231 |
+
|
232 |
+
# Construct prompt
|
233 |
+
prompt = state.get_prompt()
|
234 |
+
|
235 |
+
all_images = state.get_images(return_pil=True)
|
236 |
+
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
|
237 |
+
|
238 |
+
image = convert_image_to_base64(all_images[0])
|
239 |
+
|
240 |
+
# Make requests
|
241 |
+
pload = {
|
242 |
+
"image_path": image,
|
243 |
+
"model": model_name,
|
244 |
+
"messages": prompt,
|
245 |
+
"stream": True,
|
246 |
+
"max_tokens": 512
|
247 |
+
}
|
248 |
+
logger.info(f"==== request ====\n{json.dumps(pload)}")
|
249 |
+
|
250 |
+
pload['images'] = state.get_images()
|
251 |
+
|
252 |
+
state.messages[-1][-1] = "▌"
|
253 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
254 |
+
|
255 |
+
try:
|
256 |
+
# Stream output
|
257 |
+
response = requests.post(worker_addr,
|
258 |
+
headers=headers, json=pload, timeout=60)
|
259 |
+
|
260 |
+
output = json.loads(response.text)['message']['content']
|
261 |
+
logger.info("the response is {output}")
|
262 |
+
state.messages[-1][-1] = output
|
263 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
264 |
+
except requests.exceptions.RequestException as e:
|
265 |
+
state.messages[-1][-1] = server_error_msg
|
266 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
267 |
+
return
|
268 |
+
|
269 |
+
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
270 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
271 |
+
|
272 |
+
finish_tstamp = time.time()
|
273 |
+
logger.info(f"{output}")
|
274 |
+
|
275 |
+
with open(get_conv_log_filename(), "a") as fout:
|
276 |
+
data = {
|
277 |
+
"tstamp": round(finish_tstamp, 4),
|
278 |
+
"type": "chat",
|
279 |
+
"model": model_name,
|
280 |
+
"start": round(start_tstamp, 4),
|
281 |
+
"finish": round(finish_tstamp, 4),
|
282 |
+
"state": state.dict(),
|
283 |
+
"images": all_image_hash,
|
284 |
+
"ip": request.client.host,
|
285 |
+
}
|
286 |
+
fout.write(json.dumps(data) + "\n")
|
287 |
+
|
288 |
+
title_markdown = ("""
|
289 |
+
# 🌋 LLaVA: Large Language and Vision Assistant
|
290 |
+
[[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | 📚 [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)]
|
291 |
+
""")
|
292 |
+
|
293 |
+
tos_markdown = ("""
|
294 |
+
### Terms of use
|
295 |
+
By using this service, users are required to agree to the following terms:
|
296 |
+
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
297 |
+
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
298 |
+
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
299 |
+
""")
|
300 |
+
|
301 |
+
|
302 |
+
learn_more_markdown = ("""
|
303 |
+
### License
|
304 |
+
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
|
305 |
+
""")
|
306 |
+
|
307 |
+
block_css = """
|
308 |
+
|
309 |
+
#buttons button {
|
310 |
+
min-width: min(120px,100%);
|
311 |
+
}
|
312 |
+
|
313 |
+
"""
|
314 |
+
|
315 |
+
def build_demo(embed_mode):
|
316 |
+
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
|
317 |
+
with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
|
318 |
+
state = gr.State()
|
319 |
+
|
320 |
+
if not embed_mode:
|
321 |
+
gr.Markdown(title_markdown)
|
322 |
+
|
323 |
+
with gr.Row():
|
324 |
+
with gr.Column(scale=3):
|
325 |
+
with gr.Row(elem_id="model_selector_row"):
|
326 |
+
model_selector = gr.Dropdown(
|
327 |
+
choices=models,
|
328 |
+
value=models[0] if len(models) > 0 else "",
|
329 |
+
interactive=True,
|
330 |
+
show_label=False,
|
331 |
+
container=False)
|
332 |
+
|
333 |
+
imagebox = gr.Image(type="pil")
|
334 |
+
image_process_mode = gr.Radio(
|
335 |
+
["Crop", "Resize", "Pad", "Default"],
|
336 |
+
value="Default",
|
337 |
+
label="Preprocess for non-square image", visible=False)
|
338 |
+
|
339 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
340 |
+
gr.Examples(examples=[
|
341 |
+
[f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
|
342 |
+
[f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
|
343 |
+
], inputs=[imagebox, textbox])
|
344 |
+
|
345 |
+
with gr.Accordion("Parameters", open=False) as parameter_row:
|
346 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
|
347 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
|
348 |
+
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
|
349 |
+
|
350 |
+
with gr.Column(scale=8):
|
351 |
+
chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", height=550)
|
352 |
+
with gr.Row():
|
353 |
+
with gr.Column(scale=8):
|
354 |
+
textbox.render()
|
355 |
+
with gr.Column(scale=1, min_width=50):
|
356 |
+
submit_btn = gr.Button(value="Send", variant="primary")
|
357 |
+
with gr.Row(elem_id="buttons") as button_row:
|
358 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
359 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
360 |
+
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
|
361 |
+
#stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
|
362 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
363 |
+
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
|
364 |
+
|
365 |
+
if not embed_mode:
|
366 |
+
gr.Markdown(tos_markdown)
|
367 |
+
gr.Markdown(learn_more_markdown)
|
368 |
+
url_params = gr.JSON(visible=False)
|
369 |
+
|
370 |
+
# Register listeners
|
371 |
+
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
372 |
+
upvote_btn.click(
|
373 |
+
upvote_last_response,
|
374 |
+
[state, model_selector],
|
375 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
376 |
+
queue=False
|
377 |
+
)
|
378 |
+
downvote_btn.click(
|
379 |
+
downvote_last_response,
|
380 |
+
[state, model_selector],
|
381 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
382 |
+
queue=False
|
383 |
+
)
|
384 |
+
flag_btn.click(
|
385 |
+
flag_last_response,
|
386 |
+
[state, model_selector],
|
387 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
388 |
+
queue=False
|
389 |
+
)
|
390 |
+
|
391 |
+
regenerate_btn.click(
|
392 |
+
regenerate,
|
393 |
+
[state, image_process_mode],
|
394 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
395 |
+
queue=False
|
396 |
+
).then(
|
397 |
+
http_bot,
|
398 |
+
[state, model_selector, temperature, top_p, max_output_tokens],
|
399 |
+
[state, chatbot] + btn_list
|
400 |
+
)
|
401 |
+
|
402 |
+
clear_btn.click(
|
403 |
+
clear_history,
|
404 |
+
None,
|
405 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
406 |
+
queue=False
|
407 |
+
)
|
408 |
+
|
409 |
+
textbox.submit(
|
410 |
+
add_text,
|
411 |
+
[state, textbox, imagebox, image_process_mode],
|
412 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
413 |
+
queue=False
|
414 |
+
).then(
|
415 |
+
http_bot,
|
416 |
+
[state, model_selector, temperature, top_p, max_output_tokens],
|
417 |
+
[state, chatbot] + btn_list
|
418 |
+
)
|
419 |
+
|
420 |
+
submit_btn.click(
|
421 |
+
add_text,
|
422 |
+
[state, textbox, imagebox, image_process_mode],
|
423 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
424 |
+
queue=False
|
425 |
+
).then(
|
426 |
+
http_bot,
|
427 |
+
[state, model_selector, temperature, top_p, max_output_tokens],
|
428 |
+
[state, chatbot] + btn_list
|
429 |
+
)
|
430 |
+
|
431 |
+
if args.model_list_mode == "once":
|
432 |
+
demo.load(
|
433 |
+
load_demo,
|
434 |
+
[url_params],
|
435 |
+
[state, model_selector],
|
436 |
+
_js=get_window_url_params,
|
437 |
+
queue=False
|
438 |
+
)
|
439 |
+
elif args.model_list_mode == "reload":
|
440 |
+
demo.load(
|
441 |
+
load_demo_refresh_model_list,
|
442 |
+
None,
|
443 |
+
[state, model_selector],
|
444 |
+
queue=False
|
445 |
+
)
|
446 |
+
else:
|
447 |
+
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
448 |
+
|
449 |
+
return demo
|
450 |
+
|
451 |
+
|
452 |
+
if __name__ == "__main__":
|
453 |
+
parser = argparse.ArgumentParser()
|
454 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
455 |
+
parser.add_argument("--port", type=int)
|
456 |
+
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
|
457 |
+
parser.add_argument("--concurrency-count", type=int, default=10)
|
458 |
+
parser.add_argument("--model-list-mode", type=str, default="once",
|
459 |
+
choices=["once", "reload"])
|
460 |
+
parser.add_argument("--share", action="store_true")
|
461 |
+
parser.add_argument("--moderate", action="store_true")
|
462 |
+
parser.add_argument("--embed", action="store_true")
|
463 |
+
args = parser.parse_args()
|
464 |
+
logger.info(f"args: {args}")
|
465 |
+
|
466 |
+
models = ["yi-34b-vl"]
|
467 |
+
|
468 |
+
logger.info(args)
|
469 |
+
demo = build_demo(args.embed)
|
470 |
+
demo.queue(
|
471 |
+
concurrency_count=args.concurrency_count,
|
472 |
+
api_open=False
|
473 |
+
).launch(
|
474 |
+
server_name=args.host,
|
475 |
+
server_port=args.port,
|
476 |
+
share=args.share
|
477 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.2.1
|
2 |
+
altair==5.2.0
|
3 |
+
annotated-types==0.6.0
|
4 |
+
anyio==3.7.1
|
5 |
+
attrs==23.1.0
|
6 |
+
certifi==2023.11.17
|
7 |
+
charset-normalizer==3.3.2
|
8 |
+
click==8.1.7
|
9 |
+
contourpy==1.2.0
|
10 |
+
cycler==0.12.1
|
11 |
+
fastapi==0.105.0
|
12 |
+
ffmpy==0.3.1
|
13 |
+
filelock==3.13.1
|
14 |
+
fonttools==4.47.0
|
15 |
+
fsspec==2023.12.2
|
16 |
+
gradio==3.47.1
|
17 |
+
gradio_client==0.6.0
|
18 |
+
h11==0.14.0
|
19 |
+
httpcore==1.0.2
|
20 |
+
httpx==0.26.0
|
21 |
+
huggingface-hub==0.20.1
|
22 |
+
idna==3.6
|
23 |
+
importlib-resources==6.1.1
|
24 |
+
Jinja2==3.1.2
|
25 |
+
jsonschema==4.20.0
|
26 |
+
jsonschema-specifications==2023.11.2
|
27 |
+
kiwisolver==1.4.5
|
28 |
+
MarkupSafe==2.1.3
|
29 |
+
matplotlib==3.8.2
|
30 |
+
numpy==1.26.2
|
31 |
+
orjson==3.9.10
|
32 |
+
packaging==23.2
|
33 |
+
pandas==2.1.4
|
34 |
+
Pillow==10.1.0
|
35 |
+
pydantic==2.5.2
|
36 |
+
pydantic_core==2.14.5
|
37 |
+
pydub==0.25.1
|
38 |
+
pyparsing==3.1.1
|
39 |
+
python-dateutil==2.8.2
|
40 |
+
python-multipart==0.0.6
|
41 |
+
pytz==2023.3.post1
|
42 |
+
PyYAML==6.0.1
|
43 |
+
referencing==0.32.0
|
44 |
+
requests==2.31.0
|
45 |
+
rpds-py==0.15.2
|
46 |
+
semantic-version==2.10.0
|
47 |
+
six==1.16.0
|
48 |
+
sniffio==1.3.0
|
49 |
+
starlette==0.27.0
|
50 |
+
toolz==0.12.0
|
51 |
+
tqdm==4.66.1
|
52 |
+
typing_extensions==4.9.0
|
53 |
+
tzdata==2023.3
|
54 |
+
urllib3==2.1.0
|
55 |
+
uvicorn==0.25.0
|
56 |
+
websockets==11.0.3
|