Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -232,9 +232,9 @@ def generate_30(
|
|
232 |
test_prompt_embeds = prompt_embeds.mean(dim=0,keepdim=True)
|
233 |
print('averaged shape (keepdim): ', test_prompt_embeds.shape)
|
234 |
|
235 |
-
test_prompt_embeds_2 = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=
|
236 |
print('averaged shape 2: ', test_prompt_embeds_2.shape)
|
237 |
-
test_prompt_embeds_3 = torch.cat([prompt_embeds_a, prompt_embeds_b]).mean(dim=
|
238 |
print('averaged shape 3(keepdim): ', test_prompt_embeds_3.shape)
|
239 |
|
240 |
pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=0)
|
|
|
232 |
test_prompt_embeds = prompt_embeds.mean(dim=0,keepdim=True)
|
233 |
print('averaged shape (keepdim): ', test_prompt_embeds.shape)
|
234 |
|
235 |
+
test_prompt_embeds_2 = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=0).mean(dim=1)
|
236 |
print('averaged shape 2: ', test_prompt_embeds_2.shape)
|
237 |
+
test_prompt_embeds_3 = torch.cat([prompt_embeds_a, prompt_embeds_b]).mean(dim=0,keepdim=True)
|
238 |
print('averaged shape 3(keepdim): ', test_prompt_embeds_3.shape)
|
239 |
|
240 |
pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=0)
|