Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -258,7 +258,7 @@ def generate_30(
|
|
258 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
259 |
print('catted shape: ', prompt_embeds.shape)
|
260 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
|
261 |
-
print('catted pooled shape: ',
|
262 |
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
263 |
print('meaned pooled shape: ', pooled_prompt_embeds.shape)
|
264 |
|
@@ -271,7 +271,7 @@ def generate_30(
|
|
271 |
print('catted shape2: ', prompt_embeds2.shape)
|
272 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
273 |
print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
|
274 |
-
pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=0
|
275 |
print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
|
276 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
|
277 |
print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
|
@@ -367,6 +367,7 @@ def generate_60(
|
|
367 |
# 2. Encode with the two text encoders
|
368 |
prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
|
369 |
pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
|
|
|
370 |
prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
|
371 |
print('encoder shape: ', prompt_embeds_a.shape)
|
372 |
prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
|
@@ -375,7 +376,7 @@ def generate_60(
|
|
375 |
|
376 |
prompt_embeds_a2 = pipe.text_encoder_2(text_input_ids1b.to(torch.device('cuda')), output_hidden_states=True)
|
377 |
pooled_prompt_embeds_a2 = prompt_embeds_a2[0] # Pooled output from encoder 1
|
378 |
-
print('pooled shape: ', pooled_prompt_embeds_a2.shape)
|
379 |
prompt_embeds_a2 = prompt_embeds_a2.hidden_states[-2] # Penultimate hidden state from encoder 1
|
380 |
print('encoder shape2: ', prompt_embeds_a2.shape)
|
381 |
prompt_embeds_b2 = pipe.text_encoder_2(text_input_ids2b.to(torch.device('cuda')), output_hidden_states=True)
|
@@ -386,7 +387,7 @@ def generate_60(
|
|
386 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
387 |
print('catted shape: ', prompt_embeds.shape)
|
388 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
|
389 |
-
print('catted pooled shape: ',
|
390 |
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
391 |
print('meaned pooled shape: ', pooled_prompt_embeds.shape)
|
392 |
|
@@ -495,6 +496,7 @@ def generate_90(
|
|
495 |
# 2. Encode with the two text encoders
|
496 |
prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
|
497 |
pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
|
|
|
498 |
prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
|
499 |
print('encoder shape: ', prompt_embeds_a.shape)
|
500 |
prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
|
@@ -503,7 +505,7 @@ def generate_90(
|
|
503 |
|
504 |
prompt_embeds_a2 = pipe.text_encoder_2(text_input_ids1b.to(torch.device('cuda')), output_hidden_states=True)
|
505 |
pooled_prompt_embeds_a2 = prompt_embeds_a2[0] # Pooled output from encoder 1
|
506 |
-
print('pooled shape: ', pooled_prompt_embeds_a2.shape)
|
507 |
prompt_embeds_a2 = prompt_embeds_a2.hidden_states[-2] # Penultimate hidden state from encoder 1
|
508 |
print('encoder shape2: ', prompt_embeds_a2.shape)
|
509 |
prompt_embeds_b2 = pipe.text_encoder_2(text_input_ids2b.to(torch.device('cuda')), output_hidden_states=True)
|
@@ -514,7 +516,7 @@ def generate_90(
|
|
514 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
515 |
print('catted shape: ', prompt_embeds.shape)
|
516 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
|
517 |
-
print('catted pooled shape: ',
|
518 |
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
519 |
print('meaned pooled shape: ', pooled_prompt_embeds.shape)
|
520 |
|
|
|
258 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
259 |
print('catted shape: ', prompt_embeds.shape)
|
260 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
|
261 |
+
print('catted pooled shape: ', pooled_prompt_embeds.shape)
|
262 |
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
263 |
print('meaned pooled shape: ', pooled_prompt_embeds.shape)
|
264 |
|
|
|
271 |
print('catted shape2: ', prompt_embeds2.shape)
|
272 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
273 |
print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
|
274 |
+
pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=0)
|
275 |
print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
|
276 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
|
277 |
print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
|
|
|
367 |
# 2. Encode with the two text encoders
|
368 |
prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
|
369 |
pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
|
370 |
+
print('pooled shape 1: ', pooled_prompt_embeds_a.shape)
|
371 |
prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
|
372 |
print('encoder shape: ', prompt_embeds_a.shape)
|
373 |
prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
|
|
|
376 |
|
377 |
prompt_embeds_a2 = pipe.text_encoder_2(text_input_ids1b.to(torch.device('cuda')), output_hidden_states=True)
|
378 |
pooled_prompt_embeds_a2 = prompt_embeds_a2[0] # Pooled output from encoder 1
|
379 |
+
print('pooled shape 2: ', pooled_prompt_embeds_a2.shape)
|
380 |
prompt_embeds_a2 = prompt_embeds_a2.hidden_states[-2] # Penultimate hidden state from encoder 1
|
381 |
print('encoder shape2: ', prompt_embeds_a2.shape)
|
382 |
prompt_embeds_b2 = pipe.text_encoder_2(text_input_ids2b.to(torch.device('cuda')), output_hidden_states=True)
|
|
|
387 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
388 |
print('catted shape: ', prompt_embeds.shape)
|
389 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
|
390 |
+
print('catted pooled shape: ', pooled_prompt_embeds.shape)
|
391 |
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
392 |
print('meaned pooled shape: ', pooled_prompt_embeds.shape)
|
393 |
|
|
|
496 |
# 2. Encode with the two text encoders
|
497 |
prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
|
498 |
pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
|
499 |
+
print('pooled shape 1: ', pooled_prompt_embeds_a.shape)
|
500 |
prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
|
501 |
print('encoder shape: ', prompt_embeds_a.shape)
|
502 |
prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
|
|
|
505 |
|
506 |
prompt_embeds_a2 = pipe.text_encoder_2(text_input_ids1b.to(torch.device('cuda')), output_hidden_states=True)
|
507 |
pooled_prompt_embeds_a2 = prompt_embeds_a2[0] # Pooled output from encoder 1
|
508 |
+
print('pooled shape 2: ', pooled_prompt_embeds_a2.shape)
|
509 |
prompt_embeds_a2 = prompt_embeds_a2.hidden_states[-2] # Penultimate hidden state from encoder 1
|
510 |
print('encoder shape2: ', prompt_embeds_a2.shape)
|
511 |
prompt_embeds_b2 = pipe.text_encoder_2(text_input_ids2b.to(torch.device('cuda')), output_hidden_states=True)
|
|
|
516 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
517 |
print('catted shape: ', prompt_embeds.shape)
|
518 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
|
519 |
+
print('catted pooled shape: ', pooled_prompt_embeds.shape)
|
520 |
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
521 |
print('meaned pooled shape: ', pooled_prompt_embeds.shape)
|
522 |
|