Spaces:
Configuration error
Configuration error
update config and rm discriminator
Browse files- app.py +1 -1
- app_backend.py +18 -19
- configs.py +1 -1
app.py
CHANGED
@@ -139,7 +139,7 @@ with gr.Blocks(css="styles.css") as demo:
|
|
139 |
with gr.Column():
|
140 |
major_global = gr.Button(value="Major Global Changes (e.g. change race / gender").style(full_width=False)
|
141 |
iterations = gr.Slider(minimum=10,
|
142 |
-
maximum=
|
143 |
step=1,
|
144 |
value=20,
|
145 |
label="Iterations: How many steps the model will take to modify the image. Try starting small and seeing how the results turn out, you can always resume with afterwards",)
|
|
|
139 |
with gr.Column():
|
140 |
major_global = gr.Button(value="Major Global Changes (e.g. change race / gender").style(full_width=False)
|
141 |
iterations = gr.Slider(minimum=10,
|
142 |
+
maximum=60,
|
143 |
step=1,
|
144 |
value=20,
|
145 |
label="Iterations: How many steps the model will take to modify the image. Try starting small and seeing how the results turn out, you can always resume with afterwards",)
|
app_backend.py
CHANGED
@@ -81,7 +81,6 @@ class ImagePromptOptimizer(nn.Module):
|
|
81 |
self.make_grid = make_grid
|
82 |
self.return_val = return_val
|
83 |
self.quantize = quantize
|
84 |
-
self.disc = load_disc(self.device)
|
85 |
self.lpips_weight = lpips_weight
|
86 |
self.perceptual_loss = lpips.LPIPS(net='vgg').to(self.device)
|
87 |
def disc_loss_fn(self, logits):
|
@@ -175,19 +174,19 @@ class ImagePromptOptimizer(nn.Module):
|
|
175 |
clip_clone = processed_img.clone()
|
176 |
clip_clone.register_hook(self.attn_masking)
|
177 |
clip_clone.retain_grad()
|
178 |
-
with torch.autocast("cuda"):
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
with torch.no_grad():
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
if log:
|
189 |
wandb.log({"Perceptual Loss": perceptual_loss})
|
190 |
-
wandb.log({"Discriminator Loss": disc_loss})
|
191 |
wandb.log({"CLIP Loss": clip_loss})
|
192 |
clip_loss.backward(retain_graph=True)
|
193 |
perceptual_loss.backward(retain_graph=True)
|
@@ -209,13 +208,13 @@ class ImagePromptOptimizer(nn.Module):
|
|
209 |
lpips_input = processed_img.clone()
|
210 |
lpips_input.register_hook(self.attn_masking2)
|
211 |
lpips_input.retain_grad()
|
212 |
-
with torch.autocast("cuda"):
|
213 |
-
|
214 |
-
with torch.no_grad():
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
# print(f"disc_loss2 = {disc_loss2}")
|
220 |
if log:
|
221 |
wandb.log({"Perceptual Loss": perceptual_loss})
|
|
|
81 |
self.make_grid = make_grid
|
82 |
self.return_val = return_val
|
83 |
self.quantize = quantize
|
|
|
84 |
self.lpips_weight = lpips_weight
|
85 |
self.perceptual_loss = lpips.LPIPS(net='vgg').to(self.device)
|
86 |
def disc_loss_fn(self, logits):
|
|
|
174 |
clip_clone = processed_img.clone()
|
175 |
clip_clone.register_hook(self.attn_masking)
|
176 |
clip_clone.retain_grad()
|
177 |
+
# with torch.autocast("cuda"):
|
178 |
+
clip_loss = self.get_similarity_loss(pos_prompts, neg_prompts, clip_clone)
|
179 |
+
print("CLIP loss", clip_loss)
|
180 |
+
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
|
181 |
+
print("LPIPS loss: ", perceptual_loss)
|
182 |
+
# with torch.no_grad():
|
183 |
+
# disc_logits = self.disc(transformed_img)
|
184 |
+
# disc_loss = self.disc_loss_fn(disc_logits)
|
185 |
+
# print(f"disc_loss = {disc_loss}")
|
186 |
+
# disc_loss2 = self.disc(processed_img)
|
187 |
if log:
|
188 |
wandb.log({"Perceptual Loss": perceptual_loss})
|
189 |
+
# wandb.log({"Discriminator Loss": disc_loss})
|
190 |
wandb.log({"CLIP Loss": clip_loss})
|
191 |
clip_loss.backward(retain_graph=True)
|
192 |
perceptual_loss.backward(retain_graph=True)
|
|
|
208 |
lpips_input = processed_img.clone()
|
209 |
lpips_input.register_hook(self.attn_masking2)
|
210 |
lpips_input.retain_grad()
|
211 |
+
# with torch.autocast("cuda"):
|
212 |
+
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
|
213 |
+
# with torch.no_grad():
|
214 |
+
# disc_logits = self.disc(transformed_img)
|
215 |
+
# disc_loss = self.disc_loss_fn(disc_logits)
|
216 |
+
# print(f"disc_loss = {disc_loss}")
|
217 |
+
# disc_loss2 = self.disc(processed_img)
|
218 |
# print(f"disc_loss2 = {disc_loss2}")
|
219 |
if log:
|
220 |
wandb.log({"Perceptual Loss": perceptual_loss})
|
configs.py
CHANGED
@@ -2,6 +2,6 @@ import gradio as gr
|
|
2 |
def set_small_local():
|
3 |
return (gr.Slider.update(value=25), gr.Slider.update(value=0.15), gr.Slider.update(value=1), gr.Slider.update(value=4))
|
4 |
def set_major_local():
|
5 |
-
return (gr.Slider.update(value=25), gr.Slider.update(value=0.
|
6 |
def set_major_global():
|
7 |
return (gr.Slider.update(value=30), gr.Slider.update(value=0.1), gr.Slider.update(value=2), gr.Slider.update(value=0.2))
|
|
|
2 |
def set_small_local():
|
3 |
return (gr.Slider.update(value=25), gr.Slider.update(value=0.15), gr.Slider.update(value=1), gr.Slider.update(value=4))
|
4 |
def set_major_local():
|
5 |
+
return (gr.Slider.update(value=25), gr.Slider.update(value=0.2), gr.Slider.update(value=36.6), gr.Slider.update(value=10))
|
6 |
def set_major_global():
|
7 |
return (gr.Slider.update(value=30), gr.Slider.update(value=0.1), gr.Slider.update(value=2), gr.Slider.update(value=0.2))
|