Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -256,15 +256,17 @@ def generate_30(
|
|
256 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
257 |
print('catted shape: ', prompt_embeds.shape)
|
258 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
|
|
|
|
|
259 |
# 4. (Optional) Average the pooled embeddings
|
260 |
prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
|
261 |
print('averaged shape: ', prompt_embeds.shape)
|
262 |
|
263 |
# 3. Concatenate the text_encoder_2 embeddings
|
264 |
-
prompt_embeds2 = torch.cat([prompt_embeds_a2, prompt_embeds_b2])
|
265 |
-
print('catted shape2: ', prompt_embeds.shape)
|
266 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
|
|
267 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
|
|
|
268 |
# 4. (Optional) Average the pooled embeddings
|
269 |
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
270 |
print('pooled averaged shape: ', pooled_prompt_embeds.shape)
|
@@ -379,15 +381,17 @@ def generate_60(
|
|
379 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
380 |
print('catted shape: ', prompt_embeds.shape)
|
381 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
|
|
|
|
|
382 |
# 4. (Optional) Average the pooled embeddings
|
383 |
prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
|
384 |
print('averaged shape: ', prompt_embeds.shape)
|
385 |
|
386 |
# 3. Concatenate the text_encoder_2 embeddings
|
387 |
-
prompt_embeds2 = torch.cat([prompt_embeds_a2, prompt_embeds_b2])
|
388 |
-
print('catted shape2: ', prompt_embeds.shape)
|
389 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
|
|
390 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
|
|
|
391 |
# 4. (Optional) Average the pooled embeddings
|
392 |
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
393 |
print('pooled averaged shape: ', pooled_prompt_embeds.shape)
|
@@ -502,15 +506,17 @@ def generate_90(
|
|
502 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
503 |
print('catted shape: ', prompt_embeds.shape)
|
504 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
|
|
|
|
|
505 |
# 4. (Optional) Average the pooled embeddings
|
506 |
prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
|
507 |
print('averaged shape: ', prompt_embeds.shape)
|
508 |
|
509 |
# 3. Concatenate the text_encoder_2 embeddings
|
510 |
-
prompt_embeds2 = torch.cat([prompt_embeds_a2, prompt_embeds_b2])
|
511 |
-
print('catted shape2: ', prompt_embeds.shape)
|
512 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
|
|
513 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
|
|
|
514 |
# 4. (Optional) Average the pooled embeddings
|
515 |
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
516 |
print('pooled averaged shape: ', pooled_prompt_embeds.shape)
|
|
|
256 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
257 |
print('catted shape: ', prompt_embeds.shape)
|
258 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
|
259 |
+
print('pooled shape: ', prompt_embeds.shape)
|
260 |
+
|
261 |
# 4. (Optional) Average the pooled embeddings
|
262 |
prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
|
263 |
print('averaged shape: ', prompt_embeds.shape)
|
264 |
|
265 |
# 3. Concatenate the text_encoder_2 embeddings
|
|
|
|
|
266 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
267 |
+
print('catted pooled shape: ', prompt_embeds.shape)
|
268 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
|
269 |
+
print('catted combined pooled shape: ', prompt_embeds.shape)
|
270 |
# 4. (Optional) Average the pooled embeddings
|
271 |
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
272 |
print('pooled averaged shape: ', pooled_prompt_embeds.shape)
|
|
|
381 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
382 |
print('catted shape: ', prompt_embeds.shape)
|
383 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
|
384 |
+
print('pooled shape: ', prompt_embeds.shape)
|
385 |
+
|
386 |
# 4. (Optional) Average the pooled embeddings
|
387 |
prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
|
388 |
print('averaged shape: ', prompt_embeds.shape)
|
389 |
|
390 |
# 3. Concatenate the text_encoder_2 embeddings
|
|
|
|
|
391 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
392 |
+
print('catted pooled shape: ', prompt_embeds.shape)
|
393 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
|
394 |
+
print('catted combined pooled shape: ', prompt_embeds.shape)
|
395 |
# 4. (Optional) Average the pooled embeddings
|
396 |
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
397 |
print('pooled averaged shape: ', pooled_prompt_embeds.shape)
|
|
|
506 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
|
507 |
print('catted shape: ', prompt_embeds.shape)
|
508 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
|
509 |
+
print('pooled shape: ', prompt_embeds.shape)
|
510 |
+
|
511 |
# 4. (Optional) Average the pooled embeddings
|
512 |
prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
|
513 |
print('averaged shape: ', prompt_embeds.shape)
|
514 |
|
515 |
# 3. Concatenate the text_encoder_2 embeddings
|
|
|
|
|
516 |
pooled_prompt_embeds2 = torch.cat([pooled_prompt_embeds_a2, pooled_prompt_embeds_b2])
|
517 |
+
print('catted pooled shape: ', prompt_embeds.shape)
|
518 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds2])
|
519 |
+
print('catted combined pooled shape: ', prompt_embeds.shape)
|
520 |
# 4. (Optional) Average the pooled embeddings
|
521 |
pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0)
|
522 |
print('pooled averaged shape: ', pooled_prompt_embeds.shape)
|