1inkusFace commited on
Commit
6c95eb1
·
verified ·
1 Parent(s): dc2becd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -0
app.py CHANGED
@@ -212,6 +212,7 @@ def generate_30(
212
  )
213
  text_input_ids2 = text_inputs2.input_ids
214
  prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
 
215
  pooled_prompt_embeds_list.append(prompt_embedsa[0])
216
  prompt_embedsa = prompt_embedsa.hidden_states[-2]
217
  prompt_embeds_list.append(prompt_embedsa)
@@ -220,6 +221,7 @@ def generate_30(
220
  prompt_embedsb = prompt_embedsb.hidden_states[-2]
221
  prompt_embeds_list.append(prompt_embedsb)
222
  prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=1, keepdim=True)
 
223
  pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=1, keepdim=True)
224
 
225
  options = {
@@ -287,6 +289,7 @@ def generate_60(
287
  )
288
  text_input_ids2 = text_inputs2.input_ids
289
  prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
 
290
  pooled_prompt_embeds_list.append(prompt_embedsa[0])
291
  prompt_embedsa = prompt_embedsa.hidden_states[-2]
292
  prompt_embeds_list.append(prompt_embedsa)
@@ -295,6 +298,7 @@ def generate_60(
295
  prompt_embedsb = prompt_embedsb.hidden_states[-2]
296
  prompt_embeds_list.append(prompt_embedsb)
297
  prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=1, keepdim=True)
 
298
  pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=1, keepdim=True)
299
 
300
  options = {
@@ -362,6 +366,7 @@ def generate_90(
362
  )
363
  text_input_ids2 = text_inputs2.input_ids
364
  prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
 
365
  pooled_prompt_embeds_list.append(prompt_embedsa[0])
366
  prompt_embedsa = prompt_embedsa.hidden_states[-2]
367
  prompt_embeds_list.append(prompt_embedsa)
@@ -370,6 +375,7 @@ def generate_90(
370
  prompt_embedsb = prompt_embedsb.hidden_states[-2]
371
  prompt_embeds_list.append(prompt_embedsb)
372
  prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=1, keepdim=True)
 
373
  pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=1, keepdim=True)
374
 
375
  options = {
 
212
  )
213
  text_input_ids2 = text_inputs2.input_ids
214
  prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
215
+ print('text_encoder shape: ',prompt_embedsa.shape)
216
  pooled_prompt_embeds_list.append(prompt_embedsa[0])
217
  prompt_embedsa = prompt_embedsa.hidden_states[-2]
218
  prompt_embeds_list.append(prompt_embedsa)
 
221
  prompt_embedsb = prompt_embedsb.hidden_states[-2]
222
  prompt_embeds_list.append(prompt_embedsb)
223
  prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=1, keepdim=True)
224
+ print('catted shape: ',prompt_embeds.shape)
225
  pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=1, keepdim=True)
226
 
227
  options = {
 
289
  )
290
  text_input_ids2 = text_inputs2.input_ids
291
  prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
292
+ print('text_encoder shape: ',prompt_embedsa.shape)
293
  pooled_prompt_embeds_list.append(prompt_embedsa[0])
294
  prompt_embedsa = prompt_embedsa.hidden_states[-2]
295
  prompt_embeds_list.append(prompt_embedsa)
 
298
  prompt_embedsb = prompt_embedsb.hidden_states[-2]
299
  prompt_embeds_list.append(prompt_embedsb)
300
  prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=1, keepdim=True)
301
+ print('catted shape: ',prompt_embeds.shape)
302
  pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=1, keepdim=True)
303
 
304
  options = {
 
366
  )
367
  text_input_ids2 = text_inputs2.input_ids
368
  prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
369
+ print('text_encoder shape: ',prompt_embedsa.shape)
370
  pooled_prompt_embeds_list.append(prompt_embedsa[0])
371
  prompt_embedsa = prompt_embedsa.hidden_states[-2]
372
  prompt_embeds_list.append(prompt_embedsa)
 
375
  prompt_embedsb = prompt_embedsb.hidden_states[-2]
376
  prompt_embeds_list.append(prompt_embedsb)
377
  prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=1, keepdim=True)
378
+ print('catted shape: ',prompt_embeds.shape)
379
  pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=1, keepdim=True)
380
 
381
  options = {