File size: 4,222 Bytes
5e01e13
ea06e5a
5e01e13
 
313342d
 
ea06e5a
602f686
5e01e13
 
 
ea06e5a
5e01e13
 
 
 
 
 
 
be4f1b2
 
 
 
5e01e13
 
 
 
313342d
 
c6ba6c5
 
313342d
c6ba6c5
 
313342d
 
 
 
5e01e13
313342d
 
 
5e01e13
 
602f686
 
939debd
602f686
 
 
be4f1b2
602f686
 
be4f1b2
602f686
 
313342d
c6ba6c5
 
 
 
 
 
939debd
 
 
602f686
 
 
 
 
c6ba6c5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import json
import gradio as gr
import os
from PIL import Image
import plotly.graph_objects as go
import plotly.express as px

TITLE = "Diffusion Faces Cluster Explorer"
clusters_12 = json.load(open("clusters/id_all_blip_clusters_12.json"))
clusters_24 = json.load(open("clusters/id_all_blip_clusters_24.json"))
clusters_48 = json.load(open("clusters/id_all_blip_clusters_48.json"))

clusters_by_size = {
    12: clusters_12,
    24: clusters_24,
    48: clusters_48,
}

def show_cluster(cl_id, num_clusters):
    if not cl_id:
        cl_id = 0
    if not num_clusters:
        num_clusters = 12
    cl_dct = clusters_by_size[num_clusters][cl_id]
    images = []
    for i in range(6):
        img_path = "/".join([st.replace("/", "") for st in cl_dct['img_path_list'][i].split("//")][3:])
        images.append((Image.open(os.path.join("identities-images", img_path)), "_".join([img_path.split("/")[0], img_path.split("/")[-1]]).replace('Photo_portrait_of_an_','').replace('Photo_portrait_of_a_','').replace('SD_v2_random_seeds_identity_','(SD v.2) ').replace('dataset-identities-dalle2_','(Dall-E 2) ').replace('SD_v1.4_random_seeds_identity_','(SD v.1.4) ').replace('_',' ')))
    model_fig = go.Figure()
    model_fig.add_trace(go.Pie(labels=list(dict(cl_dct["labels_model"]).keys()), 
                               values=list(dict(cl_dct["labels_model"]).values())))
    gender_fig = go.Figure()
    gender_fig.add_trace(go.Pie(labels=list(dict(cl_dct["labels_gender"]).keys()), 
                                values=list(dict(cl_dct["labels_gender"]).values())))
    ethnicity_fig = go.Figure()
    ethnicity_fig.add_trace(go.Bar(x=list(dict(cl_dct["labels_ethnicity"]).keys()), 
                                   y=list(dict(cl_dct["labels_ethnicity"]).values()), 
                                   marker_color=px.colors.qualitative.G10))
    return (len(cl_dct['img_path_list']),
            gender_fig,
            model_fig,
            ethnicity_fig,
            images)

with gr.Blocks(title=TITLE) as demo:
    gr.Markdown(f"# {TITLE}")
    gr.Markdown("## This Space lets you explore the clusters based on the data generated from [DiffusionBiasExplorer](https://huggingface.co/spaces/society-ethics/DiffusionBiasExplorer).")
    gr.HTML("""<span style="color:red" font-size:smaller>⚠️ DISCLAIMER: the images displayed by this tool were generated by text-to-image models and may depict offensive stereotypes or contain explicit content.</span>""")
    num_clusters = gr.Radio([12,24,48], value=12, label="How many clusters do you want to make from the data?")

    with gr.Row():
        with gr.Column(scale=4):
            gallery = gr.Gallery(label="Most representative images in cluster").style(grid=(3,3))
        with gr.Column():
            cluster_id = gr.Slider(minimum=0, maximum=num_clusters.value-1, step=1, value=0, label="Click to move between clusters")
            a = gr.Text(label="Number of images")
    with gr.Row():
            with gr.Column(scale=1):
                c = gr.Plot(label="Model makeup of cluster")
            with gr.Column(scale=1):
                b = gr.Plot(label="Gender label makeup of cluster")
            with gr.Column(scale=2):
                d = gr.Plot(label="Ethnicity label makeup of cluster")
    gr.Markdown(f"The 'Model makeup' plot corresponds to the number of images from the cluster that come from each of the TTI systems that we are comparing: Dall-E 2, Stable Diffusion v.1.4. and Stable Diffusion v.2.")
    gr.Markdown(f"The 'Gender label makeup' plot corresponds to the number of images from each of the genders used in the prompts: male, female, non-binary and unmarked ('person')")
    gr.Markdown(f"The 'Ethnicity label makeup' plot corresponds to the number of images from each of the 18 ethnicities used in the prompts. A blank value means unmarked ethnicity")
    demo.load(fn=show_cluster, inputs=[cluster_id, num_clusters], outputs=[a,b,c,d, gallery])
    num_clusters.change(fn=show_cluster, inputs=[cluster_id, num_clusters], outputs=[a,b,c,d, gallery])
    cluster_id.change(fn=show_cluster, inputs=[cluster_id, num_clusters], outputs=[a,b,c,d, gallery])

if __name__ == "__main__":
    demo.queue().launch(debug=True)