Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -39,81 +39,48 @@ style_token_dict = {
|
|
39 |
"Birb Style": '<birb-style>'
|
40 |
}
|
41 |
|
42 |
-
def apply_guidance(
|
|
|
|
|
|
|
43 |
if guidance_method == 'Grayscale':
|
44 |
-
|
45 |
-
|
46 |
-
gray_latents = pil_to_latent(gray.convert('RGB'))
|
47 |
-
return latents + (gray_latents - latents) * loss_scale
|
48 |
elif guidance_method == 'Bright':
|
49 |
-
|
50 |
-
|
51 |
elif guidance_method == 'Contrast':
|
52 |
-
mean =
|
53 |
-
|
54 |
-
|
55 |
elif guidance_method == 'Symmetry':
|
56 |
-
|
57 |
-
|
58 |
elif guidance_method == 'Saturation':
|
59 |
-
|
60 |
-
|
61 |
-
saturated_latents = pil_to_latent(tfms.ToPILImage()(saturated))
|
62 |
-
return latents + (saturated_latents - latents) * loss_scale
|
63 |
else:
|
64 |
-
return
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
def generate_with_guidance(prompt, num_inference_steps, guidance_scale, seed, guidance_method, loss_scale):
|
|
|
67 |
generator = torch.Generator(device=torch_device).manual_seed(seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
-
|
70 |
-
text_input = sd_pipeline.tokenizer(prompt, padding="max_length", max_length=sd_pipeline.tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
71 |
-
with torch.no_grad():
|
72 |
-
text_embeddings = sd_pipeline.text_encoder(text_input.input_ids.to(torch_device))[0]
|
73 |
-
|
74 |
-
# Set the timesteps
|
75 |
-
sd_pipeline.scheduler.set_timesteps(num_inference_steps)
|
76 |
-
|
77 |
-
# Prepare latents
|
78 |
-
latents = torch.randn(
|
79 |
-
(1, sd_pipeline.unet.in_channels, 64, 64),
|
80 |
-
generator=generator,
|
81 |
-
device=torch_device
|
82 |
-
)
|
83 |
-
latents = latents * sd_pipeline.scheduler.init_noise_sigma
|
84 |
-
|
85 |
-
# Denoising loop
|
86 |
-
for t in tqdm(sd_pipeline.scheduler.timesteps):
|
87 |
-
# Expand the latents for classifier-free guidance
|
88 |
-
latent_model_input = torch.cat([latents] * 2)
|
89 |
-
latent_model_input = sd_pipeline.scheduler.scale_model_input(latent_model_input, timestep=t)
|
90 |
-
|
91 |
-
# Predict the noise residual
|
92 |
-
with torch.no_grad():
|
93 |
-
noise_pred = sd_pipeline.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
94 |
-
|
95 |
-
# Perform guidance
|
96 |
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
97 |
-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
98 |
-
|
99 |
-
# Apply custom guidance
|
100 |
-
latents = apply_guidance(latents, guidance_method, loss_scale / 10000) # Normalize loss_scale
|
101 |
-
|
102 |
-
# Compute the previous noisy sample x_t -> x_t-1
|
103 |
-
latents = sd_pipeline.scheduler.step(noise_pred, t, latents).prev_sample
|
104 |
-
|
105 |
-
# Scale and decode the image latents with vae
|
106 |
-
latents = 1 / 0.18215 * latents
|
107 |
-
with torch.no_grad():
|
108 |
-
image = sd_pipeline.vae.decode(latents).sample
|
109 |
-
|
110 |
-
# Convert to PIL Image
|
111 |
-
image = (image / 2 + 0.5).clamp(0, 1)
|
112 |
-
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
113 |
-
image = (image * 255).round().astype("uint8")[0]
|
114 |
-
image = Image.fromarray(image)
|
115 |
-
|
116 |
-
return image
|
117 |
|
118 |
def inference(text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale):
|
119 |
prompt = text + " " + style_token_dict[style]
|
|
|
39 |
"Birb Style": '<birb-style>'
|
40 |
}
|
41 |
|
42 |
+
def apply_guidance(image, guidance_method, loss_scale):
|
43 |
+
# Convert PIL Image to tensor
|
44 |
+
img_tensor = tfms.ToTensor()(image).unsqueeze(0).to(torch_device)
|
45 |
+
|
46 |
if guidance_method == 'Grayscale':
|
47 |
+
gray = tfms.Grayscale(3)(img_tensor)
|
48 |
+
guided = img_tensor + (gray - img_tensor) * (loss_scale / 10000)
|
|
|
|
|
49 |
elif guidance_method == 'Bright':
|
50 |
+
bright = F.relu(img_tensor) # Simple brightness increase
|
51 |
+
guided = img_tensor + (bright - img_tensor) * (loss_scale / 10000)
|
52 |
elif guidance_method == 'Contrast':
|
53 |
+
mean = img_tensor.mean()
|
54 |
+
contrast = (img_tensor - mean) * 2 + mean
|
55 |
+
guided = img_tensor + (contrast - img_tensor) * (loss_scale / 10000)
|
56 |
elif guidance_method == 'Symmetry':
|
57 |
+
flipped = torch.flip(img_tensor, [3]) # Flip horizontally
|
58 |
+
guided = img_tensor + (flipped - img_tensor) * (loss_scale / 10000)
|
59 |
elif guidance_method == 'Saturation':
|
60 |
+
saturated = tfms.functional.adjust_saturation(img_tensor, 2)
|
61 |
+
guided = img_tensor + (saturated - img_tensor) * (loss_scale / 10000)
|
|
|
|
|
62 |
else:
|
63 |
+
return image
|
64 |
+
|
65 |
+
# Convert back to PIL Image
|
66 |
+
guided = guided.squeeze(0).clamp(0, 1)
|
67 |
+
guided = (guided * 255).byte().cpu().permute(1, 2, 0).numpy()
|
68 |
+
return Image.fromarray(guided)
|
69 |
|
70 |
def generate_with_guidance(prompt, num_inference_steps, guidance_scale, seed, guidance_method, loss_scale):
|
71 |
+
# Generate image with pipeline
|
72 |
generator = torch.Generator(device=torch_device).manual_seed(seed)
|
73 |
+
image = sd_pipeline(
|
74 |
+
prompt,
|
75 |
+
num_inference_steps=num_inference_steps,
|
76 |
+
guidance_scale=guidance_scale,
|
77 |
+
generator=generator
|
78 |
+
).images[0]
|
79 |
+
|
80 |
+
# Apply guidance
|
81 |
+
guided_image = apply_guidance(image, guidance_method, loss_scale)
|
82 |
|
83 |
+
return guided_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
def inference(text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale):
|
86 |
prompt = text + " " + style_token_dict[style]
|