sxela commited on
Commit
03667b4
Β·
1 Parent(s): d8edde0

bring back inits

Browse files
Files changed (1) hide show
  1. app.py +25 -26
app.py CHANGED
@@ -118,26 +118,25 @@ def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, r
118
  std=[0.26862954, 0.26130258, 0.27577711])
119
 
120
 
121
- #def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, range_scale, init_scale, seed, image_prompt):
122
  all_frames = []
123
  prompts = [text]
124
- # if image_prompts:
125
- # image_prompts = [image_prompts.name]
126
- # else:
127
- # image_prompts = []
128
  batch_size = 1
129
  clip_guidance_scale = clip_guidance_scale # Controls how much the image should look like the prompt.
130
  tv_scale = tv_scale # Controls the smoothness of the final output.
131
  range_scale = range_scale # Controls how far out of range RGB values are allowed to be.
132
  cutn = cutn
133
  n_batches = 1
134
- # if init_image:
135
- # init_image = init_image.name
136
- # else:
137
- # init_image = None # This can be an URL or Colab local path and must be in quotes.
138
  skip_timesteps = skip_timesteps # This needs to be between approx. 200 and 500 when using an init image.
139
  # Higher values make the output look more like the init.
140
- # init_scale = init_scale # This enhances the effect of the init image, a good value is 1000.
141
  seed = seed
142
 
143
  if seed is not None:
@@ -149,25 +148,25 @@ def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, r
149
  txt, weight = parse_prompt(prompt)
150
  target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
151
  weights.append(weight)
152
- # for prompt in image_prompts:
153
- # path, weight = parse_prompt(prompt)
154
- # img = Image.open(fetch(path)).convert('RGB')
155
- # img = TF.resize(img, min(side_x, side_y, *img.size), transforms.InterpolationMode.LANCZOS)
156
- # batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
157
- # embed = clip_model.encode_image(normalize(batch)).float()
158
- # target_embeds.append(embed)
159
- # weights.extend([weight / cutn] * cutn)
160
  target_embeds = torch.cat(target_embeds)
161
  weights = torch.tensor(weights, device=device)
162
  if weights.sum().abs() < 1e-3:
163
  raise RuntimeError('The weights must not sum to 0.')
164
  weights /= weights.sum().abs()
165
  init = None
166
- # if init_image is not None:
167
- # lpips_model = lpips.LPIPS(net='vgg').to(device)
168
- # init = Image.open(fetch(init_image)).convert('RGB')
169
- # init = init.resize((side_x, side_y), Image.LANCZOS)
170
- # init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)
171
  cur_t = None
172
  def cond_fn(x, t, y=None):
173
  with torch.enable_grad():
@@ -185,10 +184,10 @@ def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, r
185
  tv_losses = tv_loss(x_in)
186
  range_losses = range_loss(out['pred_xstart'])
187
  loss = losses.sum() * clip_guidance_scale + tv_losses.sum() * tv_scale + range_losses.sum() * range_scale
188
- # if init is not None and init_scale:
189
 
190
- # init_losses = lpips_model(x_in, init)
191
- # loss = loss + init_losses.sum() * init_scale
192
  return -torch.autograd.grad(loss, x)[0]
193
  if model_config['timestep_respacing'].startswith('ddim'):
194
  sample_fn = diffusion.ddim_sample_loop_progressive
 
118
  std=[0.26862954, 0.26130258, 0.27577711])
119
 
120
 
 
121
  all_frames = []
122
  prompts = [text]
123
+ if image_prompts:
124
+ image_prompts = [image_prompts.name]
125
+ else:
126
+ image_prompts = []
127
  batch_size = 1
128
  clip_guidance_scale = clip_guidance_scale # Controls how much the image should look like the prompt.
129
  tv_scale = tv_scale # Controls the smoothness of the final output.
130
  range_scale = range_scale # Controls how far out of range RGB values are allowed to be.
131
  cutn = cutn
132
  n_batches = 1
133
+ if init_image:
134
+ init_image = init_image.name
135
+ else:
136
+ init_image = None # This can be an URL or Colab local path and must be in quotes.
137
  skip_timesteps = skip_timesteps # This needs to be between approx. 200 and 500 when using an init image.
138
  # Higher values make the output look more like the init.
139
+ init_scale = init_scale # This enhances the effect of the init image, a good value is 1000.
140
  seed = seed
141
 
142
  if seed is not None:
 
148
  txt, weight = parse_prompt(prompt)
149
  target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
150
  weights.append(weight)
151
+ for prompt in image_prompts:
152
+ path, weight = parse_prompt(prompt)
153
+ img = Image.open(fetch(path)).convert('RGB')
154
+ img = TF.resize(img, min(side_x, side_y, *img.size), transforms.InterpolationMode.LANCZOS)
155
+ batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
156
+ embed = clip_model.encode_image(normalize(batch)).float()
157
+ target_embeds.append(embed)
158
+ weights.extend([weight / cutn] * cutn)
159
  target_embeds = torch.cat(target_embeds)
160
  weights = torch.tensor(weights, device=device)
161
  if weights.sum().abs() < 1e-3:
162
  raise RuntimeError('The weights must not sum to 0.')
163
  weights /= weights.sum().abs()
164
  init = None
165
+ if init_image is not None:
166
+ lpips_model = lpips.LPIPS(net='vgg').to(device)
167
+ init = Image.open(fetch(init_image)).convert('RGB')
168
+ init = init.resize((side_x, side_y), Image.LANCZOS)
169
+ init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)
170
  cur_t = None
171
  def cond_fn(x, t, y=None):
172
  with torch.enable_grad():
 
184
  tv_losses = tv_loss(x_in)
185
  range_losses = range_loss(out['pred_xstart'])
186
  loss = losses.sum() * clip_guidance_scale + tv_losses.sum() * tv_scale + range_losses.sum() * range_scale
187
+ if init is not None and init_scale:
188
 
189
+ init_losses = lpips_model(x_in, init)
190
+ loss = loss + init_losses.sum() * init_scale
191
  return -torch.autograd.grad(loss, x)[0]
192
  if model_config['timestep_respacing'].startswith('ddim'):
193
  sample_fn = diffusion.ddim_sample_loop_progressive