Spaces:
Runtime error
Runtime error
Commit
·
007806e
1
Parent(s):
fd16ff8
update space
Browse files
app.py
CHANGED
@@ -153,9 +153,14 @@ def main():
|
|
153 |
editing_types = ['rainbow', 'santa_hat', 'lego', 'golden', 'wooden', 'cyber']
|
154 |
# prepare models
|
155 |
for editing_type in editing_types:
|
156 |
-
tmp_model =
|
|
|
|
|
|
|
|
|
|
|
157 |
with torch.no_grad():
|
158 |
-
new_proj = nn.Linear(1024 * 2, 1024, device=
|
159 |
new_proj.weight = nn.Parameter(torch.zeros_like(new_proj.weight))
|
160 |
new_proj.weight[:, :1024].copy_(tmp_model.wrapped.input_proj.weight) #
|
161 |
new_proj.bias = nn.Parameter(torch.zeros_like(new_proj.bias))
|
@@ -164,10 +169,13 @@ def main():
|
|
164 |
|
165 |
ckp = torch.load(hf_hub_download(repo_id='silentchen/Shap_Editor', subfolder='single', filename='{}.pt'.format(editing_type)), map_location='cpu')
|
166 |
tmp_model.load_state_dict(ckp['model'])
|
167 |
-
|
|
|
|
|
|
|
168 |
initial_noise[editing_type] = noise_initial
|
169 |
noise_start_t[editing_type] = ckp['t_start']
|
170 |
-
models[editing_type] = tmp_model
|
171 |
@torch.no_grad()
|
172 |
def optimize_all(prompt, instruction,
|
173 |
rand_seed):
|
@@ -279,12 +287,14 @@ def main():
|
|
279 |
os.makedirs(general_save_path, exist_ok=True)
|
280 |
for i, latent in enumerate(state['latent']):
|
281 |
latent = latent.to(device)
|
282 |
-
text_embeddings_clip = model.cached_model_kwargs(1, dict(texts=[instruction]))
|
283 |
print("shape of latent: ", latent.clone().unsqueeze(0).shape, "instruction: ", instruction)
|
284 |
ref_latent = latent.clone().unsqueeze(0).to(device)
|
285 |
t_1 = torch.randint(noise_start_t_e_type, noise_start_t_e_type + 1, (1,), device=device).long()
|
286 |
|
287 |
noise_input = diffusion.q_sample(ref_latent, t_1, noise=noise_initial)
|
|
|
|
|
288 |
out_1 = diffusion.p_mean_variance(model, noise_input, t_1, clip_denoised=True,
|
289 |
model_kwargs=text_embeddings_clip,
|
290 |
condition_latents=ref_latent)
|
|
|
153 |
editing_types = ['rainbow', 'santa_hat', 'lego', 'golden', 'wooden', 'cyber']
|
154 |
# prepare models
|
155 |
for editing_type in editing_types:
|
156 |
+
tmp_model = model_from_config(load_config('text300M'), device=device)
|
157 |
+
# print(model_name, kwargs)
|
158 |
+
# print(model)
|
159 |
+
|
160 |
+
# xm = load_model('transmitter', de
|
161 |
+
tmp_model = load_model('text300M', device=device)
|
162 |
with torch.no_grad():
|
163 |
+
new_proj = nn.Linear(1024 * 2, 1024, device=device, dtype=tmp_model.wrapped.input_proj.weight.dtype)
|
164 |
new_proj.weight = nn.Parameter(torch.zeros_like(new_proj.weight))
|
165 |
new_proj.weight[:, :1024].copy_(tmp_model.wrapped.input_proj.weight) #
|
166 |
new_proj.bias = nn.Parameter(torch.zeros_like(new_proj.bias))
|
|
|
169 |
|
170 |
ckp = torch.load(hf_hub_download(repo_id='silentchen/Shap_Editor', subfolder='single', filename='{}.pt'.format(editing_type)), map_location='cpu')
|
171 |
tmp_model.load_state_dict(ckp['model'])
|
172 |
+
tmp_model.eval()
|
173 |
+
# print("loaded latent model")
|
174 |
+
tmp_model.to(device)
|
175 |
+
noise_initial = ckp['initial_noise']['noise'].to(device)
|
176 |
initial_noise[editing_type] = noise_initial
|
177 |
noise_start_t[editing_type] = ckp['t_start']
|
178 |
+
models[editing_type] = tmp_model.to(device)
|
179 |
@torch.no_grad()
|
180 |
def optimize_all(prompt, instruction,
|
181 |
rand_seed):
|
|
|
287 |
os.makedirs(general_save_path, exist_ok=True)
|
288 |
for i, latent in enumerate(state['latent']):
|
289 |
latent = latent.to(device)
|
290 |
+
text_embeddings_clip = model.cached_model_kwargs(1, dict(texts=[instruction])).to(device)
|
291 |
print("shape of latent: ", latent.clone().unsqueeze(0).shape, "instruction: ", instruction)
|
292 |
ref_latent = latent.clone().unsqueeze(0).to(device)
|
293 |
t_1 = torch.randint(noise_start_t_e_type, noise_start_t_e_type + 1, (1,), device=device).long()
|
294 |
|
295 |
noise_input = diffusion.q_sample(ref_latent, t_1, noise=noise_initial)
|
296 |
+
print("noise_input:", noise_input.device)
|
297 |
+
|
298 |
out_1 = diffusion.p_mean_variance(model, noise_input, t_1, clip_denoised=True,
|
299 |
model_kwargs=text_embeddings_clip,
|
300 |
condition_latents=ref_latent)
|