1inkusFace commited on
Commit
0e750cc
·
verified ·
1 Parent(s): 34f3ac2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -301,17 +301,19 @@ def generate_60(
301
  prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
302
  pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
303
  prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
304
-
305
  prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
306
  pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
307
  prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
308
 
309
  # 3. Concatenate the embeddings along the sequence dimension (dim=1)
310
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=1)
 
311
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
312
-
313
  # 4. (Optional) Average the pooled embeddings
314
  prompt_embeds = prompt_embeds.mean(dim=1, keepdim=True)
 
315
  pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
316
 
317
 
@@ -384,17 +386,19 @@ def generate_90(
384
  prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
385
  pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
386
  prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
387
-
388
  prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
389
  pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
390
  prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
391
 
392
  # 3. Concatenate the embeddings along the sequence dimension (dim=1)
393
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=1)
 
394
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
395
-
396
  # 4. (Optional) Average the pooled embeddings
397
  prompt_embeds = prompt_embeds.mean(dim=1, keepdim=True)
 
398
  pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
399
 
400
 
 
301
  prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
302
  pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
303
  prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
304
+ print('encoder shape: ', prompt_embeds_a.shape)
305
  prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
306
  pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
307
  prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
308
 
309
  # 3. Concatenate the embeddings along the sequence dimension (dim=1)
310
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=1)
311
+ print('catted shape: ', prompt_embeds.shape)
312
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
313
+
314
  # 4. (Optional) Average the pooled embeddings
315
  prompt_embeds = prompt_embeds.mean(dim=1, keepdim=True)
316
+ print('averaged shape: ', prompt_embeds.shape)
317
  pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
318
 
319
 
 
386
  prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
387
  pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
388
  prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
389
+ print('encoder shape: ', prompt_embeds_a.shape)
390
  prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
391
  pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
392
  prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
393
 
394
  # 3. Concatenate the embeddings along the sequence dimension (dim=1)
395
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=1)
396
+ print('catted shape: ', prompt_embeds.shape)
397
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
398
+
399
  # 4. (Optional) Average the pooled embeddings
400
  prompt_embeds = prompt_embeds.mean(dim=1, keepdim=True)
401
+ print('averaged shape: ', prompt_embeds.shape)
402
  pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
403
 
404