Spaces:
Build error
Build error
Add fix
Browse files- app.py +25 -16
- fromage/models.py +1 -0
app.py
CHANGED
@@ -27,7 +27,7 @@ class FromageChatBot:
|
|
27 |
|
28 |
|
29 |
def upload_image(self, state, image_input):
|
30 |
-
state += [(f"", "
|
31 |
self.input_image = Image.open(image_input.name).resize((224, 224)).convert('RGB')
|
32 |
return state, state
|
33 |
|
@@ -42,7 +42,7 @@ class FromageChatBot:
|
|
42 |
def generate_for_prompt(self, input_text, state, ret_scale_factor, max_nm_rets, num_words, temperature):
|
43 |
input_prompt = 'Q: ' + input_text + '\nA:'
|
44 |
self.chat_history += input_prompt
|
45 |
-
print('Generating for', self.chat_history)
|
46 |
|
47 |
# If an image was uploaded, prepend it to the model.
|
48 |
model_inputs = None
|
@@ -57,6 +57,7 @@ class FromageChatBot:
|
|
57 |
model_outputs = self.model.generate_for_images_and_texts(model_inputs,
|
58 |
num_words=num_words, ret_scale_factor=ret_scale_factor, top_p=top_p,
|
59 |
temperature=temperature, max_num_rets=max_nm_rets)
|
|
|
60 |
|
61 |
im_names = []
|
62 |
response = ''
|
@@ -73,6 +74,7 @@ class FromageChatBot:
|
|
73 |
filename = self.save_image_to_local(output)
|
74 |
response += f'<img src="/file={filename}">'
|
75 |
|
|
|
76 |
self.chat_history += ' '.join(text_outputs)
|
77 |
if self.chat_history[-1] != '\n':
|
78 |
self.chat_history += '\n'
|
@@ -88,26 +90,33 @@ class FromageChatBot:
|
|
88 |
'### Grounding Language Models to Images for Multimodal Generation'
|
89 |
)
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
max_ret_images = gr.Number(minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return")
|
94 |
-
gr_max_len = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True)
|
95 |
-
gr_temperature = gr.Number(value=0.0, label="Temperature", interactive=True)
|
96 |
|
97 |
-
with gr.
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
text_input.submit(self.generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
|
105 |
-
text_input.submit(lambda :"", None, text_input)
|
106 |
image_btn.upload(self.upload_image, [gr_state, image_btn], [gr_state, chatbot])
|
107 |
clear_btn.click(self.reset, [], [gr_state, chatbot])
|
108 |
|
109 |
demo.launch(share=False, server_name="0.0.0.0")
|
110 |
|
111 |
|
112 |
-
|
113 |
-
chatbot
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
|
29 |
def upload_image(self, state, image_input):
|
30 |
+
state += [(f"", "(Image received. Type or ask something to continue.)")]
|
31 |
self.input_image = Image.open(image_input.name).resize((224, 224)).convert('RGB')
|
32 |
return state, state
|
33 |
|
|
|
42 |
def generate_for_prompt(self, input_text, state, ret_scale_factor, max_nm_rets, num_words, temperature):
|
43 |
input_prompt = 'Q: ' + input_text + '\nA:'
|
44 |
self.chat_history += input_prompt
|
45 |
+
print('Generating for', self.chat_history, flush=True)
|
46 |
|
47 |
# If an image was uploaded, prepend it to the model.
|
48 |
model_inputs = None
|
|
|
57 |
model_outputs = self.model.generate_for_images_and_texts(model_inputs,
|
58 |
num_words=num_words, ret_scale_factor=ret_scale_factor, top_p=top_p,
|
59 |
temperature=temperature, max_num_rets=max_nm_rets)
|
60 |
+
print('model_outputs', model_outputs, flush=True)
|
61 |
|
62 |
im_names = []
|
63 |
response = ''
|
|
|
74 |
filename = self.save_image_to_local(output)
|
75 |
response += f'<img src="/file={filename}">'
|
76 |
|
77 |
+
# TODO(jykoh): Persist image inputs.
|
78 |
self.chat_history += ' '.join(text_outputs)
|
79 |
if self.chat_history[-1] != '\n':
|
80 |
self.chat_history += '\n'
|
|
|
90 |
'### Grounding Language Models to Images for Multimodal Generation'
|
91 |
)
|
92 |
|
93 |
+
chatbot = gr.Chatbot()
|
94 |
+
gr_state = gr.State([])
|
|
|
|
|
|
|
95 |
|
96 |
+
with gr.Row():
|
97 |
+
with gr.Column(scale=0.3, min_width=0):
|
98 |
+
ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, label="Multiplier for returning images (higher means more frequent)")
|
99 |
+
max_ret_images = gr.Number(minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return")
|
100 |
+
gr_max_len = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True)
|
101 |
+
gr_temperature = gr.Number(value=0.0, label="Temperature", interactive=True)
|
102 |
+
|
103 |
+
with gr.Column(scale=0.7, min_width=0):
|
104 |
+
image_btn = gr.UploadButton("Image Input", file_types=["image"])
|
105 |
+
text_input = gr.Textbox(label="Text Input", lines=1, placeholder="Upload an image above [optional]. Then enter a text prompt, and press enter!")
|
106 |
+
clear_btn = gr.Button("Clear History")
|
107 |
|
108 |
text_input.submit(self.generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
|
|
|
109 |
image_btn.upload(self.upload_image, [gr_state, image_btn], [gr_state, chatbot])
|
110 |
clear_btn.click(self.reset, [], [gr_state, chatbot])
|
111 |
|
112 |
demo.launch(share=False, server_name="0.0.0.0")
|
113 |
|
114 |
|
115 |
+
def main():
|
116 |
+
chatbot = FromageChatBot()
|
117 |
+
chatbot.launch()
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == "__main__":
|
121 |
+
chatbot = FromageChatBot()
|
122 |
+
chatbot.launch()
|
fromage/models.py
CHANGED
@@ -651,6 +651,7 @@ def load_fromage(embeddings_dir: str, model_args_path: str, model_ckpt_path: str
|
|
651 |
emb_matrix = emb_matrix / emb_matrix.norm(dim=1, keepdim=True)
|
652 |
emb_matrix = logit_scale * emb_matrix
|
653 |
model.emb_matrix = emb_matrix
|
|
|
654 |
|
655 |
return model
|
656 |
|
|
|
651 |
emb_matrix = emb_matrix / emb_matrix.norm(dim=1, keepdim=True)
|
652 |
emb_matrix = logit_scale * emb_matrix
|
653 |
model.emb_matrix = emb_matrix
|
654 |
+
print('Done loading FROMAGe!')
|
655 |
|
656 |
return model
|
657 |
|