1inkusFace commited on
Commit
855f65a
·
verified ·
1 Parent(s): 2b2f3b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -44
app.py CHANGED
@@ -221,24 +221,14 @@ def generate_30(
221
  pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
222
  prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
223
 
224
- # 3. Concatenate the embeddings along the sequence dimension (dim=1)
225
- prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=1)
226
  print('catted shape: ', prompt_embeds.shape)
227
- pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
228
-
229
  # 4. (Optional) Average the pooled embeddings
230
- prompt_embeds = prompt_embeds.mean(dim=0)
231
  print('averaged shape: ', prompt_embeds.shape)
232
- test_prompt_embeds = prompt_embeds.mean(dim=0,keepdim=True)
233
- print('averaged shape (keepdim): ', test_prompt_embeds.shape)
234
-
235
- test_prompt_embeds_2 = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=0).mean(dim=1)
236
- print('averaged shape 2: ', test_prompt_embeds_2.shape)
237
- test_prompt_embeds_3 = torch.cat([prompt_embeds_a, prompt_embeds_b]).mean(dim=0,keepdim=True)
238
- print('averaged shape 3(keepdim): ', test_prompt_embeds_3.shape)
239
-
240
- pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=0)
241
-
242
 
243
  options = {
244
  #"prompt": prompt,
@@ -314,23 +304,14 @@ def generate_60(
314
  pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
315
  prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
316
 
317
- # 3. Concatenate the embeddings along the sequence dimension (dim=1)
318
- prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=1)
319
  print('catted shape: ', prompt_embeds.shape)
320
- pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
321
-
322
  # 4. (Optional) Average the pooled embeddings
323
- prompt_embeds = prompt_embeds.mean(dim=0)
324
  print('averaged shape: ', prompt_embeds.shape)
325
- test_prompt_embeds = prompt_embeds.mean(dim=0,keepdim=True)
326
- print('averaged shape (keepdim): ', prompt_embeds.shape)
327
-
328
- test_prompt_embeds_2 = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=1).mean(dim=1)
329
- print('averaged shape 2: ', test_prompt_embeds_2.shape)
330
- test_prompt_embeds_3 = torch.cat([prompt_embeds_a, prompt_embeds_b]).mean(dim=1,keepdim=True)
331
- print('averaged shape 3(keepdim): ', test_prompt_embeds_3.shape)
332
-
333
- pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=0)
334
 
335
 
336
  options = {
@@ -407,24 +388,14 @@ def generate_90(
407
  pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
408
  prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
409
 
410
- # 3. Concatenate the embeddings along the sequence dimension (dim=1)
411
- prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=1)
412
  print('catted shape: ', prompt_embeds.shape)
413
- pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
414
-
415
  # 4. (Optional) Average the pooled embeddings
416
- prompt_embeds = prompt_embeds.mean(dim=0)
417
  print('averaged shape: ', prompt_embeds.shape)
418
- test_prompt_embeds = prompt_embeds.mean(dim=0,keepdim=True)
419
- print('averaged shape (keepdim): ', prompt_embeds.shape)
420
-
421
- test_prompt_embeds_2 = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=1).mean(dim=1)
422
- print('averaged shape 2: ', test_prompt_embeds_2.shape)
423
- test_prompt_embeds_3 = torch.cat([prompt_embeds_a, prompt_embeds_b]).mean(dim=1,keepdim=True)
424
- print('averaged shape 3(keepdim): ', test_prompt_embeds_3.shape)
425
-
426
- pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=0)
427
-
428
 
429
  options = {
430
  #"prompt": prompt,
 
221
  pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
222
  prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
223
 
224
+ # 3. Concatenate the embeddings
225
+ prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
226
  print('catted shape: ', prompt_embeds.shape)
227
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
 
228
  # 4. (Optional) Average the pooled embeddings
229
+ prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
230
  print('averaged shape: ', prompt_embeds.shape)
231
+ pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0,keepdim=True)
 
 
 
 
 
 
 
 
 
232
 
233
  options = {
234
  #"prompt": prompt,
 
304
  pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
305
  prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
306
 
307
+ # 3. Concatenate the embeddings
308
+ prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
309
  print('catted shape: ', prompt_embeds.shape)
310
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
 
311
  # 4. (Optional) Average the pooled embeddings
312
+ prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
313
  print('averaged shape: ', prompt_embeds.shape)
314
+ pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0,keepdim=True)
 
 
 
 
 
 
 
 
315
 
316
 
317
  options = {
 
388
  pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
389
  prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
390
 
391
+ # 3. Concatenate the embeddings
392
+ prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b])
393
  print('catted shape: ', prompt_embeds.shape)
394
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b])
 
395
  # 4. (Optional) Average the pooled embeddings
396
+ prompt_embeds = torch.mean(prompt_embeds,dim=0,keepdim=True)
397
  print('averaged shape: ', prompt_embeds.shape)
398
+ pooled_prompt_embeds = torch.mean(pooled_prompt_embeds,dim=0,keepdim=True)
 
 
 
 
 
 
 
 
 
399
 
400
  options = {
401
  #"prompt": prompt,