Spaces:
Runtime error
Runtime error
File size: 7,210 Bytes
5e01e13 ea06e5a 5e01e13 313342d 7adeb33 ea06e5a 602f686 5e01e13 ea06e5a 5e01e13 7df4a81 cbc358c 7df4a81 cbc358c 7df4a81 cbc358c 7df4a81 0d44baa 7df4a81 cbc358c 7df4a81 0d44baa 7adeb33 7df4a81 0d44baa cbc358c 0d44baa cbc358c 0d44baa cbc358c 0d44baa cbc358c 7adeb33 7df4a81 5e01e13 be4f1b2 5e01e13 0d44baa 313342d 0d44baa 3911108 7adeb33 313342d 7df4a81 0d44baa 7adeb33 313342d 7df4a81 0d44baa eda017b 0d44baa 7df4a81 5e01e13 602f686 7df4a81 0d44baa 7df4a81 0d44baa 7df4a81 0d44baa 7df4a81 0d44baa 602f686 be4f1b2 602f686 0d44baa be4f1b2 0d44baa 602f686 313342d 7df4a81 3911108 7df4a81 0d44baa c7632f4 7df4a81 0d44baa 7df4a81 0d44baa 7df4a81 0d44baa eda017b 0d44baa eda017b 0d44baa eda017b 0d44baa 602f686 0d44baa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import json
import gradio as gr
import os
from PIL import Image
import plotly.graph_objects as go
import plotly.express as px
import operator
TITLE = "Diffusion Faces Cluster Explorer"
clusters_12 = json.load(open("clusters/id_all_blip_clusters_12.json"))
clusters_24 = json.load(open("clusters/id_all_blip_clusters_24.json"))
clusters_48 = json.load(open("clusters/id_all_blip_clusters_48.json"))
clusters_by_size = {
12: clusters_12,
24: clusters_24,
48: clusters_48,
}
def to_string(label):
if label == "SD_2":
label = "Stable Diffusion 2.0"
elif label == "SD_14":
label = "Stable Diffusion 1.4"
elif label == "DallE":
label = "Dall-E 2"
elif label == "non-binary":
label = "non-binary person"
elif label == "person":
label = "<i>unmarked</i> (person)"
elif label == "":
label = "<i>unmarked</i> ()"
elif label == "gender":
label = "gender term"
return label
def describe_cluster(cl_dict, block="label", max_items=4):
labels_values = sorted(cl_dict.items(), key=operator.itemgetter(1))
labels_values.reverse()
total = float(sum(cl_dict.values()))
lv_prcnt = list(
(item[0], round(item[1] * 100 / total, 0)) for item in labels_values
)
top_label = lv_prcnt[0][0]
description_string = (
"<span>The most represented %s is <b>%s</b>, making up about <b>%d%%</b> of the cluster.</span>"
% (to_string(block), to_string(top_label), lv_prcnt[0][1])
)
description_string += "<p>This is followed by: "
for lv in lv_prcnt[1 : min(len(lv_prcnt), 1 + max_items)]:
description_string += "<BR/><b>%s:</b> %d%%" % (to_string(lv[0]), lv[1])
if len(lv_prcnt) > max_items + 1:
description_string += "<BR/><b> - Other terms:</b> %d%%" % (
sum(lv[1] for lv in lv_prcnt[max_items + 1 :]),
)
description_string += "</p>"
return description_string
def show_cluster(cl_id, num_clusters):
if not cl_id:
cl_id = 0
if not num_clusters:
num_clusters = 12
cl_dct = clusters_by_size[num_clusters][cl_id]
images = []
for i in range(6):
img_path = "/".join(
[st.replace("/", "") for st in cl_dct["img_path_list"][i].split("//")][3:]
)
images.append(
(
Image.open(os.path.join("identities-images", img_path)),
"_".join([img_path.split("/")[0], img_path.split("/")[-1]])
.replace("Photo_portrait_of_an_", "")
.replace("Photo_portrait_of_a_", "")
.replace("SD_v2_random_seeds_identity_", "(SD v.2) ")
.replace("dataset-identities-dalle2_", "(Dall-E 2) ")
.replace("SD_v1.4_random_seeds_identity_", "(SD v.1.4) ")
.replace("_", " "),
)
)
model_fig = go.Figure()
model_fig.add_trace(
go.Pie(
labels=list(dict(cl_dct["labels_model"]).keys()),
values=list(dict(cl_dct["labels_model"]).values()),
)
)
model_description = describe_cluster(dict(cl_dct["labels_model"]), "system")
gender_fig = go.Figure()
gender_fig.add_trace(
go.Pie(
labels=list(dict(cl_dct["labels_gender"]).keys()),
values=list(dict(cl_dct["labels_gender"]).values()),
)
)
gender_description = describe_cluster(dict(cl_dct["labels_gender"]), "gender")
ethnicity_fig = go.Figure()
ethnicity_fig.add_trace(
go.Bar(
x=list(dict(cl_dct["labels_ethnicity"]).keys()),
y=list(dict(cl_dct["labels_ethnicity"]).values()),
marker_color=px.colors.qualitative.G10,
)
)
ethnicity_description = describe_cluster(
dict(cl_dct["labels_ethnicity"]), "ethnicity"
)
return (
len(cl_dct["img_path_list"]),
gender_fig,
gender_description,
model_fig,
model_description,
ethnicity_fig,
ethnicity_description,
images,
gr.update(choices=[i for i in range(num_clusters)]),
)
with gr.Blocks(title=TITLE) as demo:
gr.Markdown(f"# {TITLE}")
gr.Markdown(
"## Explore the data generated from [DiffusionBiasExplorer](https://huggingface.co/spaces/society-ethics/DiffusionBiasExplorer)!"
)
gr.Markdown(
"### This demo showcases patterns in the images generated from different prompts input to Stable Diffusion and Dalle-2 systems."
)
gr.Markdown(
"### Below, see results on how the images from different prompts cluster together."
)
gr.HTML(
"""<span style="color:red" font-size:smaller>⚠️ DISCLAIMER: the images displayed by this tool were generated by text-to-image systems and may depict offensive stereotypes or contain explicit content.</span>"""
)
num_clusters = gr.Radio(
[12, 24, 48],
value=12,
label="How many clusters do you want to make from the data?",
)
with gr.Row():
with gr.Column(scale=4):
gallery = gr.Gallery(label="Most representative images in cluster").style(
grid=(3, 3)
)
with gr.Column():
cluster_id = gr.Dropdown(
choices=[i for i in range(num_clusters.value)],
value=0,
label="Select cluster to visualize:",
)
a = gr.Text(label="Number of images")
with gr.Row():
with gr.Column(scale=1):
c = gr.Plot(label="How many images from each system?")
c_desc = gr.HTML(label="")
with gr.Column(scale=1):
b = gr.Plot(label="How many gender terms are represented?")
b_desc = gr.HTML(label="")
with gr.Column(scale=2):
d = gr.Plot(label="Which ethnicity terms are present?")
d_desc = gr.HTML(label="")
gr.Markdown(
f"The 'System makeup' plot corresponds to the number of images from the cluster that come from each of the TTI systems that we are comparing: Dall-E 2, Stable Diffusion v.1.4. and Stable Diffusion v.2."
)
gr.Markdown(
'The Gender plot shows the number of images based on the input prompts that used the words man, woman, non-binary person, and unmarked, which we label "person".'
)
gr.Markdown(
f"The 'Ethnicity label makeup' plot corresponds to the number of images from each of the 18 ethnicities used in the prompts. A blank value means unmarked ethnicity."
)
demo.load(
fn=show_cluster,
inputs=[cluster_id, num_clusters],
outputs=[a, b, b_desc, c, c_desc, d, d_desc, gallery, cluster_id],
)
num_clusters.change(
fn=show_cluster,
inputs=[cluster_id, num_clusters],
outputs=[
a,
b,
b_desc,
c,
c_desc,
d,
d_desc,
gallery,
cluster_id,
],
)
cluster_id.change(
fn=show_cluster,
inputs=[cluster_id, num_clusters],
outputs=[a, b, b_desc, c, c_desc, d, d_desc, gallery, cluster_id],
)
if __name__ == "__main__":
demo.queue().launch(debug=True)
|