Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -301,17 +301,19 @@ def generate_60(
|
|
301 |
prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
|
302 |
pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
|
303 |
prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
|
304 |
-
|
305 |
prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
|
306 |
pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
|
307 |
prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
|
308 |
|
309 |
# 3. Concatenate the embeddings along the sequence dimension (dim=1)
|
310 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=1)
|
|
|
311 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
|
312 |
-
|
313 |
# 4. (Optional) Average the pooled embeddings
|
314 |
prompt_embeds = prompt_embeds.mean(dim=1, keepdim=True)
|
|
|
315 |
pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
|
316 |
|
317 |
|
@@ -384,17 +386,19 @@ def generate_90(
|
|
384 |
prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
|
385 |
pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
|
386 |
prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
|
387 |
-
|
388 |
prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
|
389 |
pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
|
390 |
prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
|
391 |
|
392 |
# 3. Concatenate the embeddings along the sequence dimension (dim=1)
|
393 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=1)
|
|
|
394 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
|
395 |
-
|
396 |
# 4. (Optional) Average the pooled embeddings
|
397 |
prompt_embeds = prompt_embeds.mean(dim=1, keepdim=True)
|
|
|
398 |
pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
|
399 |
|
400 |
|
|
|
301 |
prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
|
302 |
pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
|
303 |
prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
|
304 |
+
print('encoder shape: ', prompt_embeds_a.shape)
|
305 |
prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
|
306 |
pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
|
307 |
prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
|
308 |
|
309 |
# 3. Concatenate the embeddings along the sequence dimension (dim=1)
|
310 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=1)
|
311 |
+
print('catted shape: ', prompt_embeds.shape)
|
312 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
|
313 |
+
|
314 |
# 4. (Optional) Average the pooled embeddings
|
315 |
prompt_embeds = prompt_embeds.mean(dim=1, keepdim=True)
|
316 |
+
print('averaged shape: ', prompt_embeds.shape)
|
317 |
pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
|
318 |
|
319 |
|
|
|
386 |
prompt_embeds_a = pipe.text_encoder(text_input_ids1.to(torch.device('cuda')), output_hidden_states=True)
|
387 |
pooled_prompt_embeds_a = prompt_embeds_a[0] # Pooled output from encoder 1
|
388 |
prompt_embeds_a = prompt_embeds_a.hidden_states[-2] # Penultimate hidden state from encoder 1
|
389 |
+
print('encoder shape: ', prompt_embeds_a.shape)
|
390 |
prompt_embeds_b = pipe.text_encoder(text_input_ids2.to(torch.device('cuda')), output_hidden_states=True)
|
391 |
pooled_prompt_embeds_b = prompt_embeds_b[0] # Pooled output from encoder 2
|
392 |
prompt_embeds_b = prompt_embeds_b.hidden_states[-2] # Penultimate hidden state from encoder 2
|
393 |
|
394 |
# 3. Concatenate the embeddings along the sequence dimension (dim=1)
|
395 |
prompt_embeds = torch.cat([prompt_embeds_a, prompt_embeds_b], dim=1)
|
396 |
+
print('catted shape: ', prompt_embeds.shape)
|
397 |
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_a, pooled_prompt_embeds_b], dim=1)
|
398 |
+
|
399 |
# 4. (Optional) Average the pooled embeddings
|
400 |
prompt_embeds = prompt_embeds.mean(dim=1, keepdim=True)
|
401 |
+
print('averaged shape: ', prompt_embeds.shape)
|
402 |
pooled_prompt_embeds = pooled_prompt_embeds.mean(dim=1, keepdim=True)
|
403 |
|
404 |
|