1inkusFace commited on
Commit
7b87899
·
verified ·
1 Parent(s): 91bf97e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -19
app.py CHANGED
@@ -238,6 +238,7 @@ def generate_30(
238
  # 2. Encode with the two text encoders
239
  prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
240
  pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
 
241
  prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
242
  print('encoder shape: ', prompt_embeds_a.shape)
243
  prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
@@ -246,7 +247,7 @@ def generate_30(
246
 
247
  prompt_embeds_a2 = pipe.text_encoder_2(text_input_ids1b.to(torch.device('cuda')), output_hidden_states=True)
248
  pooled_prompt_embeds_a2 = prompt_embeds_a2[0] # Pooled output from encoder 1
249
- print('pooled shape: ', pooled_prompt_embeds_a2.shape)
250
  prompt_embeds_a2 = prompt_embeds_a2.hidden_states[-2] # Penultimate hidden state from encoder 1
251
  print('encoder shape2: ', prompt_embeds_a2.shape)
252
  prompt_embeds_b2 = pipe.text_encoder_2(text_input_ids2b.to(torch.device('cuda')), output_hidden_states=True)
@@ -258,6 +259,8 @@ def generate_30(
258
  print('catted shape: ', prompt_embeds.shape)
259
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
260
  print('catted pooled shape: ', prompt_embeds.shape)
 
 
261
 
262
  # 4. (Optional) Average the pooled embeddings
263
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
@@ -267,13 +270,12 @@ def generate_30(
267
  prompt_embeds2 = torch.cat([prompt_embeds_a, prompt_embeds_b])
268
  print('catted shape2: ', prompt_embeds2.shape)
269
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
270
- print('catted pooled shape 2: ', prompt_embeds.shape)
 
 
271
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
272
- print('catted combined pooled shape: ', prompt_embeds.shape)
273
- # 4. (Optional) Average the pooled embeddings
274
- pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
275
- print('pooled averaged shape: ', pooled_prompt_embeds.shape)
276
-
277
  options = {
278
  #"prompt": prompt,
279
  "prompt_embeds": prompt_embeds,
@@ -385,6 +387,8 @@ def generate_60(
385
  print('catted shape: ', prompt_embeds.shape)
386
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
387
  print('catted pooled shape: ', prompt_embeds.shape)
 
 
388
 
389
  # 4. (Optional) Average the pooled embeddings
390
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
@@ -394,13 +398,12 @@ def generate_60(
394
  prompt_embeds2 = torch.cat([prompt_embeds_a, prompt_embeds_b])
395
  print('catted shape2: ', prompt_embeds2.shape)
396
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
397
- print('catted pooled shape 2: ', prompt_embeds.shape)
 
 
398
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
399
- print('catted combined pooled shape: ', prompt_embeds.shape)
400
- # 4. (Optional) Average the pooled embeddings
401
- pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
402
- print('pooled averaged shape: ', pooled_prompt_embeds.shape)
403
-
404
  options = {
405
  #"prompt": prompt,
406
  "prompt_embeds": prompt_embeds,
@@ -512,6 +515,8 @@ def generate_90(
512
  print('catted shape: ', prompt_embeds.shape)
513
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
514
  print('catted pooled shape: ', prompt_embeds.shape)
 
 
515
 
516
  # 4. (Optional) Average the pooled embeddings
517
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
@@ -521,13 +526,12 @@ def generate_90(
521
  prompt_embeds2 = torch.cat([prompt_embeds_a, prompt_embeds_b])
522
  print('catted shape2: ', prompt_embeds2.shape)
523
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
524
- print('catted pooled shape 2: ', prompt_embeds.shape)
 
 
525
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
526
- print('catted combined pooled shape: ', prompt_embeds.shape)
527
- # 4. (Optional) Average the pooled embeddings
528
- pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
529
- print('pooled averaged shape: ', pooled_prompt_embeds.shape)
530
-
531
  options = {
532
  #"prompt": prompt,
533
  "prompt_embeds": prompt_embeds,
 
238
  # 2. Encode with the two text encoders
239
  prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
240
  pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
241
+ print('pooled shape 1: ', pooled_prompt_embeds_a.shape)
242
  prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
243
  print('encoder shape: ', prompt_embeds_a.shape)
244
  prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
 
247
 
248
  prompt_embeds_a2 = pipe.text_encoder_2(text_input_ids1b.to(torch.device('cuda')), output_hidden_states=True)
249
  pooled_prompt_embeds_a2 = prompt_embeds_a2[0] # Pooled output from encoder 1
250
+ print('pooled shape 2: ', pooled_prompt_embeds_a2.shape)
251
  prompt_embeds_a2 = prompt_embeds_a2.hidden_states[-2] # Penultimate hidden state from encoder 1
252
  print('encoder shape2: ', prompt_embeds_a2.shape)
253
  prompt_embeds_b2 = pipe.text_encoder_2(text_input_ids2b.to(torch.device('cuda')), output_hidden_states=True)
 
259
  print('catted shape: ', prompt_embeds.shape)
260
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
261
  print('catted pooled shape: ', prompt_embeds.shape)
262
+ pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
263
+ print('meaned pooled shape: ', pooled_prompt_embeds.shape)
264
 
265
  # 4. (Optional) Average the pooled embeddings
266
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
 
270
  prompt_embeds2 = torch.cat([prompt_embeds_a, prompt_embeds_b])
271
  print('catted shape2: ', prompt_embeds2.shape)
272
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
273
+ print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
274
+ pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=0)
275
+ print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
276
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
277
+ print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
278
+
 
 
 
279
  options = {
280
  #"prompt": prompt,
281
  "prompt_embeds": prompt_embeds,
 
387
  print('catted shape: ', prompt_embeds.shape)
388
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
389
  print('catted pooled shape: ', prompt_embeds.shape)
390
+ pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
391
+ print('meaned pooled shape: ', pooled_prompt_embeds.shape)
392
 
393
  # 4. (Optional) Average the pooled embeddings
394
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
 
398
  prompt_embeds2 = torch.cat([prompt_embeds_a, prompt_embeds_b])
399
  print('catted shape2: ', prompt_embeds2.shape)
400
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
401
+ print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
402
+ pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=0)
403
+ print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
404
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
405
+ print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
406
+
 
 
 
407
  options = {
408
  #"prompt": prompt,
409
  "prompt_embeds": prompt_embeds,
 
515
  print('catted shape: ', prompt_embeds.shape)
516
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
517
  print('catted pooled shape: ', prompt_embeds.shape)
518
+ pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
519
+ print('meaned pooled shape: ', pooled_prompt_embeds.shape)
520
 
521
  # 4. (Optional) Average the pooled embeddings
522
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
 
526
  prompt_embeds2 = torch.cat([prompt_embeds_a, prompt_embeds_b])
527
  print('catted shape2: ', prompt_embeds2.shape)
528
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
529
+ print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
530
+ pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=0)
531
+ print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
532
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
533
+ print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
534
+
 
 
 
535
  options = {
536
  #"prompt": prompt,
537
  "prompt_embeds": prompt_embeds,