Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -212,6 +212,7 @@ def generate_30(
|
|
212 |
)
|
213 |
text_input_ids2 = text_inputs2.input_ids
|
214 |
prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
|
|
|
215 |
pooled_prompt_embeds_list.append(prompt_embedsa[0])
|
216 |
prompt_embedsa = prompt_embedsa.hidden_states[-2]
|
217 |
prompt_embeds_list.append(prompt_embedsa)
|
@@ -220,6 +221,7 @@ def generate_30(
|
|
220 |
prompt_embedsb = prompt_embedsb.hidden_states[-2]
|
221 |
prompt_embeds_list.append(prompt_embedsb)
|
222 |
prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=1, keepdim=True)
|
|
|
223 |
pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=1, keepdim=True)
|
224 |
|
225 |
options = {
|
@@ -287,6 +289,7 @@ def generate_60(
|
|
287 |
)
|
288 |
text_input_ids2 = text_inputs2.input_ids
|
289 |
prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
|
|
|
290 |
pooled_prompt_embeds_list.append(prompt_embedsa[0])
|
291 |
prompt_embedsa = prompt_embedsa.hidden_states[-2]
|
292 |
prompt_embeds_list.append(prompt_embedsa)
|
@@ -295,6 +298,7 @@ def generate_60(
|
|
295 |
prompt_embedsb = prompt_embedsb.hidden_states[-2]
|
296 |
prompt_embeds_list.append(prompt_embedsb)
|
297 |
prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=1, keepdim=True)
|
|
|
298 |
pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=1, keepdim=True)
|
299 |
|
300 |
options = {
|
@@ -362,6 +366,7 @@ def generate_90(
|
|
362 |
)
|
363 |
text_input_ids2 = text_inputs2.input_ids
|
364 |
prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
|
|
|
365 |
pooled_prompt_embeds_list.append(prompt_embedsa[0])
|
366 |
prompt_embedsa = prompt_embedsa.hidden_states[-2]
|
367 |
prompt_embeds_list.append(prompt_embedsa)
|
@@ -370,6 +375,7 @@ def generate_90(
|
|
370 |
prompt_embedsb = prompt_embedsb.hidden_states[-2]
|
371 |
prompt_embeds_list.append(prompt_embedsb)
|
372 |
prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=1, keepdim=True)
|
|
|
373 |
pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=1, keepdim=True)
|
374 |
|
375 |
options = {
|
|
|
212 |
)
|
213 |
text_input_ids2 = text_inputs2.input_ids
|
214 |
prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
|
215 |
+
print('text_encoder shape: ',prompt_embedsa.shape)
|
216 |
pooled_prompt_embeds_list.append(prompt_embedsa[0])
|
217 |
prompt_embedsa = prompt_embedsa.hidden_states[-2]
|
218 |
prompt_embeds_list.append(prompt_embedsa)
|
|
|
221 |
prompt_embedsb = prompt_embedsb.hidden_states[-2]
|
222 |
prompt_embeds_list.append(prompt_embedsb)
|
223 |
prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=1, keepdim=True)
|
224 |
+
print('catted shape: ',prompt_embeds.shape)
|
225 |
pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=1, keepdim=True)
|
226 |
|
227 |
options = {
|
|
|
289 |
)
|
290 |
text_input_ids2 = text_inputs2.input_ids
|
291 |
prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
|
292 |
+
print('text_encoder shape: ',prompt_embedsa.shape)
|
293 |
pooled_prompt_embeds_list.append(prompt_embedsa[0])
|
294 |
prompt_embedsa = prompt_embedsa.hidden_states[-2]
|
295 |
prompt_embeds_list.append(prompt_embedsa)
|
|
|
298 |
prompt_embedsb = prompt_embedsb.hidden_states[-2]
|
299 |
prompt_embeds_list.append(prompt_embedsb)
|
300 |
prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=1, keepdim=True)
|
301 |
+
print('catted shape: ',prompt_embeds.shape)
|
302 |
pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=1, keepdim=True)
|
303 |
|
304 |
options = {
|
|
|
366 |
)
|
367 |
text_input_ids2 = text_inputs2.input_ids
|
368 |
prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
|
369 |
+
print('text_encoder shape: ',prompt_embedsa.shape)
|
370 |
pooled_prompt_embeds_list.append(prompt_embedsa[0])
|
371 |
prompt_embedsa = prompt_embedsa.hidden_states[-2]
|
372 |
prompt_embeds_list.append(prompt_embedsa)
|
|
|
375 |
prompt_embedsb = prompt_embedsb.hidden_states[-2]
|
376 |
prompt_embeds_list.append(prompt_embedsb)
|
377 |
prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=1, keepdim=True)
|
378 |
+
print('catted shape: ',prompt_embeds.shape)
|
379 |
pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=1, keepdim=True)
|
380 |
|
381 |
options = {
|