Spaces:
Runtime error
Runtime error
bring back inits
Browse files
app.py
CHANGED
@@ -118,26 +118,25 @@ def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, r
|
|
118 |
std=[0.26862954, 0.26130258, 0.27577711])
|
119 |
|
120 |
|
121 |
-
#def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, range_scale, init_scale, seed, image_prompt):
|
122 |
all_frames = []
|
123 |
prompts = [text]
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
batch_size = 1
|
129 |
clip_guidance_scale = clip_guidance_scale # Controls how much the image should look like the prompt.
|
130 |
tv_scale = tv_scale # Controls the smoothness of the final output.
|
131 |
range_scale = range_scale # Controls how far out of range RGB values are allowed to be.
|
132 |
cutn = cutn
|
133 |
n_batches = 1
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
skip_timesteps = skip_timesteps # This needs to be between approx. 200 and 500 when using an init image.
|
139 |
# Higher values make the output look more like the init.
|
140 |
-
|
141 |
seed = seed
|
142 |
|
143 |
if seed is not None:
|
@@ -149,25 +148,25 @@ def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, r
|
|
149 |
txt, weight = parse_prompt(prompt)
|
150 |
target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
|
151 |
weights.append(weight)
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
target_embeds = torch.cat(target_embeds)
|
161 |
weights = torch.tensor(weights, device=device)
|
162 |
if weights.sum().abs() < 1e-3:
|
163 |
raise RuntimeError('The weights must not sum to 0.')
|
164 |
weights /= weights.sum().abs()
|
165 |
init = None
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
cur_t = None
|
172 |
def cond_fn(x, t, y=None):
|
173 |
with torch.enable_grad():
|
@@ -185,10 +184,10 @@ def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, r
|
|
185 |
tv_losses = tv_loss(x_in)
|
186 |
range_losses = range_loss(out['pred_xstart'])
|
187 |
loss = losses.sum() * clip_guidance_scale + tv_losses.sum() * tv_scale + range_losses.sum() * range_scale
|
188 |
-
|
189 |
|
190 |
-
|
191 |
-
|
192 |
return -torch.autograd.grad(loss, x)[0]
|
193 |
if model_config['timestep_respacing'].startswith('ddim'):
|
194 |
sample_fn = diffusion.ddim_sample_loop_progressive
|
|
|
118 |
std=[0.26862954, 0.26130258, 0.27577711])
|
119 |
|
120 |
|
|
|
121 |
all_frames = []
|
122 |
prompts = [text]
|
123 |
+
if image_prompts:
|
124 |
+
image_prompts = [image_prompts.name]
|
125 |
+
else:
|
126 |
+
image_prompts = []
|
127 |
batch_size = 1
|
128 |
clip_guidance_scale = clip_guidance_scale # Controls how much the image should look like the prompt.
|
129 |
tv_scale = tv_scale # Controls the smoothness of the final output.
|
130 |
range_scale = range_scale # Controls how far out of range RGB values are allowed to be.
|
131 |
cutn = cutn
|
132 |
n_batches = 1
|
133 |
+
if init_image:
|
134 |
+
init_image = init_image.name
|
135 |
+
else:
|
136 |
+
init_image = None # This can be an URL or Colab local path and must be in quotes.
|
137 |
skip_timesteps = skip_timesteps # This needs to be between approx. 200 and 500 when using an init image.
|
138 |
# Higher values make the output look more like the init.
|
139 |
+
init_scale = init_scale # This enhances the effect of the init image, a good value is 1000.
|
140 |
seed = seed
|
141 |
|
142 |
if seed is not None:
|
|
|
148 |
txt, weight = parse_prompt(prompt)
|
149 |
target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
|
150 |
weights.append(weight)
|
151 |
+
for prompt in image_prompts:
|
152 |
+
path, weight = parse_prompt(prompt)
|
153 |
+
img = Image.open(fetch(path)).convert('RGB')
|
154 |
+
img = TF.resize(img, min(side_x, side_y, *img.size), transforms.InterpolationMode.LANCZOS)
|
155 |
+
batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
|
156 |
+
embed = clip_model.encode_image(normalize(batch)).float()
|
157 |
+
target_embeds.append(embed)
|
158 |
+
weights.extend([weight / cutn] * cutn)
|
159 |
target_embeds = torch.cat(target_embeds)
|
160 |
weights = torch.tensor(weights, device=device)
|
161 |
if weights.sum().abs() < 1e-3:
|
162 |
raise RuntimeError('The weights must not sum to 0.')
|
163 |
weights /= weights.sum().abs()
|
164 |
init = None
|
165 |
+
if init_image is not None:
|
166 |
+
lpips_model = lpips.LPIPS(net='vgg').to(device)
|
167 |
+
init = Image.open(fetch(init_image)).convert('RGB')
|
168 |
+
init = init.resize((side_x, side_y), Image.LANCZOS)
|
169 |
+
init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)
|
170 |
cur_t = None
|
171 |
def cond_fn(x, t, y=None):
|
172 |
with torch.enable_grad():
|
|
|
184 |
tv_losses = tv_loss(x_in)
|
185 |
range_losses = range_loss(out['pred_xstart'])
|
186 |
loss = losses.sum() * clip_guidance_scale + tv_losses.sum() * tv_scale + range_losses.sum() * range_scale
|
187 |
+
if init is not None and init_scale:
|
188 |
|
189 |
+
init_losses = lpips_model(x_in, init)
|
190 |
+
loss = loss + init_losses.sum() * init_scale
|
191 |
return -torch.autograd.grad(loss, x)[0]
|
192 |
if model_config['timestep_respacing'].startswith('ddim'):
|
193 |
sample_fn = diffusion.ddim_sample_loop_progressive
|