Spaces:
Runtime error
Runtime error
fancyfeast
commited on
Commit
Β·
2a3d557
1
Parent(s):
f4d3067
Man the chatinterface is weird #3
Browse files
app.py
CHANGED
@@ -72,10 +72,10 @@ assert isinstance(end_of_header_id, int) and isinstance(end_of_turn_id, int)
|
|
72 |
|
73 |
@spaces.GPU()
|
74 |
@torch.no_grad()
|
75 |
-
def chat_joycaption(message: dict, history, temperature: float, max_new_tokens: int) -> Generator[str, None, None]:
|
76 |
torch.cuda.empty_cache()
|
77 |
|
78 |
-
|
79 |
|
80 |
# Prompts are always stripped in training for now
|
81 |
prompt = message['text'].strip()
|
@@ -88,7 +88,8 @@ def chat_joycaption(message: dict, history, temperature: float, max_new_tokens:
|
|
88 |
image = Image.open(message["files"][0])
|
89 |
|
90 |
# Log the prompt
|
91 |
-
|
|
|
92 |
|
93 |
# Preprocess image
|
94 |
# NOTE: I found the default processor for so400M to have worse results than just using PIL directly
|
@@ -148,7 +149,7 @@ def chat_joycaption(message: dict, history, temperature: float, max_new_tokens:
|
|
148 |
use_cache=True,
|
149 |
temperature=temperature,
|
150 |
top_k=None,
|
151 |
-
top_p=
|
152 |
streamer=streamer,
|
153 |
)
|
154 |
|
@@ -170,14 +171,14 @@ textbox = gr.MultimodalTextbox(file_types=["image"], file_count="single")
|
|
170 |
with gr.Blocks() as demo:
|
171 |
gr.HTML(TITLE)
|
172 |
gr.Markdown(DESCRIPTION)
|
173 |
-
gr.ChatInterface(
|
174 |
fn=chat_joycaption,
|
175 |
chatbot=chatbot,
|
176 |
type="messages",
|
177 |
fill_height=True,
|
178 |
multimodal=True,
|
179 |
textbox=textbox,
|
180 |
-
additional_inputs_accordion=
|
181 |
additional_inputs=[
|
182 |
gr.Slider(minimum=0,
|
183 |
maximum=1,
|
@@ -185,23 +186,27 @@ with gr.Blocks() as demo:
|
|
185 |
value=0.6,
|
186 |
label="Temperature",
|
187 |
render=False),
|
188 |
-
gr.Slider(minimum=
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
maximum=4096,
|
190 |
step=1,
|
191 |
value=1024,
|
192 |
label="Max new tokens",
|
193 |
render=False ),
|
194 |
-
|
195 |
-
|
196 |
-
['How to setup a human base on Mars? Give short answer.'],
|
197 |
-
['Explain theory of relativity to me like Iβm 8 years old.'],
|
198 |
-
['What is 9,000 * 9,000?'],
|
199 |
-
['Write a pun-filled happy birthday message to my friend Alex.'],
|
200 |
-
['Justify why a penguin might make a good king of the jungle.']
|
201 |
-
],
|
202 |
-
cache_examples=False,
|
203 |
)
|
204 |
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
if __name__ == "__main__":
|
207 |
demo.launch()
|
|
|
72 |
|
73 |
@spaces.GPU()
|
74 |
@torch.no_grad()
|
75 |
+
def chat_joycaption(message: dict, history, temperature: float, top_p: float, max_new_tokens: int, log_prompt: bool) -> Generator[str, None, None]:
|
76 |
torch.cuda.empty_cache()
|
77 |
|
78 |
+
chat_interface.chatbot_state
|
79 |
|
80 |
# Prompts are always stripped in training for now
|
81 |
prompt = message['text'].strip()
|
|
|
88 |
image = Image.open(message["files"][0])
|
89 |
|
90 |
# Log the prompt
|
91 |
+
if log_prompt:
|
92 |
+
print(f"Prompt: {prompt}")
|
93 |
|
94 |
# Preprocess image
|
95 |
# NOTE: I found the default processor for so400M to have worse results than just using PIL directly
|
|
|
149 |
use_cache=True,
|
150 |
temperature=temperature,
|
151 |
top_k=None,
|
152 |
+
top_p=top_p,
|
153 |
streamer=streamer,
|
154 |
)
|
155 |
|
|
|
171 |
with gr.Blocks() as demo:
|
172 |
gr.HTML(TITLE)
|
173 |
gr.Markdown(DESCRIPTION)
|
174 |
+
chat_interface = gr.ChatInterface(
|
175 |
fn=chat_joycaption,
|
176 |
chatbot=chatbot,
|
177 |
type="messages",
|
178 |
fill_height=True,
|
179 |
multimodal=True,
|
180 |
textbox=textbox,
|
181 |
+
additional_inputs_accordion=gr.Accordion(label="βοΈ Parameters", open=True, render=True),
|
182 |
additional_inputs=[
|
183 |
gr.Slider(minimum=0,
|
184 |
maximum=1,
|
|
|
186 |
value=0.6,
|
187 |
label="Temperature",
|
188 |
render=False),
|
189 |
+
gr.Slider(minimum=0,
|
190 |
+
maximum=1,
|
191 |
+
step=0.05,
|
192 |
+
value=0.9,
|
193 |
+
label="Top p",
|
194 |
+
render=False),
|
195 |
+
gr.Slider(minimum=8,
|
196 |
maximum=4096,
|
197 |
step=1,
|
198 |
value=1024,
|
199 |
label="Max new tokens",
|
200 |
render=False ),
|
201 |
+
gr.Checkbox(label="Help improve JoyCaption by logging your text query", default=True, render=True),
|
202 |
+
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
)
|
204 |
|
205 |
+
def new_trim_history(self, message, history_with_input):
|
206 |
+
return message, []
|
207 |
+
|
208 |
+
chat_interface._process_msg_and_trim_history = new_trim_history.__get__(chat_interface, chat_interface.__class__)
|
209 |
+
|
210 |
|
211 |
if __name__ == "__main__":
|
212 |
demo.launch()
|