Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -193,6 +193,8 @@ def generate_30(
|
|
| 193 |
pipe.text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16)
|
| 194 |
pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
|
| 195 |
|
|
|
|
|
|
|
| 196 |
text_inputs1 = pipe.tokenizer(
|
| 197 |
prompt,
|
| 198 |
padding="max_length",
|
|
@@ -210,11 +212,15 @@ def generate_30(
|
|
| 210 |
)
|
| 211 |
text_input_ids2 = text_inputs2.input_ids
|
| 212 |
prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
|
|
|
|
| 213 |
prompt_embedsa = prompt_embedsa.hidden_states[-2]
|
|
|
|
| 214 |
prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
|
|
|
|
| 215 |
prompt_embedsb = prompt_embedsb.hidden_states[-2]
|
| 216 |
-
|
| 217 |
-
|
|
|
|
| 218 |
|
| 219 |
options = {
|
| 220 |
#"prompt": prompt,
|
|
@@ -262,6 +268,8 @@ def generate_60(
|
|
| 262 |
pipe.text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16)
|
| 263 |
pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
|
| 264 |
|
|
|
|
|
|
|
| 265 |
text_inputs1 = pipe.tokenizer(
|
| 266 |
prompt,
|
| 267 |
padding="max_length",
|
|
@@ -279,11 +287,15 @@ def generate_60(
|
|
| 279 |
)
|
| 280 |
text_input_ids2 = text_inputs2.input_ids
|
| 281 |
prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
|
|
|
|
| 282 |
prompt_embedsa = prompt_embedsa.hidden_states[-2]
|
|
|
|
| 283 |
prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
|
|
|
|
| 284 |
prompt_embedsb = prompt_embedsb.hidden_states[-2]
|
| 285 |
-
|
| 286 |
-
|
|
|
|
| 287 |
|
| 288 |
options = {
|
| 289 |
#"prompt": prompt,
|
|
@@ -331,6 +343,8 @@ def generate_90(
|
|
| 331 |
pipe.text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16)
|
| 332 |
pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
|
| 333 |
|
|
|
|
|
|
|
| 334 |
text_inputs1 = pipe.tokenizer(
|
| 335 |
prompt,
|
| 336 |
padding="max_length",
|
|
@@ -348,11 +362,15 @@ def generate_90(
|
|
| 348 |
)
|
| 349 |
text_input_ids2 = text_inputs2.input_ids
|
| 350 |
prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
|
|
|
|
| 351 |
prompt_embedsa = prompt_embedsa.hidden_states[-2]
|
|
|
|
| 352 |
prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
|
|
|
|
| 353 |
prompt_embedsb = prompt_embedsb.hidden_states[-2]
|
| 354 |
-
|
| 355 |
-
|
|
|
|
| 356 |
|
| 357 |
options = {
|
| 358 |
#"prompt": prompt,
|
|
|
|
| 193 |
pipe.text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16)
|
| 194 |
pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
|
| 195 |
|
| 196 |
+
pooled_prompt_embeds_list=[]
|
| 197 |
+
prompt_embeds_list=[]
|
| 198 |
text_inputs1 = pipe.tokenizer(
|
| 199 |
prompt,
|
| 200 |
padding="max_length",
|
|
|
|
| 212 |
)
|
| 213 |
text_input_ids2 = text_inputs2.input_ids
|
| 214 |
prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
|
| 215 |
+
pooled_prompt_embeds_list.append(prompt_embedsa)
|
| 216 |
prompt_embedsa = prompt_embedsa.hidden_states[-2]
|
| 217 |
+
prompt_embeds_list.append(prompt_embedsa[0])
|
| 218 |
prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
|
| 219 |
+
pooled_prompt_embeds_list.append(prompt_embedsb[0])
|
| 220 |
prompt_embedsb = prompt_embedsb.hidden_states[-2]
|
| 221 |
+
prompt_embeds_list.append(prompt_embedsb)
|
| 222 |
+
prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=-1)
|
| 223 |
+
pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=-1)
|
| 224 |
|
| 225 |
options = {
|
| 226 |
#"prompt": prompt,
|
|
|
|
| 268 |
pipe.text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16)
|
| 269 |
pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
|
| 270 |
|
| 271 |
+
pooled_prompt_embeds_list=[]
|
| 272 |
+
prompt_embeds_list=[]
|
| 273 |
text_inputs1 = pipe.tokenizer(
|
| 274 |
prompt,
|
| 275 |
padding="max_length",
|
|
|
|
| 287 |
)
|
| 288 |
text_input_ids2 = text_inputs2.input_ids
|
| 289 |
prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
|
| 290 |
+
pooled_prompt_embeds_list.append(prompt_embedsa)
|
| 291 |
prompt_embedsa = prompt_embedsa.hidden_states[-2]
|
| 292 |
+
prompt_embeds_list.append(prompt_embedsa[0])
|
| 293 |
prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
|
| 294 |
+
pooled_prompt_embeds_list.append(prompt_embedsb[0])
|
| 295 |
prompt_embedsb = prompt_embedsb.hidden_states[-2]
|
| 296 |
+
prompt_embeds_list.append(prompt_embedsb)
|
| 297 |
+
prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=-1)
|
| 298 |
+
pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=-1)
|
| 299 |
|
| 300 |
options = {
|
| 301 |
#"prompt": prompt,
|
|
|
|
| 343 |
pipe.text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16)
|
| 344 |
pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
|
| 345 |
|
| 346 |
+
pooled_prompt_embeds_list=[]
|
| 347 |
+
prompt_embeds_list=[]
|
| 348 |
text_inputs1 = pipe.tokenizer(
|
| 349 |
prompt,
|
| 350 |
padding="max_length",
|
|
|
|
| 362 |
)
|
| 363 |
text_input_ids2 = text_inputs2.input_ids
|
| 364 |
prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
|
| 365 |
+
pooled_prompt_embeds_list.append(prompt_embedsa)
|
| 366 |
prompt_embedsa = prompt_embedsa.hidden_states[-2]
|
| 367 |
+
prompt_embeds_list.append(prompt_embedsa[0])
|
| 368 |
prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
|
| 369 |
+
pooled_prompt_embeds_list.append(prompt_embedsb[0])
|
| 370 |
prompt_embedsb = prompt_embedsb.hidden_states[-2]
|
| 371 |
+
prompt_embeds_list.append(prompt_embedsb)
|
| 372 |
+
prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=-1)
|
| 373 |
+
pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=-1)
|
| 374 |
|
| 375 |
options = {
|
| 376 |
#"prompt": prompt,
|