sxela commited on
Commit
964e0fc
Β·
1 Parent(s): bc42778

fix image inputs

Browse files
Files changed (1) hide show
  1. app.py +10 -16
app.py CHANGED
@@ -80,7 +80,7 @@ def tv_loss(input):
80
  def range_loss(input):
81
  return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])
82
 
83
- def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, range_scale, init_scale, seed, image_prompts,timestep_respacing, cutn):
84
  # Model settings
85
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
86
  model_config = model_and_diffusion_defaults()
@@ -121,20 +121,14 @@ def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, r
121
 
122
  all_frames = []
123
  prompts = [text]
124
- if image_prompts:
125
- image_prompts = [image_prompts.name]
126
- else:
127
- image_prompts = []
128
  batch_size = 1
129
  clip_guidance_scale = clip_guidance_scale # Controls how much the image should look like the prompt.
130
  tv_scale = tv_scale # Controls the smoothness of the final output.
131
  range_scale = range_scale # Controls how far out of range RGB values are allowed to be.
132
  cutn = cutn
133
  n_batches = 1
134
- if init_image:
135
- init_image = init_image.name
136
- else:
137
- init_image = None # This can be an URL or Colab local path and must be in quotes.
138
  skip_timesteps = skip_timesteps # This needs to be between approx. 200 and 500 when using an init image.
139
  # Higher values make the output look more like the init.
140
  init_scale = init_scale # This enhances the effect of the init image, a good value is 1000.
@@ -149,14 +143,13 @@ def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, r
149
  txt, weight = parse_prompt(prompt)
150
  target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
151
  weights.append(weight)
152
- for prompt in image_prompts:
153
- path, weight = parse_prompt(prompt)
154
- img = Image.open(fetch(path)).convert('RGB')
155
  img = TF.resize(img, min(side_x, side_y, *img.size), transforms.InterpolationMode.LANCZOS)
156
  batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
157
  embed = clip_model.encode_image(normalize(batch)).float()
158
  target_embeds.append(embed)
159
- weights.extend([weight / cutn] * cutn)
160
  target_embeds = torch.cat(target_embeds)
161
  weights = torch.tensor(weights, device=device)
162
  if weights.sum().abs() < 1e-3:
@@ -165,7 +158,7 @@ def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, r
165
  init = None
166
  if init_image is not None:
167
  lpips_model = lpips.LPIPS(net='vgg').to(device)
168
- init = Image.open(fetch(init_image)).convert('RGB')
169
  init = init.resize((side_x, side_y), Image.LANCZOS)
170
  init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)
171
  cur_t = None
@@ -253,10 +246,11 @@ with demo:
253
  with gr.Column():
254
  init_image = gr.Image(source="upload", label='initial image (optional)')
255
  init_scale = gr.Slider(minimum=0, maximum=45, step=1, value=10, label="Look like the image above")
 
256
  # with gr.Group():
257
  with gr.Column():
258
  image_prompts = gr.Image(source="upload", label='image prompt (optional)')
259
- skip_timesteps = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Look like the image above")
260
 
261
  with gr.Group():
262
  with gr.Row():
@@ -279,6 +273,6 @@ with demo:
279
 
280
  outputs=[output_image,output_video]
281
 
282
- run_button.click(inference, inputs=[text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, range_scale, init_scale, seed, image_prompts,timestep_respacing, cutn], outputs=outputs)
283
 
284
  demo.launch(enable_queue=True)
 
80
  def range_loss(input):
81
  return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])
82
 
83
+ def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, range_scale, init_scale, seed, image_prompts,timestep_respacing, cutn, im_prompt_weight):
84
  # Model settings
85
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
86
  model_config = model_and_diffusion_defaults()
 
121
 
122
  all_frames = []
123
  prompts = [text]
124
+
 
 
 
125
  batch_size = 1
126
  clip_guidance_scale = clip_guidance_scale # Controls how much the image should look like the prompt.
127
  tv_scale = tv_scale # Controls the smoothness of the final output.
128
  range_scale = range_scale # Controls how far out of range RGB values are allowed to be.
129
  cutn = cutn
130
  n_batches = 1
131
+
 
 
 
132
  skip_timesteps = skip_timesteps # This needs to be between approx. 200 and 500 when using an init image.
133
  # Higher values make the output look more like the init.
134
  init_scale = init_scale # This enhances the effect of the init image, a good value is 1000.
 
143
  txt, weight = parse_prompt(prompt)
144
  target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
145
  weights.append(weight)
146
+ if image_prompts is not None:
147
+ img = Image.fromarray(image_prompts).convert('RGB')
 
148
  img = TF.resize(img, min(side_x, side_y, *img.size), transforms.InterpolationMode.LANCZOS)
149
  batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
150
  embed = clip_model.encode_image(normalize(batch)).float()
151
  target_embeds.append(embed)
152
+ weights.extend([im_prompt_weight / cutn] * cutn)
153
  target_embeds = torch.cat(target_embeds)
154
  weights = torch.tensor(weights, device=device)
155
  if weights.sum().abs() < 1e-3:
 
158
  init = None
159
  if init_image is not None:
160
  lpips_model = lpips.LPIPS(net='vgg').to(device)
161
+ init = Image.fromarray(init_image).convert('RGB')
162
  init = init.resize((side_x, side_y), Image.LANCZOS)
163
  init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)
164
  cur_t = None
 
246
  with gr.Column():
247
  init_image = gr.Image(source="upload", label='initial image (optional)')
248
  init_scale = gr.Slider(minimum=0, maximum=45, step=1, value=10, label="Look like the image above")
249
+ skip_timesteps = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Style strength")
250
  # with gr.Group():
251
  with gr.Column():
252
  image_prompts = gr.Image(source="upload", label='image prompt (optional)')
253
+ im_prompt_weight = gr.Slider(minimum=0, maximum=10, step=1, value=1, label="Look like the image above")
254
 
255
  with gr.Group():
256
  with gr.Row():
 
273
 
274
  outputs=[output_image,output_video]
275
 
276
+ run_button.click(inference, inputs=[text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, range_scale, init_scale, seed, image_prompts,timestep_respacing, cutn, im_prompt_weight], outputs=outputs)
277
 
278
  demo.launch(enable_queue=True)