Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -193,6 +193,8 @@ def generate_30(
|
|
193 |
pipe.text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16)
|
194 |
pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
|
195 |
|
|
|
|
|
196 |
text_inputs1 = pipe.tokenizer(
|
197 |
prompt,
|
198 |
padding="max_length",
|
@@ -210,11 +212,15 @@ def generate_30(
|
|
210 |
)
|
211 |
text_input_ids2 = text_inputs2.input_ids
|
212 |
prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
|
|
|
213 |
prompt_embedsa = prompt_embedsa.hidden_states[-2]
|
|
|
214 |
prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
|
|
|
215 |
prompt_embedsb = prompt_embedsb.hidden_states[-2]
|
216 |
-
|
217 |
-
|
|
|
218 |
|
219 |
options = {
|
220 |
#"prompt": prompt,
|
@@ -262,6 +268,8 @@ def generate_60(
|
|
262 |
pipe.text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16)
|
263 |
pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
|
264 |
|
|
|
|
|
265 |
text_inputs1 = pipe.tokenizer(
|
266 |
prompt,
|
267 |
padding="max_length",
|
@@ -279,11 +287,15 @@ def generate_60(
|
|
279 |
)
|
280 |
text_input_ids2 = text_inputs2.input_ids
|
281 |
prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
|
|
|
282 |
prompt_embedsa = prompt_embedsa.hidden_states[-2]
|
|
|
283 |
prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
|
|
|
284 |
prompt_embedsb = prompt_embedsb.hidden_states[-2]
|
285 |
-
|
286 |
-
|
|
|
287 |
|
288 |
options = {
|
289 |
#"prompt": prompt,
|
@@ -331,6 +343,8 @@ def generate_90(
|
|
331 |
pipe.text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16)
|
332 |
pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
|
333 |
|
|
|
|
|
334 |
text_inputs1 = pipe.tokenizer(
|
335 |
prompt,
|
336 |
padding="max_length",
|
@@ -348,11 +362,15 @@ def generate_90(
|
|
348 |
)
|
349 |
text_input_ids2 = text_inputs2.input_ids
|
350 |
prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
|
|
|
351 |
prompt_embedsa = prompt_embedsa.hidden_states[-2]
|
|
|
352 |
prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
|
|
|
353 |
prompt_embedsb = prompt_embedsb.hidden_states[-2]
|
354 |
-
|
355 |
-
|
|
|
356 |
|
357 |
options = {
|
358 |
#"prompt": prompt,
|
|
|
193 |
pipe.text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16)
|
194 |
pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
|
195 |
|
196 |
+
pooled_prompt_embeds_list=[]
|
197 |
+
prompt_embeds_list=[]
|
198 |
text_inputs1 = pipe.tokenizer(
|
199 |
prompt,
|
200 |
padding="max_length",
|
|
|
212 |
)
|
213 |
text_input_ids2 = text_inputs2.input_ids
|
214 |
prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
|
215 |
+
pooled_prompt_embeds_list.append(prompt_embedsa)
|
216 |
prompt_embedsa = prompt_embedsa.hidden_states[-2]
|
217 |
+
prompt_embeds_list.append(prompt_embedsa[0])
|
218 |
prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
|
219 |
+
pooled_prompt_embeds_list.append(prompt_embedsb[0])
|
220 |
prompt_embedsb = prompt_embedsb.hidden_states[-2]
|
221 |
+
prompt_embeds_list.append(prompt_embedsb)
|
222 |
+
prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=-1)
|
223 |
+
pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=-1)
|
224 |
|
225 |
options = {
|
226 |
#"prompt": prompt,
|
|
|
268 |
pipe.text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16)
|
269 |
pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
|
270 |
|
271 |
+
pooled_prompt_embeds_list=[]
|
272 |
+
prompt_embeds_list=[]
|
273 |
text_inputs1 = pipe.tokenizer(
|
274 |
prompt,
|
275 |
padding="max_length",
|
|
|
287 |
)
|
288 |
text_input_ids2 = text_inputs2.input_ids
|
289 |
prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
|
290 |
+
pooled_prompt_embeds_list.append(prompt_embedsa)
|
291 |
prompt_embedsa = prompt_embedsa.hidden_states[-2]
|
292 |
+
prompt_embeds_list.append(prompt_embedsa[0])
|
293 |
prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
|
294 |
+
pooled_prompt_embeds_list.append(prompt_embedsb[0])
|
295 |
prompt_embedsb = prompt_embedsb.hidden_states[-2]
|
296 |
+
prompt_embeds_list.append(prompt_embedsb)
|
297 |
+
prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=-1)
|
298 |
+
pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=-1)
|
299 |
|
300 |
options = {
|
301 |
#"prompt": prompt,
|
|
|
343 |
pipe.text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16)
|
344 |
pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
|
345 |
|
346 |
+
pooled_prompt_embeds_list=[]
|
347 |
+
prompt_embeds_list=[]
|
348 |
text_inputs1 = pipe.tokenizer(
|
349 |
prompt,
|
350 |
padding="max_length",
|
|
|
362 |
)
|
363 |
text_input_ids2 = text_inputs2.input_ids
|
364 |
prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
|
365 |
+
pooled_prompt_embeds_list.append(prompt_embedsa)
|
366 |
prompt_embedsa = prompt_embedsa.hidden_states[-2]
|
367 |
+
prompt_embeds_list.append(prompt_embedsa[0])
|
368 |
prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
|
369 |
+
pooled_prompt_embeds_list.append(prompt_embedsb[0])
|
370 |
prompt_embedsb = prompt_embedsb.hidden_states[-2]
|
371 |
+
prompt_embeds_list.append(prompt_embedsb)
|
372 |
+
prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=-1)
|
373 |
+
pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=-1)
|
374 |
|
375 |
options = {
|
376 |
#"prompt": prompt,
|