Spaces:
Runtime error
Runtime error
Update
Browse files
app.py
CHANGED
@@ -53,10 +53,6 @@ def get_cluster_center_image_markdown(model_name: str) -> str:
|
|
53 |
return f''
|
54 |
|
55 |
|
56 |
-
def update_distance_type(multimodal_truncation: bool) -> dict:
|
57 |
-
return gr.Dropdown.update(visible=multimodal_truncation)
|
58 |
-
|
59 |
-
|
60 |
def main():
|
61 |
args = parse_args()
|
62 |
|
@@ -85,14 +81,14 @@ def main():
|
|
85 |
step=0.05,
|
86 |
value=0.7,
|
87 |
label='Truncation psi')
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
run_button = gr.Button('Run')
|
97 |
with gr.Column():
|
98 |
result = gr.Image(label='Result', elem_id='result')
|
@@ -116,16 +112,12 @@ def main():
|
|
116 |
gr.Markdown(FOOTER)
|
117 |
|
118 |
model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
|
119 |
-
multimodal_truncation.change(fn=update_distance_type,
|
120 |
-
inputs=multimodal_truncation,
|
121 |
-
outputs=distance_type)
|
122 |
run_button.click(fn=model.set_model_and_generate_image,
|
123 |
inputs=[
|
124 |
model_name,
|
125 |
seed,
|
126 |
psi,
|
127 |
-
|
128 |
-
distance_type,
|
129 |
],
|
130 |
outputs=result)
|
131 |
model_name2.change(fn=get_sample_image_markdown,
|
|
|
53 |
return f''
|
54 |
|
55 |
|
|
|
|
|
|
|
|
|
56 |
def main():
|
57 |
args = parse_args()
|
58 |
|
|
|
81 |
step=0.05,
|
82 |
value=0.7,
|
83 |
label='Truncation psi')
|
84 |
+
truncation_type = gr.Dropdown(
|
85 |
+
[
|
86 |
+
'Multimodal (LPIPS)',
|
87 |
+
'Multimodal (L2)',
|
88 |
+
'Global',
|
89 |
+
],
|
90 |
+
value='Multimodal (LPIPS)',
|
91 |
+
label='Truncation Type')
|
92 |
run_button = gr.Button('Run')
|
93 |
with gr.Column():
|
94 |
result = gr.Image(label='Result', elem_id='result')
|
|
|
112 |
gr.Markdown(FOOTER)
|
113 |
|
114 |
model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
|
|
|
|
|
|
|
115 |
run_button.click(fn=model.set_model_and_generate_image,
|
116 |
inputs=[
|
117 |
model_name,
|
118 |
seed,
|
119 |
psi,
|
120 |
+
truncation_type,
|
|
|
121 |
],
|
122 |
outputs=result)
|
123 |
model_name2.change(fn=get_sample_image_markdown,
|
model.py
CHANGED
@@ -190,15 +190,20 @@ class Model:
|
|
190 |
return int(np.argmin(distances))
|
191 |
|
192 |
def generate_image(self, seed: int, truncation_psi: float,
|
193 |
-
|
194 |
-
distance_type: str) -> np.ndarray:
|
195 |
z = self.generate_z(seed)
|
196 |
ws = self.compute_w(z)
|
197 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
cluster_index = self.find_nearest_cluster_center(ws, distance_type)
|
199 |
w0 = self.cluster_centers[cluster_index]
|
200 |
-
else:
|
201 |
-
w0 = self.model.mapping.w_avg
|
202 |
new_ws = self.truncate_w(w0, ws, truncation_psi)
|
203 |
out = self.synthesize(new_ws)
|
204 |
out = self.postprocess(out)
|
@@ -206,8 +211,6 @@ class Model:
|
|
206 |
|
207 |
def set_model_and_generate_image(self, model_name: str, seed: int,
|
208 |
truncation_psi: float,
|
209 |
-
|
210 |
-
distance_type: str) -> np.ndarray:
|
211 |
self.set_model(model_name)
|
212 |
-
return self.generate_image(seed, truncation_psi,
|
213 |
-
distance_type)
|
|
|
190 |
return int(np.argmin(distances))
|
191 |
|
192 |
def generate_image(self, seed: int, truncation_psi: float,
|
193 |
+
truncation_type: str) -> np.ndarray:
|
|
|
194 |
z = self.generate_z(seed)
|
195 |
ws = self.compute_w(z)
|
196 |
+
if truncation_type == 'Global':
|
197 |
+
w0 = self.model.mapping.w_avg
|
198 |
+
else:
|
199 |
+
if truncation_type == 'Multimodal (LPIPS)':
|
200 |
+
distance_type = 'lpips'
|
201 |
+
elif truncation_type == 'Multimodal (L2)':
|
202 |
+
distance_type = 'l2'
|
203 |
+
else:
|
204 |
+
raise ValueError
|
205 |
cluster_index = self.find_nearest_cluster_center(ws, distance_type)
|
206 |
w0 = self.cluster_centers[cluster_index]
|
|
|
|
|
207 |
new_ws = self.truncate_w(w0, ws, truncation_psi)
|
208 |
out = self.synthesize(new_ws)
|
209 |
out = self.postprocess(out)
|
|
|
211 |
|
212 |
def set_model_and_generate_image(self, model_name: str, seed: int,
|
213 |
truncation_psi: float,
|
214 |
+
truncation_type: str) -> np.ndarray:
|
|
|
215 |
self.set_model(model_name)
|
216 |
+
return self.generate_image(seed, truncation_psi, truncation_type)
|
|