1inkusFace commited on
Commit
34f3ac2
·
verified ·
1 Parent(s): 91ac098

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -216,17 +216,19 @@ def generate_30(
216
  prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
217
  pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
218
  prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
219
-
220
  prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
221
  pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
222
  prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
223
 
224
  # 3. Concatenate the embeddings along the sequence dimension (dim=1)
225
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=1)
 
226
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
227
-
228
  # 4. (Optional) Average the pooled embeddings
229
  prompt_embeds = prompt_embeds.mean(dim=1, keepdim=True)
 
230
  pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
231
 
232
 
 
216
  prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
217
  pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
218
  prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
219
+ print('encoder shape: ', prompt_embeds_a.shape)
220
  prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
221
  pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
222
  prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
223
 
224
  # 3. Concatenate the embeddings along the sequence dimension (dim=1)
225
  prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=1)
226
+ print('catted shape: ', prompt_embeds.shape)
227
  pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
228
+
229
  # 4. (Optional) Average the pooled embeddings
230
  prompt_embeds = prompt_embeds.mean(dim=1, keepdim=True)
231
+ print('averaged shape: ', prompt_embeds.shape)
232
  pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
233
 
234