1inkusFace commited on
Commit
91bf97e
·
verified ·
1 Parent(s): 8a35e1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -10
app.py CHANGED
@@ -246,6 +246,7 @@ def generate_30(
246
 
247
  prompt_embeds_a2 = pipe.text_encoder_2(text_input_ids1b.to(torch.device('cuda')), output_hidden_states=True)
248
  pooled_prompt_embeds_a2 = prompt_embeds_a2[0] # Pooled output from encoder 1
 
249
  prompt_embeds_a2 = prompt_embeds_a2.hidden_states[-2] # Penultimate hidden state from encoder 1
250
  print('encoder shape2: ', prompt_embeds_a2.shape)
251
  prompt_embeds_b2 = pipe.text_encoder_2(text_input_ids2b.to(torch.device('cuda')), output_hidden_states=True)
@@ -256,15 +257,17 @@ def generate_30(
256
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
257
  print('catted shape: ', prompt_embeds.shape)
258
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
259
- print('pooled shape: ', prompt_embeds.shape)
260
 
261
  # 4. (Optional) Average the pooled embeddings
262
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
263
  print('averaged shape: ', prompt_embeds.shape)
264
 
265
  # 3. Concatenate the text_encoder_2 embeddings
 
 
266
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
267
- print('catted pooled shape: ', prompt_embeds.shape)
268
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
269
  print('catted combined pooled shape: ', prompt_embeds.shape)
270
  # 4. (Optional) Average the pooled embeddings
@@ -370,26 +373,28 @@ def generate_60(
370
 
371
  prompt_embeds_a2 = pipe.text_encoder_2(text_input_ids1b.to(torch.device('cuda')), output_hidden_states=True)
372
  pooled_prompt_embeds_a2 = prompt_embeds_a2[0] # Pooled output from encoder 1
 
373
  prompt_embeds_a2 = prompt_embeds_a2.hidden_states[-2] # Penultimate hidden state from encoder 1
374
- print('encoder shape: ', prompt_embeds_a2.shape)
375
  prompt_embeds_b2 = pipe.text_encoder_2(text_input_ids2b.to(torch.device('cuda')), output_hidden_states=True)
376
  pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
377
  prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
378
 
379
-
380
  # 3. Concatenate the embeddings
381
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
382
  print('catted shape: ', prompt_embeds.shape)
383
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
384
- print('pooled shape: ', prompt_embeds.shape)
385
 
386
  # 4. (Optional) Average the pooled embeddings
387
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
388
  print('averaged shape: ', prompt_embeds.shape)
389
 
390
  # 3. Concatenate the text_encoder_2 embeddings
 
 
391
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
392
- print('catted pooled shape: ', prompt_embeds.shape)
393
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
394
  print('catted combined pooled shape: ', prompt_embeds.shape)
395
  # 4. (Optional) Average the pooled embeddings
@@ -495,26 +500,28 @@ def generate_90(
495
 
496
  prompt_embeds_a2 = pipe.text_encoder_2(text_input_ids1b.to(torch.device('cuda')), output_hidden_states=True)
497
  pooled_prompt_embeds_a2 = prompt_embeds_a2[0] # Pooled output from encoder 1
 
498
  prompt_embeds_a2 = prompt_embeds_a2.hidden_states[-2] # Penultimate hidden state from encoder 1
499
- print('encoder shape: ', prompt_embeds_a2.shape)
500
  prompt_embeds_b2 = pipe.text_encoder_2(text_input_ids2b.to(torch.device('cuda')), output_hidden_states=True)
501
  pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
502
  prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
503
 
504
-
505
  # 3. Concatenate the embeddings
506
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
507
  print('catted shape: ', prompt_embeds.shape)
508
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
509
- print('pooled shape: ', prompt_embeds.shape)
510
 
511
  # 4. (Optional) Average the pooled embeddings
512
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
513
  print('averaged shape: ', prompt_embeds.shape)
514
 
515
  # 3. Concatenate the text_encoder_2 embeddings
 
 
516
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
517
- print('catted pooled shape: ', prompt_embeds.shape)
518
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
519
  print('catted combined pooled shape: ', prompt_embeds.shape)
520
  # 4. (Optional) Average the pooled embeddings
 
246
 
247
  prompt_embeds_a2 = pipe.text_encoder_2(text_input_ids1b.to(torch.device('cuda')), output_hidden_states=True)
248
  pooled_prompt_embeds_a2 = prompt_embeds_a2[0] # Pooled output from encoder 1
249
+ print('pooled shape: ', pooled_prompt_embeds_a2.shape)
250
  prompt_embeds_a2 = prompt_embeds_a2.hidden_states[-2] # Penultimate hidden state from encoder 1
251
  print('encoder shape2: ', prompt_embeds_a2.shape)
252
  prompt_embeds_b2 = pipe.text_encoder_2(text_input_ids2b.to(torch.device('cuda')), output_hidden_states=True)
 
257
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
258
  print('catted shape: ', prompt_embeds.shape)
259
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
260
+ print('catted pooled shape: ', prompt_embeds.shape)
261
 
262
  # 4. (Optional) Average the pooled embeddings
263
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
264
  print('averaged shape: ', prompt_embeds.shape)
265
 
266
  # 3. Concatenate the text_encoder_2 embeddings
267
+ prompt_embeds2 = torch.cat([prompt_embeds_a, prompt_embeds_b])
268
+ print('catted shape2: ', prompt_embeds2.shape)
269
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
270
+ print('catted pooled shape 2: ', prompt_embeds.shape)
271
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
272
  print('catted combined pooled shape: ', prompt_embeds.shape)
273
  # 4. (Optional) Average the pooled embeddings
 
373
 
374
  prompt_embeds_a2 = pipe.text_encoder_2(text_input_ids1b.to(torch.device('cuda')), output_hidden_states=True)
375
  pooled_prompt_embeds_a2 = prompt_embeds_a2[0] # Pooled output from encoder 1
376
+ print('pooled shape: ', pooled_prompt_embeds_a2.shape)
377
  prompt_embeds_a2 = prompt_embeds_a2.hidden_states[-2] # Penultimate hidden state from encoder 1
378
+ print('encoder shape2: ', prompt_embeds_a2.shape)
379
  prompt_embeds_b2 = pipe.text_encoder_2(text_input_ids2b.to(torch.device('cuda')), output_hidden_states=True)
380
  pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
381
  prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
382
 
 
383
  # 3. Concatenate the embeddings
384
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
385
  print('catted shape: ', prompt_embeds.shape)
386
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
387
+ print('catted pooled shape: ', prompt_embeds.shape)
388
 
389
  # 4. (Optional) Average the pooled embeddings
390
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
391
  print('averaged shape: ', prompt_embeds.shape)
392
 
393
  # 3. Concatenate the text_encoder_2 embeddings
394
+ prompt_embeds2 = torch.cat([prompt_embeds_a, prompt_embeds_b])
395
+ print('catted shape2: ', prompt_embeds2.shape)
396
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
397
+ print('catted pooled shape 2: ', prompt_embeds.shape)
398
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
399
  print('catted combined pooled shape: ', prompt_embeds.shape)
400
  # 4. (Optional) Average the pooled embeddings
 
500
 
501
  prompt_embeds_a2 = pipe.text_encoder_2(text_input_ids1b.to(torch.device('cuda')), output_hidden_states=True)
502
  pooled_prompt_embeds_a2 = prompt_embeds_a2[0] # Pooled output from encoder 1
503
+ print('pooled shape: ', pooled_prompt_embeds_a2.shape)
504
  prompt_embeds_a2 = prompt_embeds_a2.hidden_states[-2] # Penultimate hidden state from encoder 1
505
+ print('encoder shape2: ', prompt_embeds_a2.shape)
506
  prompt_embeds_b2 = pipe.text_encoder_2(text_input_ids2b.to(torch.device('cuda')), output_hidden_states=True)
507
  pooled_prompt_embeds_b2 = prompt_embeds_b2[0] # Pooled output from encoder 2
508
  prompt_embeds_b2 = prompt_embeds_b2.hidden_states[-2] # Penultimate hidden state from encoder 2
509
 
 
510
  # 3. Concatenate the embeddings
511
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
512
  print('catted shape: ', prompt_embeds.shape)
513
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
514
+ print('catted pooled shape: ', prompt_embeds.shape)
515
 
516
  # 4. (Optional) Average the pooled embeddings
517
  prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
518
  print('averaged shape: ', prompt_embeds.shape)
519
 
520
  # 3. Concatenate the text_encoder_2 embeddings
521
+ prompt_embeds2 = torch.cat([prompt_embeds_a, prompt_embeds_b])
522
+ print('catted shape2: ', prompt_embeds2.shape)
523
  pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
524
+ print('catted pooled shape 2: ', prompt_embeds.shape)
525
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
526
  print('catted combined pooled shape: ', prompt_embeds.shape)
527
  # 4. (Optional) Average the pooled embeddings