Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -216,17 +216,19 @@ def generate_30(
|
|
216 |
prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
|
217 |
pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
|
218 |
prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
|
219 |
-
|
220 |
prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
|
221 |
pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
|
222 |
prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
|
223 |
|
224 |
# 3. Concatenate the embeddings along the sequence dimension (dim=1)
|
225 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=1)
|
|
|
226 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
|
227 |
-
|
228 |
# 4. (Optional) Average the pooled embeddings
|
229 |
prompt_embeds = prompt_embeds.mean(dim=1, keepdim=True)
|
|
|
230 |
pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
|
231 |
|
232 |
|
|
|
216 |
prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
|
217 |
pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
|
218 |
prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
|
219 |
+
print('encoder shape: ', prompt_embeds_a.shape)
|
220 |
prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
|
221 |
pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
|
222 |
prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
|
223 |
|
224 |
# 3. Concatenate the embeddings along the sequence dimension (dim=1)
|
225 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=1)
|
226 |
+
print('catted shape: ', prompt_embeds.shape)
|
227 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
|
228 |
+
|
229 |
# 4. (Optional) Average the pooled embeddings
|
230 |
prompt_embeds = prompt_embeds.mean(dim=1, keepdim=True)
|
231 |
+
print('averaged shape: ', prompt_embeds.shape)
|
232 |
pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
|
233 |
|
234 |
|