File size: 6,607 Bytes
5e01e13
ea06e5a
5e01e13
 
313342d
 
7adeb33
ea06e5a
602f686
5e01e13
 
 
ea06e5a
5e01e13
 
 
 
 
 
bb54091
cbc358c
 
bb54091
cbc358c
bb54091
cbc358c
 
bb54091
 
 
 
 
 
cbc358c
 
bb54091
7adeb33
 
 
 
bb54091
 
cbc358c
bb54091
 
cbc358c
 
 
 
7adeb33
 
bb54091
5e01e13
be4f1b2
 
 
 
5e01e13
 
 
bb54091
 
 
 
 
 
 
 
 
 
 
313342d
bb54091
 
 
7adeb33
 
313342d
bb54091
 
 
 
 
7adeb33
313342d
bb54091
 
 
 
5e01e13
bb54091
7adeb33
313342d
931f769
bb54091
 
5e01e13
602f686
 
bb54091
 
 
 
 
 
 
 
 
 
602f686
be4f1b2
602f686
bb54091
 
 
be4f1b2
bb54091
 
 
602f686
313342d
bb54091
 
 
 
 
 
 
 
c7632f4
bb54091
 
 
 
 
 
 
 
 
 
 
 
 
602f686
 
bb54091
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import json
import gradio as gr
import os
from PIL import Image
import plotly.graph_objects as go
import plotly.express as px
import operator

TITLE = "Diffusion Faces Cluster Explorer"
clusters_12 = json.load(open("clusters/id_all_blip_clusters_12.json"))
clusters_24 = json.load(open("clusters/id_all_blip_clusters_24.json"))
clusters_48 = json.load(open("clusters/id_all_blip_clusters_48.json"))

clusters_by_size = {
    12: clusters_12,
    24: clusters_24,
    48: clusters_48,
}


def to_string(label):
    if label == "SD_2":
        label = "Stable Diffusion 2.0"
    elif label == "SD_14":
        label = "Stable Diffusion 1.4"
    elif label == "DallE":
        label = "Dall-E 2"
    elif label == "non-binary":
        label = "non-binary person"
    elif label == "person":
        label = "<i>unmarked</i> (person)"
    elif label == "gender":
        label = "gender term"
    return label


def describe_cluster(cl_dict, block="label"):
    labels_values = sorted(cl_dict.items(), key=operator.itemgetter(1))
    labels_values.reverse()
    total = float(sum(cl_dict.values()))
    lv_prcnt = list(
        (item[0], round(item[1] * 100 / total, 0)) for item in labels_values)
    top_label = lv_prcnt[0][0]
    description_string = "<span>The most represented %s is <b>%s</b>, making up about <b>%d%%</b> of the cluster.</span>" % (
    to_string(block), to_string(top_label), lv_prcnt[0][1])
    description_string += "<p>This is followed by: "
    for lv in lv_prcnt[1:]:
        description_string += "<BR/><b>%s:</b> %d%%" % (to_string(lv[0]), lv[1])
    description_string += "</p>"
    return description_string


def show_cluster(cl_id, num_clusters):
    if not cl_id:
        cl_id = 0
    if not num_clusters:
        num_clusters = 12
    cl_dct = clusters_by_size[num_clusters][cl_id]
    images = []
    for i in range(6):
        img_path = "/".join([st.replace("/", "") for st in
                             cl_dct['img_path_list'][i].split("//")][3:])
        images.append((Image.open(os.path.join("identities-images", img_path)),
                       "_".join([img_path.split("/")[0],
                                 img_path.split("/")[-1]]).replace(
                           'Photo_portrait_of_an_', '').replace(
                           'Photo_portrait_of_a_', '').replace(
                           'SD_v2_random_seeds_identity_', '(SD v.2) ').replace(
                           'dataset-identities-dalle2_', '(Dall-E 2) ').replace(
                           'SD_v1.4_random_seeds_identity_',
                           '(SD v.1.4) ').replace('_', ' ')))
    model_fig = go.Figure()
    model_fig.add_trace(go.Pie(labels=list(dict(cl_dct["labels_model"]).keys()),
                               values=list(
                                   dict(cl_dct["labels_model"]).values())))
    model_description = describe_cluster(dict(cl_dct["labels_model"]), "model")

    gender_fig = go.Figure()
    gender_fig.add_trace(
        go.Pie(labels=list(dict(cl_dct["labels_gender"]).keys()),
               values=list(dict(cl_dct["labels_gender"]).values())))
    gender_description = describe_cluster(dict(cl_dct["labels_gender"]),
                                          "gender")

    ethnicity_fig = go.Figure()
    ethnicity_fig.add_trace(
        go.Bar(x=list(dict(cl_dct["labels_ethnicity"]).keys()),
               y=list(dict(cl_dct["labels_ethnicity"]).values()),
               marker_color=px.colors.qualitative.G10))
    return (len(cl_dct['img_path_list']),
            gender_fig, gender_description,
            model_fig, model_description,
            ethnicity_fig,
            images,
            gr.update(maximum=num_clusters - 1))


with gr.Blocks(title=TITLE) as demo:
    gr.Markdown(f"# {TITLE}")
    gr.Markdown(
        "## Explore the data generated from [DiffusionBiasExplorer](https://huggingface.co/spaces/society-ethics/DiffusionBiasExplorer)!")
    gr.Markdown(
        "### This demo showcases patterns in the images generated from different prompts input to Stable Diffusion and Dalle-2 diffusion models.")
    gr.Markdown(
        "### Below, see results on how the images from different prompts cluster together.")
    gr.HTML(
        """<span style="color:red" font-size:smaller>⚠️ DISCLAIMER: the images displayed by this tool were generated by text-to-image models and may depict offensive stereotypes or contain explicit content.</span>""")
    num_clusters = gr.Radio([12, 24, 48], value=12,
                            label="How many clusters do you want to make from the data?")

    with gr.Row():
        with gr.Column(scale=4):
            gallery = gr.Gallery(
                label="Most representative images in cluster").style(
                grid=(3, 3))
        with gr.Column():
            cluster_id = gr.Slider(minimum=0, maximum=num_clusters.value - 1,
                                   step=1, value=0,
                                   label="Click to move between clusters")
            a = gr.Text(label="Number of images")
    with gr.Row():
        with gr.Column(scale=1):
            c = gr.Plot(label="How many images from each model?")
            c_desc = gr.HTML(label="")
        with gr.Column(scale=1):
            b = gr.Plot(label="How many gender terms are represented?")
            b_desc = gr.HTML(label="")
        with gr.Column(scale=2):
            d = gr.Plot(label="Which ethnicity terms are present?")

    gr.Markdown(
        f"The 'Model makeup' plot corresponds to the number of images from the cluster that come from each of the TTI systems that we are comparing: Dall-E 2, Stable Diffusion v.1.4. and Stable Diffusion v.2.")
    gr.Markdown(
        'The Gender plot shows the number of images based on the input prompts that used the words man, woman, non-binary person, and unmarked, which we label "person".')
    gr.Markdown(
        f"The 'Ethnicity label makeup' plot corresponds to the number of images from each of the 18 ethnicities used in the prompts. A blank value means unmarked ethnicity.")
    demo.load(fn=show_cluster, inputs=[cluster_id, num_clusters],
              outputs=[a, b, b_desc, c, c_desc, d, gallery, cluster_id])
    num_clusters.change(fn=show_cluster, inputs=[cluster_id, num_clusters],
                        outputs=[a, b, b_desc, c, c_desc, d, gallery,
                                 cluster_id])
    cluster_id.change(fn=show_cluster, inputs=[cluster_id, num_clusters],
                      outputs=[a, b, b_desc, c, c_desc, d, gallery, cluster_id])

if __name__ == "__main__":
    demo.queue().launch(debug=True)