1inkusFace commited on
Commit
cd83b88
·
verified ·
1 Parent(s): 3a8e322

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -6
app.py CHANGED
@@ -193,6 +193,8 @@ def generate_30(
193
  pipe.text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16)
194
  pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
195
 
 
 
196
  text_inputs1 = pipe.tokenizer(
197
  prompt,
198
  padding="max_length",
@@ -210,11 +212,15 @@ def generate_30(
210
  )
211
  text_input_ids2 = text_inputs2.input_ids
212
  prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
 
213
  prompt_embedsa = prompt_embedsa.hidden_states[-2]
 
214
  prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
 
215
  prompt_embedsb = prompt_embedsb.hidden_states[-2]
216
- prompt_embeds = torch.cat([prompt_embedsa,prompt_embedsb]).mean(dim=-1)
217
- pooled_prompt_embeds = prompt_embeds[0]
 
218
 
219
  options = {
220
  #"prompt": prompt,
@@ -262,6 +268,8 @@ def generate_60(
262
  pipe.text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16)
263
  pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
264
 
 
 
265
  text_inputs1 = pipe.tokenizer(
266
  prompt,
267
  padding="max_length",
@@ -279,11 +287,15 @@ def generate_60(
279
  )
280
  text_input_ids2 = text_inputs2.input_ids
281
  prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
 
282
  prompt_embedsa = prompt_embedsa.hidden_states[-2]
 
283
  prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
 
284
  prompt_embedsb = prompt_embedsb.hidden_states[-2]
285
- prompt_embeds = torch.cat([prompt_embedsa,prompt_embedsb]).mean(dim=-1)
286
- pooled_prompt_embeds = prompt_embeds[0]
 
287
 
288
  options = {
289
  #"prompt": prompt,
@@ -331,6 +343,8 @@ def generate_90(
331
  pipe.text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16)
332
  pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
333
 
 
 
334
  text_inputs1 = pipe.tokenizer(
335
  prompt,
336
  padding="max_length",
@@ -348,11 +362,15 @@ def generate_90(
348
  )
349
  text_input_ids2 = text_inputs2.input_ids
350
  prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
 
351
  prompt_embedsa = prompt_embedsa.hidden_states[-2]
 
352
  prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
 
353
  prompt_embedsb = prompt_embedsb.hidden_states[-2]
354
- prompt_embeds = torch.cat([prompt_embedsa,prompt_embedsb]).mean(dim=-1)
355
- pooled_prompt_embeds = prompt_embeds[0]
 
356
 
357
  options = {
358
  #"prompt": prompt,
 
193
  pipe.text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16)
194
  pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
195
 
196
+ pooled_prompt_embeds_list=[]
197
+ prompt_embeds_list=[]
198
  text_inputs1 = pipe.tokenizer(
199
  prompt,
200
  padding="max_length",
 
212
  )
213
  text_input_ids2 = text_inputs2.input_ids
214
  prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
215
+ pooled_prompt_embeds_list.append(prompt_embedsa)
216
  prompt_embedsa = prompt_embedsa.hidden_states[-2]
217
+ prompt_embeds_list.append(prompt_embedsa[0])
218
  prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
219
+ pooled_prompt_embeds_list.append(prompt_embedsb[0])
220
  prompt_embedsb = prompt_embedsb.hidden_states[-2]
221
+ prompt_embeds_list.append(prompt_embedsb)
222
+ prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=-1)
223
+ pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=-1)
224
 
225
  options = {
226
  #"prompt": prompt,
 
268
  pipe.text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16)
269
  pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
270
 
271
+ pooled_prompt_embeds_list=[]
272
+ prompt_embeds_list=[]
273
  text_inputs1 = pipe.tokenizer(
274
  prompt,
275
  padding="max_length",
 
287
  )
288
  text_input_ids2 = text_inputs2.input_ids
289
  prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
290
+ pooled_prompt_embeds_list.append(prompt_embedsa)
291
  prompt_embedsa = prompt_embedsa.hidden_states[-2]
292
+ prompt_embeds_list.append(prompt_embedsa[0])
293
  prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
294
+ pooled_prompt_embeds_list.append(prompt_embedsb[0])
295
  prompt_embedsb = prompt_embedsb.hidden_states[-2]
296
+ prompt_embeds_list.append(prompt_embedsb)
297
+ prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=-1)
298
+ pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=-1)
299
 
300
  options = {
301
  #"prompt": prompt,
 
343
  pipe.text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16)
344
  pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
345
 
346
+ pooled_prompt_embeds_list=[]
347
+ prompt_embeds_list=[]
348
  text_inputs1 = pipe.tokenizer(
349
  prompt,
350
  padding="max_length",
 
362
  )
363
  text_input_ids2 = text_inputs2.input_ids
364
  prompt_embedsa = pipe.text_encoder(text_input_ids1.to(device), output_hidden_states=True)
365
+ pooled_prompt_embeds_list.append(prompt_embedsa)
366
  prompt_embedsa = prompt_embedsa.hidden_states[-2]
367
+ prompt_embeds_list.append(prompt_embedsa[0])
368
  prompt_embedsb = pipe.text_encoder(text_input_ids2.to(device), output_hidden_states=True)
369
+ pooled_prompt_embeds_list.append(prompt_embedsb[0])
370
  prompt_embedsb = prompt_embedsb.hidden_states[-2]
371
+ prompt_embeds_list.append(prompt_embedsb)
372
+ prompt_embeds = torch.cat(prompt_embeds_list).mean(dim=-1)
373
+ pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_list).mean(dim=-1)
374
 
375
  options = {
376
  #"prompt": prompt,