File size: 5,726 Bytes
5e01e13
ea06e5a
5e01e13
 
313342d
 
7adeb33
ea06e5a
602f686
5e01e13
 
 
ea06e5a
5e01e13
 
 
 
 
 
cbc358c
 
 
 
 
 
 
 
 
7adeb33
 
 
 
 
cbc358c
 
 
 
 
 
7adeb33
 
5e01e13
be4f1b2
 
 
 
5e01e13
 
 
 
313342d
 
c6ba6c5
 
7adeb33
 
313342d
c6ba6c5
 
7adeb33
 
313342d
 
 
 
5e01e13
7adeb33
 
313342d
5e01e13
 
602f686
 
cbc358c
 
 
602f686
cbc358c
602f686
7adeb33
be4f1b2
602f686
 
be4f1b2
602f686
 
313342d
c6ba6c5
9e9abb9
cbc358c
c6ba6c5
9e9abb9
cbc358c
c6ba6c5
9e9abb9
c7632f4
939debd
fb1b078
 
7adeb33
 
 
602f686
 
c6ba6c5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import json
import gradio as gr
import os
from PIL import Image
import plotly.graph_objects as go
import plotly.express as px
import operator

TITLE = "Diffusion Faces Cluster Explorer"
clusters_12 = json.load(open("clusters/id_all_blip_clusters_12.json"))
clusters_24 = json.load(open("clusters/id_all_blip_clusters_24.json"))
clusters_48 = json.load(open("clusters/id_all_blip_clusters_48.json"))

clusters_by_size = {
    12: clusters_12,
    24: clusters_24,
    48: clusters_48,
}

def to_string(label):
    if label == "SD_2":
        label = "Stable Diffusion 2"
    elif label == "SD_14":
        label = "Stable Diffusion 14"
    elif label == "DallE":
        label = "Dall-E 2"
    return label

def describe_cluster(cl_dict, block="label"):
    labels_values = sorted(cl_dict.items(), key=operator.itemgetter(1))
    labels_values.reverse()
    total = float(sum(cl_dict.values()))
    lv_prcnt = list((item[0], round(item[1] * 100/total, 0)) for item in labels_values)
    top_label = lv_prcnt[0][0]
    description_string = "<span>The most represented %s is <b>%s</b>, making up about %d%% of the cluster.</span>" % (block, to_string(lv_prcnt[0][0]), lv_prcnt[0][1])
    description_string += "<p>This is followed by: "
    for lv in lv_prcnt[1:]:
        description_string += "<BR/><b>%s:</b> %d%%" % (to_string(lv[0]), lv[1])
    description_string += "</p>"
    return description_string

def show_cluster(cl_id, num_clusters):
    if not cl_id:
        cl_id = 0
    if not num_clusters:
        num_clusters = 12
    cl_dct = clusters_by_size[num_clusters][cl_id]
    images = []
    for i in range(6):
        img_path = "/".join([st.replace("/", "") for st in cl_dct['img_path_list'][i].split("//")][3:])
        images.append((Image.open(os.path.join("identities-images", img_path)), "_".join([img_path.split("/")[0], img_path.split("/")[-1]]).replace('Photo_portrait_of_an_','').replace('Photo_portrait_of_a_','').replace('SD_v2_random_seeds_identity_','(SD v.2) ').replace('dataset-identities-dalle2_','(Dall-E 2) ').replace('SD_v1.4_random_seeds_identity_','(SD v.1.4) ').replace('_',' ')))
    model_fig = go.Figure()
    model_fig.add_trace(go.Pie(labels=list(dict(cl_dct["labels_model"]).keys()), 
                               values=list(dict(cl_dct["labels_model"]).values())))
    model_description = describe_cluster(dict(cl_dct["labels_model"]), "model")

    gender_fig = go.Figure()
    gender_fig.add_trace(go.Pie(labels=list(dict(cl_dct["labels_gender"]).keys()), 
                                values=list(dict(cl_dct["labels_gender"]).values())))
    gender_description = describe_cluster(dict(cl_dct["labels_gender"]), "gender")

    ethnicity_fig = go.Figure()
    ethnicity_fig.add_trace(go.Bar(x=list(dict(cl_dct["labels_ethnicity"]).keys()), 
                                   y=list(dict(cl_dct["labels_ethnicity"]).values()), 
                                   marker_color=px.colors.qualitative.G10))
    return (len(cl_dct['img_path_list']),
            gender_fig,gender_description,
            model_fig, model_description,
            ethnicity_fig,
            images)

with gr.Blocks(title=TITLE) as demo:
    gr.Markdown(f"# {TITLE}")
    gr.Markdown("## Explore the data generated from [DiffusionBiasExplorer](https://huggingface.co/spaces/society-ethics/DiffusionBiasExplorer)!")
    gr.Markdown("### This demo showcases patterns in the images generated from different prompts input to Stable Diffusion and Dalle-2 diffusion models.")
    gr.Markdown("### Below, see results on how the images from different prompts cluster together.")
    gr.HTML("""<span style="color:red" font-size:smaller>⚠️ DISCLAIMER: the images displayed by this tool were generated by text-to-image models and may depict offensive stereotypes or contain explicit content.</span>""")
    num_clusters = gr.Radio([12,24,48], value=12, label="How many clusters do you want to make from the data?")

    
    with gr.Row():
        with gr.Column(scale=4):
            gallery = gr.Gallery(label="Most representative images in cluster").style(grid=(3,3))
        with gr.Column():
            cluster_id = gr.Slider(minimum=0, maximum=num_clusters.value-1, step=1, value=0, label="Click to move between clusters")
            a = gr.Text(label="Number of images")
    with gr.Row():
            with gr.Column(scale=1):
                c = gr.Plot(label="How many images from each model?")
                c_desc = gr.HTML(label="")
            with gr.Column(scale=1):
                b = gr.Plot(label="How many genders are represented?")
                b_desc = gr.HTML(label="")
            with gr.Column(scale=2):
                d = gr.Plot(label="Which ethnicities are present?")

    gr.Markdown(f"The 'Model makeup' plot corresponds to the number of images from the cluster that come from each of the TTI systems that we are comparing: Dall-E 2, Stable Diffusion v.1.4. and Stable Diffusion v.2.")
    gr.Markdown('The Gender plot shows the number of images based on the input prompts that used the words male, female, non-binary, and unmarked, which we label "person".')
    gr.Markdown(f"The 'Ethnicity label makeup' plot corresponds to the number of images from each of the 18 ethnicities used in the prompts. A blank value means unmarked ethnicity.")
    demo.load(fn=show_cluster, inputs=[cluster_id, num_clusters], outputs=[a, b, b_desc, c, c_desc, d, gallery])
    num_clusters.change(fn=show_cluster, inputs=[cluster_id, num_clusters], outputs=[a, b, b_desc, c, c_desc, d, gallery])
    cluster_id.change(fn=show_cluster, inputs=[cluster_id, num_clusters], outputs=[a, b, b_desc, c, c_desc, d, gallery])

if __name__ == "__main__":
    demo.queue().launch(debug=True)