Spaces:
Sleeping
Sleeping
import io | |
import os | |
import torch | |
import zipfile | |
import spaces | |
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
from tqdm.auto import tqdm | |
from src.util.params import * | |
from src.util.clip_config import * | |
import matplotlib.pyplot as plt | |
def get_text_embeddings( | |
prompt, | |
tokenizer=tokenizer, | |
text_encoder=text_encoder, | |
torch_device=torch_device, | |
batch_size=1, | |
negative_prompt="", | |
): | |
text_input = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
with torch.no_grad(): | |
text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0] | |
max_length = text_input.input_ids.shape[-1] | |
uncond_input = tokenizer( | |
[negative_prompt] * batch_size, | |
padding="max_length", | |
max_length=max_length, | |
return_tensors="pt", | |
) | |
with torch.no_grad(): | |
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0] | |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
return text_embeddings | |
def generate_latents( | |
seed, | |
height=imageHeight, | |
width=imageWidth, | |
torch_device=torch_device, | |
unet=unet, | |
batch_size=1, | |
): | |
generator = torch.Generator().manual_seed(int(seed)) | |
latents = torch.randn( | |
(batch_size, unet.config.in_channels, height // 8, width // 8), | |
generator=generator, | |
).to(torch_device) | |
return latents | |
def generate_modified_latents( | |
poke, | |
seed, | |
pokeX=None, | |
pokeY=None, | |
pokeHeight=None, | |
pokeWidth=None, | |
imageHeight=imageHeight, | |
imageWidth=imageWidth, | |
): | |
original_latents = generate_latents(seed, height=imageHeight, width=imageWidth) | |
if poke: | |
np.random.seed(seed) | |
poke_latents = generate_latents( | |
np.random.randint(0, 100000), height=pokeHeight * 8, width=pokeWidth * 8 | |
) | |
x_origin = pokeX - pokeWidth // 2 | |
y_origin = pokeY - pokeHeight // 2 | |
modified_latents = original_latents.clone() | |
modified_latents[ | |
:, :, y_origin : y_origin + pokeHeight, x_origin : x_origin + pokeWidth | |
] = poke_latents | |
else: | |
modified_latents = None | |
return original_latents, modified_latents | |
def convert_to_pil_image(image): | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.detach().cpu().permute(0, 2, 3, 1).numpy() | |
images = (image * 255).round().astype("uint8") | |
pil_images = [Image.fromarray(image) for image in images] | |
return pil_images[0] | |
def generate_images( | |
latents, | |
text_embeddings, | |
num_inference_steps, | |
unet=unet, | |
guidance_scale=guidance_scale, | |
vae=vae, | |
scheduler=scheduler, | |
intermediate=False, | |
progress=gr.Progress(), | |
): | |
scheduler.set_timesteps(num_inference_steps) | |
latents = latents * scheduler.init_noise_sigma | |
images = [] | |
i = 1 | |
for t in tqdm(scheduler.timesteps): | |
latent_model_input = torch.cat([latents] * 2) | |
latent_model_input = scheduler.scale_model_input(latent_model_input, t) | |
with torch.no_grad(): | |
noise_pred = unet( | |
latent_model_input, t, encoder_hidden_states=text_embeddings | |
).sample | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * ( | |
noise_pred_text - noise_pred_uncond | |
) | |
if intermediate: | |
progress(((1000 - t) / 1000)) | |
Latents = 1 / 0.18215 * latents | |
with torch.no_grad(): | |
image = vae.decode(Latents).sample | |
images.append((convert_to_pil_image(image), "{}".format(i))) | |
latents = scheduler.step(noise_pred, t, latents).prev_sample | |
i += 1 | |
if not intermediate: | |
Latents = 1 / 0.18215 * latents | |
with torch.no_grad(): | |
image = vae.decode(Latents).sample | |
images = convert_to_pil_image(image) | |
return images | |
def get_word_embeddings( | |
prompt, tokenizer=tokenizer, text_encoder=text_encoder, torch_device=torch_device | |
): | |
text_input = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
).to(torch_device) | |
with torch.no_grad(): | |
text_embeddings = text_encoder(text_input.input_ids)[0].reshape(1, -1) | |
text_embeddings = text_embeddings.cpu().numpy() | |
return text_embeddings / np.linalg.norm(text_embeddings) | |
def get_concat_embeddings(names, merge=False): | |
embeddings = [] | |
for name in names: | |
embedding = get_word_embeddings(name) | |
embeddings.append(embedding) | |
embeddings = np.vstack(embeddings) | |
if merge: | |
embeddings = np.average(embeddings, axis=0).reshape(1, -1) | |
return embeddings | |
def get_axis_embeddings(A, B): | |
emb = [] | |
for a, b in zip(A, B): | |
e = get_word_embeddings(a) - get_word_embeddings(b) | |
emb.append(e) | |
emb = np.vstack(emb) | |
ax = np.average(emb, axis=0).reshape(1, -1) | |
return ax | |
def calculate_residual( | |
axis, axis_names, from_words=None, to_words=None, residual_axis=1 | |
): | |
axis_indices = [0, 1, 2] | |
axis_indices.remove(residual_axis) | |
if axis_names[axis_indices[0]] in axis_combinations: | |
fembeddings = get_concat_embeddings( | |
axis_combinations[axis_names[axis_indices[0]]], merge=True | |
) | |
else: | |
axis_combinations[axis_names[axis_indices[0]]] = from_words + to_words | |
fembeddings = get_concat_embeddings(from_words + to_words, merge=True) | |
if axis_names[axis_indices[1]] in axis_combinations: | |
sembeddings = get_concat_embeddings( | |
axis_combinations[axis_names[axis_indices[1]]], merge=True | |
) | |
else: | |
axis_combinations[axis_names[axis_indices[1]]] = from_words + to_words | |
sembeddings = get_concat_embeddings(from_words + to_words, merge=True) | |
fprojections = fembeddings @ axis[axis_indices[0]].T | |
sprojections = sembeddings @ axis[axis_indices[1]].T | |
partial_residual = fembeddings - (fprojections.reshape(-1, 1) * fembeddings) | |
residual = partial_residual - (sprojections.reshape(-1, 1) * sembeddings) | |
return residual | |
def calculate_step_size(num_images, start_degree_circular, end_degree_circular): | |
return (end_degree_circular - start_degree_circular) / (num_images) | |
def generate_seed_vis(seed): | |
np.random.seed(seed) | |
emb = np.random.rand(15) | |
plt.close() | |
plt.switch_backend("agg") | |
plt.figure(figsize=(10, 0.5)) | |
plt.imshow([emb], cmap="viridis") | |
plt.axis("off") | |
return plt | |
def export_as_gif(images, filename, frames_per_second=2, reverse=False): | |
imgs = [img[0] for img in images] | |
if reverse: | |
imgs += imgs[2:-1][::-1] | |
imgs[0].save( | |
f"outputs/{filename}", | |
format="GIF", | |
save_all=True, | |
append_images=imgs[1:], | |
duration=1000 // frames_per_second, | |
loop=0, | |
) | |
def export_as_zip(images, fname, tab_config=None): | |
if not os.path.exists(f"outputs/{fname}.zip"): | |
os.makedirs("outputs", exist_ok=True) | |
with zipfile.ZipFile(f"outputs/{fname}.zip", "w") as img_zip: | |
if tab_config: | |
with open("outputs/config.txt", "w") as f: | |
for key, value in tab_config.items(): | |
f.write(f"{key}: {value}\n") | |
f.close() | |
img_zip.write("outputs/config.txt", "config.txt") | |
for idx, img in enumerate(images): | |
buff = io.BytesIO() | |
img[0].save(buff, format="PNG") | |
buff = buff.getvalue() | |
max_num = len(images) | |
num_leading_zeros = len(str(max_num)) | |
img_name = f"{{:0{num_leading_zeros}}}.png" | |
img_zip.writestr(img_name.format(idx + 1), buff) | |
def read_html(file_path): | |
with open(file_path, "r", encoding="utf-8") as f: | |
content = f.read() | |
return content | |
__all__ = [ | |
"get_text_embeddings", | |
"generate_latents", | |
"generate_modified_latents", | |
"generate_images", | |
"get_word_embeddings", | |
"get_concat_embeddings", | |
"get_axis_embeddings", | |
"calculate_residual", | |
"calculate_step_size", | |
"generate_seed_vis", | |
"export_as_gif", | |
"export_as_zip", | |
"read_html", | |
] | |