1inkusFace commited on
Commit
8a35e1b
·
verified ·
1 Parent(s): f69d021

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -6
app.py CHANGED
@@ -256,15 +256,17 @@ def generate_30(
256
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
257
  print('catted shape: ', prompt_embeds.shape)
258
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
 
 
259
  # 4. (Optional) Average the pooled embeddings
260
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
261
  print('averaged shape: ', prompt_embeds.shape)
262
 
263
  # 3. Concatenate the text_encoder_2 embeddings
264
- prompt_embeds2 = torch.cat([prompt_embeds_a2, prompt_embeds_b2])
265
- print('catted shape2: ', prompt_embeds.shape)
266
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
 
267
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
 
268
  # 4. (Optional) Average the pooled embeddings
269
  pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
270
  print('pooled averaged shape: ', pooled_prompt_embeds.shape)
@@ -379,15 +381,17 @@ def generate_60(
379
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
380
  print('catted shape: ', prompt_embeds.shape)
381
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
 
 
382
  # 4. (Optional) Average the pooled embeddings
383
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
384
  print('averaged shape: ', prompt_embeds.shape)
385
 
386
  # 3. Concatenate the text_encoder_2 embeddings
387
- prompt_embeds2 = torch.cat([prompt_embeds_a2, prompt_embeds_b2])
388
- print('catted shape2: ', prompt_embeds.shape)
389
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
 
390
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
 
391
  # 4. (Optional) Average the pooled embeddings
392
  pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
393
  print('pooled averaged shape: ', pooled_prompt_embeds.shape)
@@ -502,15 +506,17 @@ def generate_90(
502
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
503
  print('catted shape: ', prompt_embeds.shape)
504
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
 
 
505
  # 4. (Optional) Average the pooled embeddings
506
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
507
  print('averaged shape: ', prompt_embeds.shape)
508
 
509
  # 3. Concatenate the text_encoder_2 embeddings
510
- prompt_embeds2 = torch.cat([prompt_embeds_a2, prompt_embeds_b2])
511
- print('catted shape2: ', prompt_embeds.shape)
512
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
 
513
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
 
514
  # 4. (Optional) Average the pooled embeddings
515
  pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
516
  print('pooled averaged shape: ', pooled_prompt_embeds.shape)
 
256
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
257
  print('catted shape: ', prompt_embeds.shape)
258
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
259
+ print('pooled shape: ', prompt_embeds.shape)
260
+
261
  # 4. (Optional) Average the pooled embeddings
262
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
263
  print('averaged shape: ', prompt_embeds.shape)
264
 
265
  # 3. Concatenate the text_encoder_2 embeddings
 
 
266
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
267
+ print('catted pooled shape: ', prompt_embeds.shape)
268
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
269
+ print('catted combined pooled shape: ', prompt_embeds.shape)
270
  # 4. (Optional) Average the pooled embeddings
271
  pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
272
  print('pooled averaged shape: ', pooled_prompt_embeds.shape)
 
381
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
382
  print('catted shape: ', prompt_embeds.shape)
383
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
384
+ print('pooled shape: ', prompt_embeds.shape)
385
+
386
  # 4. (Optional) Average the pooled embeddings
387
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
388
  print('averaged shape: ', prompt_embeds.shape)
389
 
390
  # 3. Concatenate the text_encoder_2 embeddings
 
 
391
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
392
+ print('catted pooled shape: ', prompt_embeds.shape)
393
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
394
+ print('catted combined pooled shape: ', prompt_embeds.shape)
395
  # 4. (Optional) Average the pooled embeddings
396
  pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
397
  print('pooled averaged shape: ', pooled_prompt_embeds.shape)
 
506
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
507
  print('catted shape: ', prompt_embeds.shape)
508
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
509
+ print('pooled shape: ', prompt_embeds.shape)
510
+
511
  # 4. (Optional) Average the pooled embeddings
512
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
513
  print('averaged shape: ', prompt_embeds.shape)
514
 
515
  # 3. Concatenate the text_encoder_2 embeddings
 
 
516
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
517
+ print('catted pooled shape: ', prompt_embeds.shape)
518
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
519
+ print('catted combined pooled shape: ', prompt_embeds.shape)
520
  # 4. (Optional) Average the pooled embeddings
521
  pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
522
  print('pooled averaged shape: ', pooled_prompt_embeds.shape)