blanchon commited on
Commit
c584662
·
1 Parent(s): d97d8b8

num_images_per_prompt

Browse files
Files changed (1) hide show
  1. app.py +6 -8
app.py CHANGED
@@ -12,8 +12,8 @@ from PIL import Image, ImageFilter, ImageOps
12
  DEVICE = "cuda"
13
 
14
  MAX_SEED = np.iinfo(np.int32).max
15
- # FIXED_DIMENSION = 900
16
- FIXED_DIMENSION = 512 + (512 // 2)
17
  FIXED_DIMENSION = (FIXED_DIMENSION // 16) * 16
18
 
19
  SYSTEM_PROMPT = r"""This two-panel split-frame image showcases a furniture in as a product shot versus styled in a room.
@@ -127,20 +127,18 @@ def infer(
127
  seed = secrets.randbelow(MAX_SEED)
128
 
129
  prompt = prompt + ".\n" + SYSTEM_PROMPT if prompt else SYSTEM_PROMPT
130
- batch_size = 4
131
  results_images = pipe(
132
- prompt=[prompt] * batch_size,
133
- image=[image] * batch_size,
134
  mask_image=mask,
135
  height=FIXED_DIMENSION,
136
  width=FIXED_DIMENSION * 2,
137
- guidance_scale=guidance_scale,
138
  num_inference_steps=num_inference_steps,
 
 
139
  generator=torch.Generator("cpu").manual_seed(seed),
140
  )["images"]
141
 
142
- print(len(results_images))
143
-
144
  cropped_images = [
145
  image.crop((FIXED_DIMENSION, 0, FIXED_DIMENSION * 2, FIXED_DIMENSION))
146
  for image in results_images
 
12
  DEVICE = "cuda"
13
 
14
  MAX_SEED = np.iinfo(np.int32).max
15
+ FIXED_DIMENSION = 900
16
+ # FIXED_DIMENSION = 512 + (512 // 2)
17
  FIXED_DIMENSION = (FIXED_DIMENSION // 16) * 16
18
 
19
  SYSTEM_PROMPT = r"""This two-panel split-frame image showcases a furniture in as a product shot versus styled in a room.
 
127
  seed = secrets.randbelow(MAX_SEED)
128
 
129
  prompt = prompt + ".\n" + SYSTEM_PROMPT if prompt else SYSTEM_PROMPT
 
130
  results_images = pipe(
131
+ prompt=prompt,
132
+ image=image,
133
  mask_image=mask,
134
  height=FIXED_DIMENSION,
135
  width=FIXED_DIMENSION * 2,
 
136
  num_inference_steps=num_inference_steps,
137
+ guidance_scale=guidance_scale,
138
+ num_images_per_prompt=4,
139
  generator=torch.Generator("cpu").manual_seed(seed),
140
  )["images"]
141
 
 
 
142
  cropped_images = [
143
  image.crop((FIXED_DIMENSION, 0, FIXED_DIMENSION * 2, FIXED_DIMENSION))
144
  for image in results_images