yjernite commited on
Commit
0d44baa
·
1 Parent(s): 3911108

ethnicity description and dropdown selector

Browse files
Files changed (1) hide show
  1. app.py +114 -55
app.py CHANGED
@@ -29,23 +29,32 @@ def to_string(label):
29
  label = "non-binary person"
30
  elif label == "person":
31
  label = "<i>unmarked</i> (person)"
 
 
32
  elif label == "gender":
33
  label = "gender term"
34
  return label
35
 
36
 
37
- def describe_cluster(cl_dict, block="label"):
38
  labels_values = sorted(cl_dict.items(), key=operator.itemgetter(1))
39
  labels_values.reverse()
40
  total = float(sum(cl_dict.values()))
41
  lv_prcnt = list(
42
- (item[0], round(item[1] * 100 / total, 0)) for item in labels_values)
 
43
  top_label = lv_prcnt[0][0]
44
- description_string = "<span>The most represented %s is <b>%s</b>, making up about <b>%d%%</b> of the cluster.</span>" % (
45
- to_string(block), to_string(top_label), lv_prcnt[0][1])
 
 
46
  description_string += "<p>This is followed by: "
47
- for lv in lv_prcnt[1:]:
48
  description_string += "<BR/><b>%s:</b> %d%%" % (to_string(lv[0]), lv[1])
 
 
 
 
49
  description_string += "</p>"
50
  return description_string
51
 
@@ -58,65 +67,94 @@ def show_cluster(cl_id, num_clusters):
58
  cl_dct = clusters_by_size[num_clusters][cl_id]
59
  images = []
60
  for i in range(6):
61
- img_path = "/".join([st.replace("/", "") for st in
62
- cl_dct['img_path_list'][i].split("//")][3:])
63
- images.append((Image.open(os.path.join("identities-images", img_path)),
64
- "_".join([img_path.split("/")[0],
65
- img_path.split("/")[-1]]).replace(
66
- 'Photo_portrait_of_an_', '').replace(
67
- 'Photo_portrait_of_a_', '').replace(
68
- 'SD_v2_random_seeds_identity_', '(SD v.2) ').replace(
69
- 'dataset-identities-dalle2_', '(Dall-E 2) ').replace(
70
- 'SD_v1.4_random_seeds_identity_',
71
- '(SD v.1.4) ').replace('_', ' ')))
 
 
 
 
72
  model_fig = go.Figure()
73
- model_fig.add_trace(go.Pie(labels=list(dict(cl_dct["labels_model"]).keys()),
74
- values=list(
75
- dict(cl_dct["labels_model"]).values())))
 
 
 
76
  model_description = describe_cluster(dict(cl_dct["labels_model"]), "system")
77
 
78
  gender_fig = go.Figure()
79
  gender_fig.add_trace(
80
- go.Pie(labels=list(dict(cl_dct["labels_gender"]).keys()),
81
- values=list(dict(cl_dct["labels_gender"]).values())))
82
- gender_description = describe_cluster(dict(cl_dct["labels_gender"]),
83
- "gender")
 
 
84
 
85
  ethnicity_fig = go.Figure()
86
  ethnicity_fig.add_trace(
87
- go.Bar(x=list(dict(cl_dct["labels_ethnicity"]).keys()),
88
- y=list(dict(cl_dct["labels_ethnicity"]).values()),
89
- marker_color=px.colors.qualitative.G10))
90
- return (len(cl_dct['img_path_list']),
91
- gender_fig, gender_description,
92
- model_fig, model_description,
93
- ethnicity_fig,
94
- images,
95
- gr.update(maximum=num_clusters - 1))
 
 
 
 
 
 
 
 
 
 
 
96
 
97
 
98
  with gr.Blocks(title=TITLE) as demo:
99
  gr.Markdown(f"# {TITLE}")
100
  gr.Markdown(
101
- "## Explore the data generated from [DiffusionBiasExplorer](https://huggingface.co/spaces/society-ethics/DiffusionBiasExplorer)!")
 
102
  gr.Markdown(
103
- "### This demo showcases patterns in the images generated from different prompts input to Stable Diffusion and Dalle-2 systems.")
 
104
  gr.Markdown(
105
- "### Below, see results on how the images from different prompts cluster together.")
 
106
  gr.HTML(
107
- """<span style="color:red" font-size:smaller>⚠️ DISCLAIMER: the images displayed by this tool were generated by text-to-image systems and may depict offensive stereotypes or contain explicit content.</span>""")
108
- num_clusters = gr.Radio([12, 24, 48], value=12,
109
- label="How many clusters do you want to make from the data?")
 
 
 
 
110
 
111
  with gr.Row():
112
  with gr.Column(scale=4):
113
- gallery = gr.Gallery(
114
- label="Most representative images in cluster").style(
115
- grid=(3, 3))
116
  with gr.Column():
117
- cluster_id = gr.Slider(minimum=0, maximum=num_clusters.value - 1,
118
- step=1, value=0,
119
- label="Click to move between clusters")
 
 
120
  a = gr.Text(label="Number of images")
121
  with gr.Row():
122
  with gr.Column(scale=1):
@@ -127,20 +165,41 @@ with gr.Blocks(title=TITLE) as demo:
127
  b_desc = gr.HTML(label="")
128
  with gr.Column(scale=2):
129
  d = gr.Plot(label="Which ethnicity terms are present?")
 
130
 
131
  gr.Markdown(
132
- f"The 'System makeup' plot corresponds to the number of images from the cluster that come from each of the TTI systems that we are comparing: Dall-E 2, Stable Diffusion v.1.4. and Stable Diffusion v.2.")
 
133
  gr.Markdown(
134
- 'The Gender plot shows the number of images based on the input prompts that used the words man, woman, non-binary person, and unmarked, which we label "person".')
 
135
  gr.Markdown(
136
- f"The 'Ethnicity label makeup' plot corresponds to the number of images from each of the 18 ethnicities used in the prompts. A blank value means unmarked ethnicity.")
137
- demo.load(fn=show_cluster, inputs=[cluster_id, num_clusters],
138
- outputs=[a, b, b_desc, c, c_desc, d, gallery, cluster_id])
139
- num_clusters.change(fn=show_cluster, inputs=[cluster_id, num_clusters],
140
- outputs=[a, b, b_desc, c, c_desc, d, gallery,
141
- cluster_id])
142
- cluster_id.change(fn=show_cluster, inputs=[cluster_id, num_clusters],
143
- outputs=[a, b, b_desc, c, c_desc, d, gallery, cluster_id])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  if __name__ == "__main__":
146
- demo.queue().launch(debug=True)
 
29
  label = "non-binary person"
30
  elif label == "person":
31
  label = "<i>unmarked</i> (person)"
32
+ elif label == "":
33
+ label = "<i>unmarked</i> ()"
34
  elif label == "gender":
35
  label = "gender term"
36
  return label
37
 
38
 
39
+ def describe_cluster(cl_dict, block="label", max_items=4):
40
  labels_values = sorted(cl_dict.items(), key=operator.itemgetter(1))
41
  labels_values.reverse()
42
  total = float(sum(cl_dict.values()))
43
  lv_prcnt = list(
44
+ (item[0], round(item[1] * 100 / total, 0)) for item in labels_values
45
+ )
46
  top_label = lv_prcnt[0][0]
47
+ description_string = (
48
+ "<span>The most represented %s is <b>%s</b>, making up about <b>%d%%</b> of the cluster.</span>"
49
+ % (to_string(block), to_string(top_label), lv_prcnt[0][1])
50
+ )
51
  description_string += "<p>This is followed by: "
52
+ for lv in lv_prcnt[1 : min(len(lv_prcnt), 1 + max_items)]:
53
  description_string += "<BR/><b>%s:</b> %d%%" % (to_string(lv[0]), lv[1])
54
+ if len(lv_prcnt) > max_items + 1:
55
+ description_string += "<BR/><b> - Other terms:</b> %d%%" % (
56
+ sum(lv[1] for lv in lv_prcnt[max_items + 1 :]),
57
+ )
58
  description_string += "</p>"
59
  return description_string
60
 
 
67
  cl_dct = clusters_by_size[num_clusters][cl_id]
68
  images = []
69
  for i in range(6):
70
+ img_path = "/".join(
71
+ [st.replace("/", "") for st in cl_dct["img_path_list"][i].split("//")][3:]
72
+ )
73
+ images.append(
74
+ (
75
+ Image.open(os.path.join("identities-images", img_path)),
76
+ "_".join([img_path.split("/")[0], img_path.split("/")[-1]])
77
+ .replace("Photo_portrait_of_an_", "")
78
+ .replace("Photo_portrait_of_a_", "")
79
+ .replace("SD_v2_random_seeds_identity_", "(SD v.2) ")
80
+ .replace("dataset-identities-dalle2_", "(Dall-E 2) ")
81
+ .replace("SD_v1.4_random_seeds_identity_", "(SD v.1.4) ")
82
+ .replace("_", " "),
83
+ )
84
+ )
85
  model_fig = go.Figure()
86
+ model_fig.add_trace(
87
+ go.Pie(
88
+ labels=list(dict(cl_dct["labels_model"]).keys()),
89
+ values=list(dict(cl_dct["labels_model"]).values()),
90
+ )
91
+ )
92
  model_description = describe_cluster(dict(cl_dct["labels_model"]), "system")
93
 
94
  gender_fig = go.Figure()
95
  gender_fig.add_trace(
96
+ go.Pie(
97
+ labels=list(dict(cl_dct["labels_gender"]).keys()),
98
+ values=list(dict(cl_dct["labels_gender"]).values()),
99
+ )
100
+ )
101
+ gender_description = describe_cluster(dict(cl_dct["labels_gender"]), "gender")
102
 
103
  ethnicity_fig = go.Figure()
104
  ethnicity_fig.add_trace(
105
+ go.Bar(
106
+ x=list(dict(cl_dct["labels_ethnicity"]).keys()),
107
+ y=list(dict(cl_dct["labels_ethnicity"]).values()),
108
+ marker_color=px.colors.qualitative.G10,
109
+ )
110
+ )
111
+ ethnicity_description = describe_cluster(
112
+ dict(cl_dct["labels_ethnicity"]), "ethnicity"
113
+ )
114
+
115
+ return (
116
+ len(cl_dct["img_path_list"]),
117
+ gender_fig,
118
+ gender_description,
119
+ model_fig,
120
+ model_description,
121
+ ethnicity_fig,
122
+ ethnicity_description,
123
+ images,
124
+ )
125
 
126
 
127
  with gr.Blocks(title=TITLE) as demo:
128
  gr.Markdown(f"# {TITLE}")
129
  gr.Markdown(
130
+ "## Explore the data generated from [DiffusionBiasExplorer](https://huggingface.co/spaces/society-ethics/DiffusionBiasExplorer)!"
131
+ )
132
  gr.Markdown(
133
+ "### This demo showcases patterns in the images generated from different prompts input to Stable Diffusion and Dalle-2 systems."
134
+ )
135
  gr.Markdown(
136
+ "### Below, see results on how the images from different prompts cluster together."
137
+ )
138
  gr.HTML(
139
+ """<span style="color:red" font-size:smaller>⚠️ DISCLAIMER: the images displayed by this tool were generated by text-to-image systems and may depict offensive stereotypes or contain explicit content.</span>"""
140
+ )
141
+ num_clusters = gr.Radio(
142
+ [12, 24, 48],
143
+ value=12,
144
+ label="How many clusters do you want to make from the data?",
145
+ )
146
 
147
  with gr.Row():
148
  with gr.Column(scale=4):
149
+ gallery = gr.Gallery(label="Most representative images in cluster").style(
150
+ grid=(3, 3)
151
+ )
152
  with gr.Column():
153
+ cluster_id = gr.Dropdown(
154
+ choices=[i for i in range(num_clusters.value)],
155
+ value=0,
156
+ label="Select cluster to visualize:",
157
+ )
158
  a = gr.Text(label="Number of images")
159
  with gr.Row():
160
  with gr.Column(scale=1):
 
165
  b_desc = gr.HTML(label="")
166
  with gr.Column(scale=2):
167
  d = gr.Plot(label="Which ethnicity terms are present?")
168
+ d_desc = gr.HTML(label="")
169
 
170
  gr.Markdown(
171
+ f"The 'System makeup' plot corresponds to the number of images from the cluster that come from each of the TTI systems that we are comparing: Dall-E 2, Stable Diffusion v.1.4. and Stable Diffusion v.2."
172
+ )
173
  gr.Markdown(
174
+ 'The Gender plot shows the number of images based on the input prompts that used the words man, woman, non-binary person, and unmarked, which we label "person".'
175
+ )
176
  gr.Markdown(
177
+ f"The 'Ethnicity label makeup' plot corresponds to the number of images from each of the 18 ethnicities used in the prompts. A blank value means unmarked ethnicity."
178
+ )
179
+ demo.load(
180
+ fn=show_cluster,
181
+ inputs=[cluster_id, num_clusters],
182
+ outputs=[a, b, b_desc, c, c_desc, d, d_desc, gallery],
183
+ )
184
+ num_clusters.change(
185
+ fn=show_cluster,
186
+ inputs=[cluster_id, num_clusters],
187
+ outputs=[
188
+ a,
189
+ b,
190
+ b_desc,
191
+ c,
192
+ c_desc,
193
+ d,
194
+ d_desc,
195
+ gallery,
196
+ ],
197
+ )
198
+ cluster_id.change(
199
+ fn=show_cluster,
200
+ inputs=[cluster_id, num_clusters],
201
+ outputs=[a, b, b_desc, c, c_desc, d, d_desc, gallery],
202
+ )
203
 
204
  if __name__ == "__main__":
205
+ demo.queue().launch(debug=True)