1inkusFace commited on
Commit
b43e85a
·
verified ·
1 Parent(s): 4e1fa3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -8
app.py CHANGED
@@ -298,7 +298,7 @@ def expand_prompt(prompt):
298
  outputs = model.generate(
299
  input_ids=input_ids,
300
  attention_mask=attention_mask,
301
- max_new_tokens=256,
302
  temperature=0.2,
303
  top_p=0.9,
304
  do_sample=True,
@@ -312,7 +312,7 @@ def expand_prompt(prompt):
312
  outputs_2 = model.generate(
313
  input_ids=input_ids_2,
314
  attention_mask=attention_mask_2,
315
- max_new_tokens=384,
316
  temperature=0.2,
317
  top_p=0.9,
318
  do_sample=True,
@@ -423,13 +423,10 @@ def generate_30(
423
  del processor5
424
  gc.collect()
425
  torch.cuda.empty_cache()
426
- expanded = expand_prompt(prompt+caption+caption_2)
427
  expanded_1 = expanded[0]
428
  expanded_2 = expanded[1]
429
-
430
- prompt = flatten_and_stringify(prompt+expanded_1+expanded_2)
431
- prompt = " ".join(prompt)
432
-
433
  global model
434
  global txt_tokenizer
435
  del model
@@ -448,7 +445,7 @@ def generate_30(
448
  pil_image_3=sd_image_c,
449
  pil_image_4=sd_image_d,
450
  pil_image_5=sd_image_e,
451
- prompt=prompt+' '+expanded_1,
452
  negative_prompt=negative_prompt,
453
  text_scale=text_scale,
454
  ip_scale=ip_scale,
 
298
  outputs = model.generate(
299
  input_ids=input_ids,
300
  attention_mask=attention_mask,
301
+ max_new_tokens=128,
302
  temperature=0.2,
303
  top_p=0.9,
304
  do_sample=True,
 
312
  outputs_2 = model.generate(
313
  input_ids=input_ids_2,
314
  attention_mask=attention_mask_2,
315
+ max_new_tokens=128,
316
  temperature=0.2,
317
  top_p=0.9,
318
  do_sample=True,
 
423
  del processor5
424
  gc.collect()
425
  torch.cuda.empty_cache()
426
+ expanded = expand_prompt(caption)
427
  expanded_1 = expanded[0]
428
  expanded_2 = expanded[1]
429
+ new_prompt = prompt+' '+expanded_1
 
 
 
430
  global model
431
  global txt_tokenizer
432
  del model
 
445
  pil_image_3=sd_image_c,
446
  pil_image_4=sd_image_d,
447
  pil_image_5=sd_image_e,
448
+ prompt=new_prompt,
449
  negative_prompt=negative_prompt,
450
  text_scale=text_scale,
451
  ip_scale=ip_scale,