Spaces:
Runtime error
Runtime error
| import json | |
| import gradio as gr | |
| import os | |
| from PIL import Image | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| import operator | |
| TITLE = "Diffusion Faces Cluster Explorer" | |
| clusters_12 = json.load(open("clusters/id_all_blip_clusters_12.json")) | |
| clusters_24 = json.load(open("clusters/id_all_blip_clusters_24.json")) | |
| clusters_48 = json.load(open("clusters/id_all_blip_clusters_48.json")) | |
| clusters_by_size = { | |
| 12: clusters_12, | |
| 24: clusters_24, | |
| 48: clusters_48, | |
| } | |
| def to_string(label): | |
| if label == "SD_2": | |
| label = "Stable Diffusion 2.0" | |
| elif label == "SD_14": | |
| label = "Stable Diffusion 1.4" | |
| elif label == "DallE": | |
| label = "Dall-E 2" | |
| elif label == "non-binary": | |
| label = "non-binary person" | |
| elif label == "person": | |
| label = "<i>unmarked</i> (person)" | |
| elif label == "gender": | |
| label = "gender term" | |
| return label | |
| def describe_cluster(cl_dict, block="label"): | |
| labels_values = sorted(cl_dict.items(), key=operator.itemgetter(1)) | |
| labels_values.reverse() | |
| total = float(sum(cl_dict.values())) | |
| lv_prcnt = list( | |
| (item[0], round(item[1] * 100 / total, 0)) for item in labels_values) | |
| top_label = lv_prcnt[0][0] | |
| description_string = "<span>The most represented %s is <b>%s</b>, making up about <b>%d%%</b> of the cluster.</span>" % ( | |
| to_string(block), to_string(top_label), lv_prcnt[0][1]) | |
| description_string += "<p>This is followed by: " | |
| for lv in lv_prcnt[1:]: | |
| description_string += "<BR/><b>%s:</b> %d%%" % (to_string(lv[0]), lv[1]) | |
| description_string += "</p>" | |
| return description_string | |
| def show_cluster(cl_id, num_clusters): | |
| if not cl_id: | |
| cl_id = 0 | |
| if not num_clusters: | |
| num_clusters = 12 | |
| cl_dct = clusters_by_size[num_clusters][cl_id] | |
| images = [] | |
| for i in range(6): | |
| img_path = "/".join([st.replace("/", "") for st in | |
| cl_dct['img_path_list'][i].split("//")][3:]) | |
| images.append((Image.open(os.path.join("identities-images", img_path)), | |
| "_".join([img_path.split("/")[0], | |
| img_path.split("/")[-1]]).replace( | |
| 'Photo_portrait_of_an_', '').replace( | |
| 'Photo_portrait_of_a_', '').replace( | |
| 'SD_v2_random_seeds_identity_', '(SD v.2) ').replace( | |
| 'dataset-identities-dalle2_', '(Dall-E 2) ').replace( | |
| 'SD_v1.4_random_seeds_identity_', | |
| '(SD v.1.4) ').replace('_', ' '))) | |
| model_fig = go.Figure() | |
| model_fig.add_trace(go.Pie(labels=list(dict(cl_dct["labels_model"]).keys()), | |
| values=list( | |
| dict(cl_dct["labels_model"]).values()))) | |
| model_description = describe_cluster(dict(cl_dct["labels_model"]), "model") | |
| gender_fig = go.Figure() | |
| gender_fig.add_trace( | |
| go.Pie(labels=list(dict(cl_dct["labels_gender"]).keys()), | |
| values=list(dict(cl_dct["labels_gender"]).values()))) | |
| gender_description = describe_cluster(dict(cl_dct["labels_gender"]), | |
| "gender") | |
| ethnicity_fig = go.Figure() | |
| ethnicity_fig.add_trace( | |
| go.Bar(x=list(dict(cl_dct["labels_ethnicity"]).keys()), | |
| y=list(dict(cl_dct["labels_ethnicity"]).values()), | |
| marker_color=px.colors.qualitative.G10)) | |
| return (len(cl_dct['img_path_list']), | |
| gender_fig, gender_description, | |
| model_fig, model_description, | |
| ethnicity_fig, | |
| images, | |
| gr.update(maximum=num_clusters - 1)) | |
| with gr.Blocks(title=TITLE) as demo: | |
| gr.Markdown(f"# {TITLE}") | |
| gr.Markdown( | |
| "## Explore the data generated from [DiffusionBiasExplorer](https://huggingface.co/spaces/society-ethics/DiffusionBiasExplorer)!") | |
| gr.Markdown( | |
| "### This demo showcases patterns in the images generated from different prompts input to Stable Diffusion and Dalle-2 diffusion models.") | |
| gr.Markdown( | |
| "### Below, see results on how the images from different prompts cluster together.") | |
| gr.HTML( | |
| """<span style="color:red" font-size:smaller>⚠️ DISCLAIMER: the images displayed by this tool were generated by text-to-image models and may depict offensive stereotypes or contain explicit content.</span>""") | |
| num_clusters = gr.Radio([12, 24, 48], value=12, | |
| label="How many clusters do you want to make from the data?") | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| gallery = gr.Gallery( | |
| label="Most representative images in cluster").style( | |
| grid=(3, 3)) | |
| with gr.Column(): | |
| cluster_id = gr.Slider(minimum=0, maximum=num_clusters.value - 1, | |
| step=1, value=0, | |
| label="Click to move between clusters") | |
| a = gr.Text(label="Number of images") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| c = gr.Plot(label="How many images from each model?") | |
| c_desc = gr.HTML(label="") | |
| with gr.Column(scale=1): | |
| b = gr.Plot(label="How many gender terms are represented?") | |
| b_desc = gr.HTML(label="") | |
| with gr.Column(scale=2): | |
| d = gr.Plot(label="Which ethnicity terms are present?") | |
| gr.Markdown( | |
| f"The 'Model makeup' plot corresponds to the number of images from the cluster that come from each of the TTI systems that we are comparing: Dall-E 2, Stable Diffusion v.1.4. and Stable Diffusion v.2.") | |
| gr.Markdown( | |
| 'The Gender plot shows the number of images based on the input prompts that used the words man, woman, non-binary person, and unmarked, which we label "person".') | |
| gr.Markdown( | |
| f"The 'Ethnicity label makeup' plot corresponds to the number of images from each of the 18 ethnicities used in the prompts. A blank value means unmarked ethnicity.") | |
| demo.load(fn=show_cluster, inputs=[cluster_id, num_clusters], | |
| outputs=[a, b, b_desc, c, c_desc, d, gallery, cluster_id]) | |
| num_clusters.change(fn=show_cluster, inputs=[cluster_id, num_clusters], | |
| outputs=[a, b, b_desc, c, c_desc, d, gallery, | |
| cluster_id]) | |
| cluster_id.change(fn=show_cluster, inputs=[cluster_id, num_clusters], | |
| outputs=[a, b, b_desc, c, c_desc, d, gallery, cluster_id]) | |
| if __name__ == "__main__": | |
| demo.queue().launch(debug=True) |