Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -238,6 +238,7 @@ def generate_30(
|
|
238 |
# 2. Encode with the two text encoders
|
239 |
prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
|
240 |
pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
|
|
|
241 |
prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
|
242 |
print('encoder shape: ', prompt_embeds_a.shape)
|
243 |
prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
|
@@ -246,7 +247,7 @@ def generate_30(
|
|
246 |
|
247 |
prompt_embeds_a2 = pipe.text_encoder_2(text_input_ids1b.to(torch.device('cuda')), output_hidden_states=True)
|
248 |
pooled_prompt_embeds_a2 = prompt_embeds_a2[0] # Pooled output from encoder 1
|
249 |
-
print('pooled shape: ', pooled_prompt_embeds_a2.shape)
|
250 |
prompt_embeds_a2 = prompt_embeds_a2.hidden_states[-2] # Penultimate hidden state from encoder 1
|
251 |
print('encoder shape2: ', prompt_embeds_a2.shape)
|
252 |
prompt_embeds_b2 = pipe.text_encoder_2(text_input_ids2b.to(torch.device('cuda')), output_hidden_states=True)
|
@@ -258,6 +259,8 @@ def generate_30(
|
|
258 |
print('catted shape: ', prompt_embeds.shape)
|
259 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
|
260 |
print('catted pooled shape: ', prompt_embeds.shape)
|
|
|
|
|
261 |
|
262 |
# 4. (Optional) Average the pooled embeddings
|
263 |
prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
|
@@ -267,13 +270,12 @@ def generate_30(
|
|
267 |
prompt_embeds2 = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
268 |
print('catted shape2: ', prompt_embeds2.shape)
|
269 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
270 |
-
print('catted pooled shape 2: ',
|
|
|
|
|
271 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
|
272 |
-
print('catted combined pooled shape: ',
|
273 |
-
|
274 |
-
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
275 |
-
print('pooled averaged shape: ', pooled_prompt_embeds.shape)
|
276 |
-
|
277 |
options = {
|
278 |
#"prompt": prompt,
|
279 |
"prompt_embeds": prompt_embeds,
|
@@ -385,6 +387,8 @@ def generate_60(
|
|
385 |
print('catted shape: ', prompt_embeds.shape)
|
386 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
|
387 |
print('catted pooled shape: ', prompt_embeds.shape)
|
|
|
|
|
388 |
|
389 |
# 4. (Optional) Average the pooled embeddings
|
390 |
prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
|
@@ -394,13 +398,12 @@ def generate_60(
|
|
394 |
prompt_embeds2 = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
395 |
print('catted shape2: ', prompt_embeds2.shape)
|
396 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
397 |
-
print('catted pooled shape 2: ',
|
|
|
|
|
398 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
|
399 |
-
print('catted combined pooled shape: ',
|
400 |
-
|
401 |
-
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
402 |
-
print('pooled averaged shape: ', pooled_prompt_embeds.shape)
|
403 |
-
|
404 |
options = {
|
405 |
#"prompt": prompt,
|
406 |
"prompt_embeds": prompt_embeds,
|
@@ -512,6 +515,8 @@ def generate_90(
|
|
512 |
print('catted shape: ', prompt_embeds.shape)
|
513 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
|
514 |
print('catted pooled shape: ', prompt_embeds.shape)
|
|
|
|
|
515 |
|
516 |
# 4. (Optional) Average the pooled embeddings
|
517 |
prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
|
@@ -521,13 +526,12 @@ def generate_90(
|
|
521 |
prompt_embeds2 = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
522 |
print('catted shape2: ', prompt_embeds2.shape)
|
523 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
524 |
-
print('catted pooled shape 2: ',
|
|
|
|
|
525 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
|
526 |
-
print('catted combined pooled shape: ',
|
527 |
-
|
528 |
-
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
529 |
-
print('pooled averaged shape: ', pooled_prompt_embeds.shape)
|
530 |
-
|
531 |
options = {
|
532 |
#"prompt": prompt,
|
533 |
"prompt_embeds": prompt_embeds,
|
|
|
238 |
# 2. Encode with the two text encoders
|
239 |
prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
|
240 |
pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
|
241 |
+
print('pooled shape 1: ', pooled_prompt_embeds_a.shape)
|
242 |
prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
|
243 |
print('encoder shape: ', prompt_embeds_a.shape)
|
244 |
prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
|
|
|
247 |
|
248 |
prompt_embeds_a2 = pipe.text_encoder_2(text_input_ids1b.to(torch.device('cuda')), output_hidden_states=True)
|
249 |
pooled_prompt_embeds_a2 = prompt_embeds_a2[0] # Pooled output from encoder 1
|
250 |
+
print('pooled shape 2: ', pooled_prompt_embeds_a2.shape)
|
251 |
prompt_embeds_a2 = prompt_embeds_a2.hidden_states[-2] # Penultimate hidden state from encoder 1
|
252 |
print('encoder shape2: ', prompt_embeds_a2.shape)
|
253 |
prompt_embeds_b2 = pipe.text_encoder_2(text_input_ids2b.to(torch.device('cuda')), output_hidden_states=True)
|
|
|
259 |
print('catted shape: ', prompt_embeds.shape)
|
260 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
|
261 |
print('catted pooled shape: ', prompt_embeds.shape)
|
262 |
+
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
263 |
+
print('meaned pooled shape: ', pooled_prompt_embeds.shape)
|
264 |
|
265 |
# 4. (Optional) Average the pooled embeddings
|
266 |
prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
|
|
|
270 |
prompt_embeds2 = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
271 |
print('catted shape2: ', prompt_embeds2.shape)
|
272 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
273 |
+
print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
|
274 |
+
pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=0)
|
275 |
+
print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
|
276 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
|
277 |
+
print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
|
278 |
+
|
|
|
|
|
|
|
279 |
options = {
|
280 |
#"prompt": prompt,
|
281 |
"prompt_embeds": prompt_embeds,
|
|
|
387 |
print('catted shape: ', prompt_embeds.shape)
|
388 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
|
389 |
print('catted pooled shape: ', prompt_embeds.shape)
|
390 |
+
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
391 |
+
print('meaned pooled shape: ', pooled_prompt_embeds.shape)
|
392 |
|
393 |
# 4. (Optional) Average the pooled embeddings
|
394 |
prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
|
|
|
398 |
prompt_embeds2 = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
399 |
print('catted shape2: ', prompt_embeds2.shape)
|
400 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
401 |
+
print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
|
402 |
+
pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=0)
|
403 |
+
print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
|
404 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
|
405 |
+
print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
|
406 |
+
|
|
|
|
|
|
|
407 |
options = {
|
408 |
#"prompt": prompt,
|
409 |
"prompt_embeds": prompt_embeds,
|
|
|
515 |
print('catted shape: ', prompt_embeds.shape)
|
516 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
|
517 |
print('catted pooled shape: ', prompt_embeds.shape)
|
518 |
+
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
519 |
+
print('meaned pooled shape: ', pooled_prompt_embeds.shape)
|
520 |
|
521 |
# 4. (Optional) Average the pooled embeddings
|
522 |
prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
|
|
|
526 |
prompt_embeds2 = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
527 |
print('catted shape2: ', prompt_embeds2.shape)
|
528 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
529 |
+
print('catted pooled shape 2: ', pooled_prompt_embeds2.shape)
|
530 |
+
pooled_prompt_embeds2 = torch.mean(pooled_prompt_embeds2,dim=0)
|
531 |
+
print('pooled meaned shape 2: ', pooled_prompt_embeds2.shape)
|
532 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
|
533 |
+
print('catted combined meaned pooled shape: ', pooled_prompt_embeds.shape)
|
534 |
+
|
|
|
|
|
|
|
535 |
options = {
|
536 |
#"prompt": prompt,
|
537 |
"prompt_embeds": prompt_embeds,
|