Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| import logging | |
| from .utils import _broadcast_tensor | |
| def cal_medium(oss_steps_all): | |
| ave_steps = [] | |
| for k in range(len(oss_steps_all[0])): | |
| l = [] | |
| for i in range(len(oss_steps_all)): | |
| l.append(oss_steps_all[i][k]) | |
| l.sort() | |
| ave_steps.append((l[len(l)//2] + l[len(l)//2 - 1])//2) | |
| return ave_steps | |
| def infer_OSS(oss_steps, model, z, class_emb, device, renorm_flag=False, max_amp=None, min_amp=None, float32=True, model_kwargs=None): | |
| # z [B,C.H,W] | |
| N = len(oss_steps) | |
| B = z.shape[0] | |
| oss_steps = [0] + oss_steps | |
| ori_z = z.clone() | |
| # First is zero | |
| timesteps = model.fm_steps | |
| if model_kwargs is None: | |
| model_kwargs = {} | |
| for i in reversed(range(1,N+1)): | |
| logging.info(f"steps {i}") | |
| t = torch.ones((B,), device=device, dtype=torch.long) * oss_steps[i] | |
| if renorm_flag: | |
| max_s = torch.quantile(z.reshape((z.shape[0], -1)), 0.95, dim=1) | |
| min_s = torch.quantile(z.reshape((z.shape[0], -1)), 0.05, dim=1) | |
| max_amp_tmp = max_amp[oss_steps[i]-1] | |
| min_amp_tmp = min_amp[oss_steps[i]-1] | |
| ak = (max_amp_tmp - min_amp_tmp)/(max_s - min_s) | |
| ab = min_amp_tmp - ak*min_s | |
| z = _broadcast_tensor(ak, z.shape) *z + _broadcast_tensor(ab, z.shape) | |
| vt = model(z, t, class_emb, model_kwargs) | |
| if float32: | |
| z = z.to(torch.float32) | |
| z = z + (timesteps[oss_steps[i-1]] - timesteps[oss_steps[i]]) * vt | |
| if float32: | |
| z = z.to(ori_z.dtype) | |
| return z | |
| def search_OSS_video(model, z, batch_size, class_emb, device, teacher_steps=200, student_steps=5, norm=2, model_kwargs=None, frame_type="6", channel_type="4", random_channel=False, float32=True): | |
| # z [B,C.H,W] | |
| # model_kwargs doesn't contain class_embedding, which is another seperate input | |
| # the class_embedding is the same for all the searching samples here. | |
| B = batch_size | |
| N = teacher_steps | |
| STEP = student_steps | |
| assert z.shape[0] == B | |
| channel_size = z.shape[1] | |
| frame_size = z.shape[2] | |
| # First is zero | |
| timesteps = model.fm_steps | |
| if model_kwargs is None: | |
| model_kwargs = {} | |
| # Compute the teacher trajectory | |
| traj_tea = torch.stack([torch.ones_like(z)]*(N+1), dim = 0) | |
| traj_tea[N] = z | |
| z_tea = z.clone() | |
| for i in reversed(range(1,N+1)): | |
| print("teachering,%s"%i) | |
| t = torch.ones((B,), device=device, dtype=torch.long) * i | |
| vt = model(z_tea, t, class_emb, model_kwargs) | |
| if float32: | |
| z_tea = z_tea.to(torch.float32) | |
| z_tea = z_tea + vt * (timesteps[i-1] - timesteps[i]) | |
| if float32: | |
| z_tea = z_tea.to(z.dtype) | |
| traj_tea[i-1] = z_tea.clone() | |
| # solving dynamic programming | |
| all_steps = [] | |
| for i_batch in range(B): # process each image separately | |
| z_cur = z[i_batch].clone().unsqueeze(0) | |
| if class_emb is not None: | |
| class_emb_cur = class_emb[i_batch].clone().unsqueeze(0) | |
| else: | |
| class_emb_cur = None | |
| tracestep = [torch.ones(N+1, device=device, dtype=torch.long)*N for _ in range(STEP+1)] | |
| dp = torch.ones((STEP+1,N+1), device=device)*torch.inf | |
| z_prev = torch.cat([z_cur]*(N+1),dim=0) | |
| z_next = z_prev.clone() | |
| for k in range(STEP): | |
| print("studenting,%s" % k) | |
| logging.info(f"Doing k step solving {k}") | |
| if random_channel and channel_type != "all": | |
| channel_select = np.random.choice(range(channel_size), size=int(channel_type), replace=False) | |
| for i in reversed(range(1,N+1)): | |
| z_i = z_prev[i].unsqueeze(0) | |
| t = torch.ones((1,), device=device, dtype=torch.long) * i | |
| vt = model(z_i, t, class_emb_cur, model_kwargs) | |
| for j in reversed(range(0,i)): | |
| if float32: | |
| z_i = z_i.to(torch.float32) | |
| z_nxt = z_i + vt * (timesteps[j] - timesteps[i]) | |
| if float32: | |
| z_nxt = z_nxt.to(z.dtype) | |
| if random_channel: | |
| pass | |
| elif channel_type == "all": | |
| channel_select = list(range(channel_size)) | |
| else: | |
| channel_select = list(range(int(channel_type))) | |
| if frame_type == "all": | |
| frame_select = list(range(frame_size)) | |
| else: | |
| frame_select = torch.linspace(0,frame_size-1,int(frame_type),dtype=torch.long).tolist() | |
| channel_select_tensor = torch.tensor(channel_select, dtype=torch.long, device=device) | |
| frame_select_tensor = torch.tensor(frame_select, dtype=torch.long, device=device) | |
| z_nxt_select = z_nxt[0, channel_select_tensor,...] | |
| traj_tea_select = traj_tea[j, i_batch, channel_select_tensor, ...] | |
| z_nxt_select = z_nxt_select[:,frame_select_tensor,... ] | |
| traj_tea_select = traj_tea_select[:,frame_select_tensor,... ] | |
| cost = (torch.abs(z_nxt_select - traj_tea_select))**norm | |
| cost = cost.mean() | |
| if cost < dp[k][j]: | |
| dp[k][j] = cost | |
| tracestep[k][j] = i | |
| z_next[j] = z_nxt | |
| dp[k+1] = dp[k].clone() | |
| tracestep[k+1] = tracestep[k].clone() | |
| z_prev = z_next.clone() | |
| logging.info(f"finish {k} steps") | |
| cur_step = [0] | |
| for kk in reversed(range(k+1)): | |
| j = cur_step[-1] | |
| cur_step.append(int(tracestep[kk][j].item())) | |
| logging.info(cur_step) | |
| # trace back | |
| final_step = [0] | |
| for k in reversed(range(STEP)): | |
| j = final_step[-1] | |
| final_step.append(int(tracestep[k][j].item())) | |
| logging.info(final_step) | |
| all_steps.append(final_step[1:]) | |
| return all_steps[0] | |
| def search_OSS(model, z, batch_size, class_emb, device, teacher_steps=200, student_steps=5, model_kwargs=None): | |
| # z [B,C.H,W] | |
| # model_kwargs doesn't contain class_embedding, which is another seperate input | |
| # the class_embedding is the same for all the searching samples here. | |
| B = batch_size | |
| N = teacher_steps | |
| STEP = student_steps | |
| assert z.shape[0] == B | |
| # First is zero | |
| timesteps = model.fm_steps | |
| if model_kwargs is None: | |
| model_kwargs = {} | |
| # Compute the teacher trajectory | |
| traj_tea = torch.stack([torch.ones_like(z)]*(N+1), dim = 0) | |
| traj_tea[N] = z | |
| z_tea = z.clone() | |
| for i in reversed(range(1,N+1)): | |
| t = torch.ones((B,), device=device, dtype=torch.long) * i | |
| vt = model(z_tea, t, class_emb, model_kwargs) | |
| z_tea = z_tea + vt * (timesteps[i-1] - timesteps[i]) | |
| traj_tea[i-1] = z_tea.clone() | |
| # solving dynamic programming | |
| all_steps = [] | |
| for i_batch in range(B): # process each image separately | |
| z_cur = z[i_batch].clone().unsqueeze(0) | |
| if class_emb is not None: | |
| class_emb_cur = class_emb[i_batch].clone().unsqueeze(0) | |
| else: | |
| class_emb_cur = None | |
| tracestep = [torch.ones(N+1, device=device, dtype=torch.long)*N for _ in range(STEP+1)] | |
| dp = torch.ones((STEP+1,N+1), device=device)*torch.inf | |
| z_prev = torch.cat([z_cur]*(N+1),dim=0) | |
| z_next = z_prev.clone() | |
| for k in range(STEP): | |
| logging.info(f"Doing k step solving {k}") | |
| for i in reversed(range(1,N+1)): | |
| z_i = z_prev[i].unsqueeze(0) | |
| t = torch.ones((1,), device=device, dtype=torch.long) * i | |
| vt = model(z_i, t, class_emb_cur, model_kwargs) | |
| for j in reversed(range(0,i)): | |
| z_j = z_i + vt * (timesteps[j] - timesteps[i]) | |
| cost = (z_j - traj_tea[j, i_batch])**2 | |
| cost = cost.mean() | |
| if cost < dp[k][j]: | |
| dp[k][j] = cost | |
| tracestep[k][j] = i | |
| z_next[j] = z_j | |
| dp[k+1] = dp[k].clone() | |
| tracestep[k+1] = tracestep[k].clone() | |
| z_prev = z_next.clone() | |
| # trace back | |
| final_step = [0] | |
| for k in reversed(range(STEP)): | |
| j = final_step[-1] | |
| final_step.append(int(tracestep[k][j].item())) | |
| logging.info(final_step) | |
| all_steps.append(final_step[1:]) | |
| return all_steps | |
| def search_OSS_batch(model, z, batch_size, class_emb, device, teacher_steps=200, student_steps=5, model_kwargs=None): | |
| # z [B,C.H,W] | |
| # model_kwargs doesn't contain class_embedding, which is another seperate input | |
| # the class_embedding is the same for all the searching samples here. | |
| B = batch_size | |
| N = teacher_steps | |
| STEP = student_steps | |
| # First is zero | |
| timesteps = model.fm_steps | |
| if model_kwargs is None: | |
| model_kwargs = {} | |
| # Compute the teacher trajectory | |
| traj_tea = torch.stack([torch.ones_like(z)]*(N+1), dim = 0) | |
| traj_tea[N] = z | |
| z_tea = z.clone() | |
| for i in reversed(range(1,N+1)): | |
| t = torch.ones((B,), device=device, dtype=torch.long) * i | |
| vt = model(z_tea, t, class_emb, model_kwargs) | |
| z_tea = z_tea + (timesteps[i-1] - timesteps[i]) * vt | |
| traj_tea[i-1] = z_tea.clone() | |
| times = torch.linspace(0,N,N+1, device=device).long() | |
| t_in = torch.linspace(1, N, N, device=device).long() | |
| # solving dynamic programming | |
| all_steps = [] | |
| for i_batch in range(B): # process each image separately | |
| z_cur = z[i_batch].clone().unsqueeze(0) | |
| if class_emb is not None: | |
| class_emb_cur = class_emb[i_batch].clone().unsqueeze(0).expand(N) | |
| else: | |
| class_emb_cur = None | |
| tracestep = [torch.ones(N+1, device=device, dtype=torch.long)*N for _ in range(STEP+1)] | |
| dp = torch.ones((STEP+1,N+1), device=device)*torch.inf | |
| z_prev = torch.cat([z_cur]*(N+1),dim=0) | |
| z_next = z_prev.clone() | |
| for k in range(STEP): | |
| logging.info(f"Doing k step solving {k}") | |
| vt = model(z_prev[1:], t_in, class_emb_cur, model_kwargs) | |
| for i in reversed(range(1,N+1)): | |
| t = torch.ones((i,), device=device, dtype=torch.long) * i | |
| z_i_batch = torch.stack([z_prev[i]]*i ,dim=0) | |
| dt = timesteps[times[:i]] - timesteps[t] | |
| z_j_batch = z_i_batch[:i] + _broadcast_tensor(dt, z_i_batch.shape) * vt[i-1].unsqueeze(0) | |
| cost = (z_j_batch - traj_tea[:i,i_batch,...])**2 | |
| cost = cost.mean(dim = tuple(range(1, len(cost.shape)))) | |
| mask = cost < dp[k, :i] | |
| dp[k, :i] = torch.where(mask, cost, dp[k, :i]) | |
| tracestep[k][:i] = torch.where(mask, i, tracestep[k][:i]) | |
| expanded_mask = _broadcast_tensor(mask,z_i_batch.shape) | |
| z_next[:i] = torch.where(expanded_mask, z_j_batch, z_next[:i]) | |
| dp[k+1] = dp[k].clone() | |
| tracestep[k+1] = tracestep[k].clone() | |
| z_prev = z_next.clone() | |
| # trace back | |
| final_step = [0] | |
| for k in reversed(range(STEP)): | |
| j = final_step[-1] | |
| final_step.append(int(tracestep[k][j].item())) | |
| logging.info(final_step) | |
| all_steps.append(final_step[1:]) | |
| return all_steps | |