fancyfeast commited on
Commit
2a3d557
Β·
1 Parent(s): f4d3067

Man the chatinterface is weird #3

Browse files
Files changed (1) hide show
  1. app.py +21 -16
app.py CHANGED
@@ -72,10 +72,10 @@ assert isinstance(end_of_header_id, int) and isinstance(end_of_turn_id, int)
72
 
73
  @spaces.GPU()
74
  @torch.no_grad()
75
- def chat_joycaption(message: dict, history, temperature: float, max_new_tokens: int) -> Generator[str, None, None]:
76
  torch.cuda.empty_cache()
77
 
78
- print(message)
79
 
80
  # Prompts are always stripped in training for now
81
  prompt = message['text'].strip()
@@ -88,7 +88,8 @@ def chat_joycaption(message: dict, history, temperature: float, max_new_tokens:
88
  image = Image.open(message["files"][0])
89
 
90
  # Log the prompt
91
- print(f"Prompt: {prompt}")
 
92
 
93
  # Preprocess image
94
  # NOTE: I found the default processor for so400M to have worse results than just using PIL directly
@@ -148,7 +149,7 @@ def chat_joycaption(message: dict, history, temperature: float, max_new_tokens:
148
  use_cache=True,
149
  temperature=temperature,
150
  top_k=None,
151
- top_p=0.9,
152
  streamer=streamer,
153
  )
154
 
@@ -170,14 +171,14 @@ textbox = gr.MultimodalTextbox(file_types=["image"], file_count="single")
170
  with gr.Blocks() as demo:
171
  gr.HTML(TITLE)
172
  gr.Markdown(DESCRIPTION)
173
- gr.ChatInterface(
174
  fn=chat_joycaption,
175
  chatbot=chatbot,
176
  type="messages",
177
  fill_height=True,
178
  multimodal=True,
179
  textbox=textbox,
180
- additional_inputs_accordion=None,#gr.Accordion(label="βš™οΈ Parameters", open=False, render=False),
181
  additional_inputs=[
182
  gr.Slider(minimum=0,
183
  maximum=1,
@@ -185,23 +186,27 @@ with gr.Blocks() as demo:
185
  value=0.6,
186
  label="Temperature",
187
  render=False),
188
- gr.Slider(minimum=128,
 
 
 
 
 
 
189
  maximum=4096,
190
  step=1,
191
  value=1024,
192
  label="Max new tokens",
193
  render=False ),
194
- ],
195
- examples=[
196
- ['How to setup a human base on Mars? Give short answer.'],
197
- ['Explain theory of relativity to me like I’m 8 years old.'],
198
- ['What is 9,000 * 9,000?'],
199
- ['Write a pun-filled happy birthday message to my friend Alex.'],
200
- ['Justify why a penguin might make a good king of the jungle.']
201
- ],
202
- cache_examples=False,
203
  )
204
 
 
 
 
 
 
205
 
206
  if __name__ == "__main__":
207
  demo.launch()
 
72
 
73
  @spaces.GPU()
74
  @torch.no_grad()
75
+ def chat_joycaption(message: dict, history, temperature: float, top_p: float, max_new_tokens: int, log_prompt: bool) -> Generator[str, None, None]:
76
  torch.cuda.empty_cache()
77
 
78
+ chat_interface.chatbot_state
79
 
80
  # Prompts are always stripped in training for now
81
  prompt = message['text'].strip()
 
88
  image = Image.open(message["files"][0])
89
 
90
  # Log the prompt
91
+ if log_prompt:
92
+ print(f"Prompt: {prompt}")
93
 
94
  # Preprocess image
95
  # NOTE: I found the default processor for so400M to have worse results than just using PIL directly
 
149
  use_cache=True,
150
  temperature=temperature,
151
  top_k=None,
152
+ top_p=top_p,
153
  streamer=streamer,
154
  )
155
 
 
171
  with gr.Blocks() as demo:
172
  gr.HTML(TITLE)
173
  gr.Markdown(DESCRIPTION)
174
+ chat_interface = gr.ChatInterface(
175
  fn=chat_joycaption,
176
  chatbot=chatbot,
177
  type="messages",
178
  fill_height=True,
179
  multimodal=True,
180
  textbox=textbox,
181
+ additional_inputs_accordion=gr.Accordion(label="βš™οΈ Parameters", open=True, render=True),
182
  additional_inputs=[
183
  gr.Slider(minimum=0,
184
  maximum=1,
 
186
  value=0.6,
187
  label="Temperature",
188
  render=False),
189
+ gr.Slider(minimum=0,
190
+ maximum=1,
191
+ step=0.05,
192
+ value=0.9,
193
+ label="Top p",
194
+ render=False),
195
+ gr.Slider(minimum=8,
196
  maximum=4096,
197
  step=1,
198
  value=1024,
199
  label="Max new tokens",
200
  render=False ),
201
+ gr.Checkbox(label="Help improve JoyCaption by logging your text query", default=True, render=True),
202
+ ],
 
 
 
 
 
 
 
203
  )
204
 
205
+ def new_trim_history(self, message, history_with_input):
206
+ return message, []
207
+
208
+ chat_interface._process_msg_and_trim_history = new_trim_history.__get__(chat_interface, chat_interface.__class__)
209
+
210
 
211
  if __name__ == "__main__":
212
  demo.launch()