yjernite
fixed range for the dropdpwn
eda017b
raw
history blame
7.21 kB
import json
import gradio as gr
import os
from PIL import Image
import plotly.graph_objects as go
import plotly.express as px
import operator
TITLE = "Diffusion Faces Cluster Explorer"
clusters_12 = json.load(open("clusters/id_all_blip_clusters_12.json"))
clusters_24 = json.load(open("clusters/id_all_blip_clusters_24.json"))
clusters_48 = json.load(open("clusters/id_all_blip_clusters_48.json"))
clusters_by_size = {
12: clusters_12,
24: clusters_24,
48: clusters_48,
}
def to_string(label):
if label == "SD_2":
label = "Stable Diffusion 2.0"
elif label == "SD_14":
label = "Stable Diffusion 1.4"
elif label == "DallE":
label = "Dall-E 2"
elif label == "non-binary":
label = "non-binary person"
elif label == "person":
label = "<i>unmarked</i> (person)"
elif label == "":
label = "<i>unmarked</i> ()"
elif label == "gender":
label = "gender term"
return label
def describe_cluster(cl_dict, block="label", max_items=4):
labels_values = sorted(cl_dict.items(), key=operator.itemgetter(1))
labels_values.reverse()
total = float(sum(cl_dict.values()))
lv_prcnt = list(
(item[0], round(item[1] * 100 / total, 0)) for item in labels_values
)
top_label = lv_prcnt[0][0]
description_string = (
"<span>The most represented %s is <b>%s</b>, making up about <b>%d%%</b> of the cluster.</span>"
% (to_string(block), to_string(top_label), lv_prcnt[0][1])
)
description_string += "<p>This is followed by: "
for lv in lv_prcnt[1 : min(len(lv_prcnt), 1 + max_items)]:
description_string += "<BR/><b>%s:</b> %d%%" % (to_string(lv[0]), lv[1])
if len(lv_prcnt) > max_items + 1:
description_string += "<BR/><b> - Other terms:</b> %d%%" % (
sum(lv[1] for lv in lv_prcnt[max_items + 1 :]),
)
description_string += "</p>"
return description_string
def show_cluster(cl_id, num_clusters):
if not cl_id:
cl_id = 0
if not num_clusters:
num_clusters = 12
cl_dct = clusters_by_size[num_clusters][cl_id]
images = []
for i in range(6):
img_path = "/".join(
[st.replace("/", "") for st in cl_dct["img_path_list"][i].split("//")][3:]
)
images.append(
(
Image.open(os.path.join("identities-images", img_path)),
"_".join([img_path.split("/")[0], img_path.split("/")[-1]])
.replace("Photo_portrait_of_an_", "")
.replace("Photo_portrait_of_a_", "")
.replace("SD_v2_random_seeds_identity_", "(SD v.2) ")
.replace("dataset-identities-dalle2_", "(Dall-E 2) ")
.replace("SD_v1.4_random_seeds_identity_", "(SD v.1.4) ")
.replace("_", " "),
)
)
model_fig = go.Figure()
model_fig.add_trace(
go.Pie(
labels=list(dict(cl_dct["labels_model"]).keys()),
values=list(dict(cl_dct["labels_model"]).values()),
)
)
model_description = describe_cluster(dict(cl_dct["labels_model"]), "system")
gender_fig = go.Figure()
gender_fig.add_trace(
go.Pie(
labels=list(dict(cl_dct["labels_gender"]).keys()),
values=list(dict(cl_dct["labels_gender"]).values()),
)
)
gender_description = describe_cluster(dict(cl_dct["labels_gender"]), "gender")
ethnicity_fig = go.Figure()
ethnicity_fig.add_trace(
go.Bar(
x=list(dict(cl_dct["labels_ethnicity"]).keys()),
y=list(dict(cl_dct["labels_ethnicity"]).values()),
marker_color=px.colors.qualitative.G10,
)
)
ethnicity_description = describe_cluster(
dict(cl_dct["labels_ethnicity"]), "ethnicity"
)
return (
len(cl_dct["img_path_list"]),
gender_fig,
gender_description,
model_fig,
model_description,
ethnicity_fig,
ethnicity_description,
images,
gr.update(choices=[i for i in range(num_clusters)]),
)
with gr.Blocks(title=TITLE) as demo:
gr.Markdown(f"# {TITLE}")
gr.Markdown(
"## Explore the data generated from [DiffusionBiasExplorer](https://huggingface.co/spaces/society-ethics/DiffusionBiasExplorer)!"
)
gr.Markdown(
"### This demo showcases patterns in the images generated from different prompts input to Stable Diffusion and Dalle-2 systems."
)
gr.Markdown(
"### Below, see results on how the images from different prompts cluster together."
)
gr.HTML(
"""<span style="color:red" font-size:smaller>⚠️ DISCLAIMER: the images displayed by this tool were generated by text-to-image systems and may depict offensive stereotypes or contain explicit content.</span>"""
)
num_clusters = gr.Radio(
[12, 24, 48],
value=12,
label="How many clusters do you want to make from the data?",
)
with gr.Row():
with gr.Column(scale=4):
gallery = gr.Gallery(label="Most representative images in cluster").style(
grid=(3, 3)
)
with gr.Column():
cluster_id = gr.Dropdown(
choices=[i for i in range(num_clusters.value)],
value=0,
label="Select cluster to visualize:",
)
a = gr.Text(label="Number of images")
with gr.Row():
with gr.Column(scale=1):
c = gr.Plot(label="How many images from each system?")
c_desc = gr.HTML(label="")
with gr.Column(scale=1):
b = gr.Plot(label="How many gender terms are represented?")
b_desc = gr.HTML(label="")
with gr.Column(scale=2):
d = gr.Plot(label="Which ethnicity terms are present?")
d_desc = gr.HTML(label="")
gr.Markdown(
f"The 'System makeup' plot corresponds to the number of images from the cluster that come from each of the TTI systems that we are comparing: Dall-E 2, Stable Diffusion v.1.4. and Stable Diffusion v.2."
)
gr.Markdown(
'The Gender plot shows the number of images based on the input prompts that used the words man, woman, non-binary person, and unmarked, which we label "person".'
)
gr.Markdown(
f"The 'Ethnicity label makeup' plot corresponds to the number of images from each of the 18 ethnicities used in the prompts. A blank value means unmarked ethnicity."
)
demo.load(
fn=show_cluster,
inputs=[cluster_id, num_clusters],
outputs=[a, b, b_desc, c, c_desc, d, d_desc, gallery, cluster_id],
)
num_clusters.change(
fn=show_cluster,
inputs=[cluster_id, num_clusters],
outputs=[
a,
b,
b_desc,
c,
c_desc,
d,
d_desc,
gallery,
cluster_id,
],
)
cluster_id.change(
fn=show_cluster,
inputs=[cluster_id, num_clusters],
outputs=[a, b, b_desc, c, c_desc, d, d_desc, gallery, cluster_id],
)
if __name__ == "__main__":
demo.queue().launch(debug=True)