Spaces:
Sleeping
Sleeping
File size: 5,315 Bytes
04ef268 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import random
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
from diffusers import StableDiffusionPipeline
import base64
from io import BytesIO
import plotly.express as px
from src.util.base import *
from src.util.params import *
from src.util.clip_config import *
age = get_axis_embeddings(young, old)
gender = get_axis_embeddings(masculine, feminine)
royalty = get_axis_embeddings(common, elite)
images = []
for example in examples:
image = pipe(
prompt=example,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
buffer = BytesIO()
image.save(buffer, format="JPEG")
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
images.append("data:image/jpeg;base64, " + encoded_image)
axis = np.vstack([gender, royalty, age])
axis[1] = calculate_residual(axis, axis_names)
coords = get_concat_embeddings(examples) @ axis.T
coords[:, 1] = 5 * (1.0 - coords[:, 1])
def update_fig():
global coords, examples, fig
fig.data[0].x = coords[:, 0]
fig.data[0].y = coords[:, 1]
fig.data[0].z = coords[:, 2]
fig.data[0].text = examples
return f"""
<script>
document.getElementById("html").src += "?rand={random.random()}"
</script>
<iframe id="html" src={dash_tunnel} style="width:100%; height:725px;"></iframe>
"""
def add_word(new_example):
global coords, images, examples
new_coord = get_concat_embeddings([new_example]) @ axis.T
new_coord[:, 1] = 5 * (1.0 - new_coord[:, 1])
coords = np.vstack([coords, new_coord])
image = pipe(
prompt=new_example,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
buffer = BytesIO()
image.save(buffer, format="JPEG")
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
images.append("data:image/jpeg;base64, " + encoded_image)
examples.append(new_example)
return update_fig()
def remove_word(new_example):
global coords, images, examples
examplesMap = {example: index for index, example in enumerate(examples)}
index = examplesMap[new_example]
coords = np.delete(coords, index, 0)
images.pop(index)
examples.pop(index)
return update_fig()
def add_rem_word(new_examples):
global examples
new_examples = new_examples.replace(",", " ").split()
for new_example in new_examples:
if new_example in examples:
remove_word(new_example)
gr.Info("Removed {}".format(new_example))
else:
tokens = tokenizer.encode(new_example)
if len(tokens) != 3:
gr.Warning(f"{new_example} not found in embeddings")
else:
add_word(new_example)
gr.Info("Added {}".format(new_example))
return update_fig()
def set_axis(axis_name, which_axis, from_words, to_words):
global coords, examples, fig, axis_names
if axis_name != "residual":
from_words, to_words = (
from_words.replace(",", " ").split(),
to_words.replace(",", " ").split(),
)
axis_emb = get_axis_embeddings(from_words, to_words)
axis[axisMap[which_axis]] = axis_emb
axis_names[axisMap[which_axis]] = axis_name
for i, name in enumerate(axis_names):
if name == "residual":
axis[i] = calculate_residual(axis, axis_names, from_words, to_words, i)
axis_names[i] = "residual"
else:
residual = calculate_residual(
axis, axis_names, residual_axis=axisMap[which_axis]
)
axis[axisMap[which_axis]] = residual
axis_names[axisMap[which_axis]] = axis_name
coords = get_concat_embeddings(examples) @ axis.T
coords[:, 1] = 5 * (1.0 - coords[:, 1])
fig.update_layout(
scene=dict(
xaxis_title=axis_names[0],
yaxis_title=axis_names[1],
zaxis_title=axis_names[2],
)
)
return update_fig()
def change_word(examples):
examples = examples.replace(",", " ").split()
for example in examples:
remove_word(example)
add_word(example)
gr.Info("Changed image for {}".format(example))
return update_fig()
def clear_words():
while examples:
remove_word(examples[-1])
return update_fig()
def generate_word_emb_vis(prompt):
buf = BytesIO()
emb = get_word_embeddings(prompt).reshape(77, 768)[1]
plt.imsave(buf, [emb], cmap="inferno")
img = "data:image/jpeg;base64, " + base64.b64encode(buf.getvalue()).decode("utf-8")
return img
fig = px.scatter_3d(
x=coords[:, 0],
y=coords[:, 1],
z=coords[:, 2],
labels={
"x": axis_names[0],
"y": axis_names[1],
"z": axis_names[2],
},
text=examples,
height=750,
)
fig.update_layout(
margin=dict(l=0, r=0, b=0, t=0), scene_camera=dict(eye=dict(x=2, y=2, z=0.1))
)
fig.update_traces(hoverinfo="none", hovertemplate=None)
__all__ = [
"fig",
"update_fig",
"coords",
"images",
"examples",
"add_word",
"remove_word",
"add_rem_word",
"change_word",
"clear_words",
"generate_word_emb_vis",
"set_axis",
"axis",
]
|