1inkusFace commited on
Commit
bc6ba86
·
verified ·
1 Parent(s): 9f1950c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -258,7 +258,7 @@ def generate_30(
258
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
259
  print('catted shape: ', prompt_embeds.shape)
260
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
261
- print('catted pooled shape: ', prompt_embeds.shape)
262
  pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
263
  print('meaned pooled shape: ', pooled_prompt_embeds.shape)
264
 
@@ -271,7 +271,7 @@ def generate_30(
271
  print('catted shape2: ', prompt_embeds2.shape)
272
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
273
  print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
274
- pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=0,keepdim=True)
275
  print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
276
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
277
  print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
@@ -367,6 +367,7 @@ def generate_60(
367
  # 2. Encode with the two text encoders
368
  prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
369
  pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
 
370
  prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
371
  print('encoder shape: ', prompt_embeds_a.shape)
372
  prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
@@ -375,7 +376,7 @@ def generate_60(
375
 
376
  prompt_embeds_a2 = pipe.text_encoder_2(text_input_ids1b.to(torch.device('cuda')), output_hidden_states=True)
377
  pooled_prompt_embeds_a2 = prompt_embeds_a2[0] # Pooled output from encoder 1
378
- print('pooled shape: ', pooled_prompt_embeds_a2.shape)
379
  prompt_embeds_a2 = prompt_embeds_a2.hidden_states[-2] # Penultimate hidden state from encoder 1
380
  print('encoder shape2: ', prompt_embeds_a2.shape)
381
  prompt_embeds_b2 = pipe.text_encoder_2(text_input_ids2b.to(torch.device('cuda')), output_hidden_states=True)
@@ -386,7 +387,7 @@ def generate_60(
386
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
387
  print('catted shape: ', prompt_embeds.shape)
388
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
389
- print('catted pooled shape: ', prompt_embeds.shape)
390
  pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
391
  print('meaned pooled shape: ', pooled_prompt_embeds.shape)
392
 
@@ -495,6 +496,7 @@ def generate_90(
495
  # 2. Encode with the two text encoders
496
  prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
497
  pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
 
498
  prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
499
  print('encoder shape: ', prompt_embeds_a.shape)
500
  prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
@@ -503,7 +505,7 @@ def generate_90(
503
 
504
  prompt_embeds_a2 = pipe.text_encoder_2(text_input_ids1b.to(torch.device('cuda')), output_hidden_states=True)
505
  pooled_prompt_embeds_a2 = prompt_embeds_a2[0] # Pooled output from encoder 1
506
- print('pooled shape: ', pooled_prompt_embeds_a2.shape)
507
  prompt_embeds_a2 = prompt_embeds_a2.hidden_states[-2] # Penultimate hidden state from encoder 1
508
  print('encoder shape2: ', prompt_embeds_a2.shape)
509
  prompt_embeds_b2 = pipe.text_encoder_2(text_input_ids2b.to(torch.device('cuda')), output_hidden_states=True)
@@ -514,7 +516,7 @@ def generate_90(
514
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
515
  print('catted shape: ', prompt_embeds.shape)
516
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
517
- print('catted pooled shape: ', prompt_embeds.shape)
518
  pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
519
  print('meaned pooled shape: ', pooled_prompt_embeds.shape)
520
 
 
258
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
259
  print('catted shape: ', prompt_embeds.shape)
260
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
261
+ print('catted pooled shape: ', pooled_prompt_embeds.shape)
262
  pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
263
  print('meaned pooled shape: ', pooled_prompt_embeds.shape)
264
 
 
271
  print('catted shape2: ', prompt_embeds2.shape)
272
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
273
  print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
274
+ pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=0)
275
  print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
276
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
277
  print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
 
367
  # 2. Encode with the two text encoders
368
  prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
369
  pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
370
+ print('pooled shape 1: ', pooled_prompt_embeds_a.shape)
371
  prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
372
  print('encoder shape: ', prompt_embeds_a.shape)
373
  prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
 
376
 
377
  prompt_embeds_a2 = pipe.text_encoder_2(text_input_ids1b.to(torch.device('cuda')), output_hidden_states=True)
378
  pooled_prompt_embeds_a2 = prompt_embeds_a2[0] # Pooled output from encoder 1
379
+ print('pooled shape 2: ', pooled_prompt_embeds_a2.shape)
380
  prompt_embeds_a2 = prompt_embeds_a2.hidden_states[-2] # Penultimate hidden state from encoder 1
381
  print('encoder shape2: ', prompt_embeds_a2.shape)
382
  prompt_embeds_b2 = pipe.text_encoder_2(text_input_ids2b.to(torch.device('cuda')), output_hidden_states=True)
 
387
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
388
  print('catted shape: ', prompt_embeds.shape)
389
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
390
+ print('catted pooled shape: ', pooled_prompt_embeds.shape)
391
  pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
392
  print('meaned pooled shape: ', pooled_prompt_embeds.shape)
393
 
 
496
  # 2. Encode with the two text encoders
497
  prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
498
  pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
499
+ print('pooled shape 1: ', pooled_prompt_embeds_a.shape)
500
  prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
501
  print('encoder shape: ', prompt_embeds_a.shape)
502
  prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
 
505
 
506
  prompt_embeds_a2 = pipe.text_encoder_2(text_input_ids1b.to(torch.device('cuda')), output_hidden_states=True)
507
  pooled_prompt_embeds_a2 = prompt_embeds_a2[0] # Pooled output from encoder 1
508
+ print('pooled shape 2: ', pooled_prompt_embeds_a2.shape)
509
  prompt_embeds_a2 = prompt_embeds_a2.hidden_states[-2] # Penultimate hidden state from encoder 1
510
  print('encoder shape2: ', prompt_embeds_a2.shape)
511
  prompt_embeds_b2 = pipe.text_encoder_2(text_input_ids2b.to(torch.device('cuda')), output_hidden_states=True)
 
516
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
517
  print('catted shape: ', prompt_embeds.shape)
518
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
519
+ print('catted pooled shape: ', pooled_prompt_embeds.shape)
520
  pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
521
  print('meaned pooled shape: ', pooled_prompt_embeds.shape)
522