Infinity / models /bitwise_self_correction.py
MohamedRashad's picture
Add initial project structure with requirements and utility functions
32287b3
raw
history blame
5.27 kB
import os
import os.path as osp
import torch
import torch.nn.functional as F
import numpy as np
def labels2image(all_indices, label_type='int_label', scale_schedule=None):
summed_codes, recons_imgs = self.vae.decode_from_indices(all_indices, scale_schedule, label_type)
recons_img = recons_imgs[0]
recons_img = (recons_img + 1) / 2
recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)[:,:,::-1]
return recons_img
def features2image(raw_features):
recons_imgs = self.vae.decode(raw_features.squeeze(-3))
recons_img = recons_imgs[0]
recons_img = (recons_img + 1) / 2
recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)[:,:,::-1]
return recons_img
class BitwiseSelfCorrection(object):
def __init__(self, vae, args):
self.noise_apply_layers = args.noise_apply_layers
self.noise_apply_requant = args.noise_apply_requant
self.noise_apply_strength = args.noise_apply_strength
self.apply_spatial_patchify = args.apply_spatial_patchify
self.vae = vae
self.debug_bsc = args.debug_bsc
def flip_requant(self, vae_scale_schedule, inp_B3HW, raw_features, device):
with torch.amp.autocast('cuda', enabled = False):
B = raw_features.shape[0]
if raw_features.dim() == 4:
codes_out = raw_features.unsqueeze(2)
else:
codes_out = raw_features
cum_var_input = 0
gt_all_bit_indices = []
pred_all_bit_indices = []
x_BLC_wo_prefix = []
for si, (pt, ph, pw) in enumerate(vae_scale_schedule):
residual = codes_out - cum_var_input
if si != len(vae_scale_schedule)-1:
residual = F.interpolate(residual, size=vae_scale_schedule[si], mode=self.vae.quantizer.z_interplote_down).contiguous()
quantized, _, bit_indices, loss = self.vae.quantizer.lfq(residual) # quantized shape: [B, d_vae, 1, h, w], bit_indices shape: [B,1,h,w,d_vae]
gt_all_bit_indices.append(bit_indices)
if si < self.noise_apply_layers:
noise_apply_strength = np.random.randint(0, 100 * self.noise_apply_strength+1) * 0.01
mask = torch.rand(*bit_indices.shape).to(device) < noise_apply_strength
pred_bit_indices = bit_indices.clone()
pred_bit_indices[mask] = 1 - pred_bit_indices[mask]
pred_all_bit_indices.append(pred_bit_indices)
if self.noise_apply_requant:
quantized = self.vae.quantizer.lfq.indices_to_codes(pred_bit_indices, label_type = 'bit_label')
else:
pred_all_bit_indices.append(bit_indices)
cum_var_input = cum_var_input + F.interpolate(quantized, size=vae_scale_schedule[-1], mode=self.vae.quantizer.z_interplote_up).contiguous()
if si < len(vae_scale_schedule)-1:
this_scale_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si+1], mode=self.vae.quantizer.z_interplote_up).contiguous()
if self.apply_spatial_patchify:
# (B,d,1,H,W) -> (B,d,H,W) -> (B,4d,H/2,W/2)
this_scale_input = torch.nn.functional.pixel_unshuffle(this_scale_input.squeeze(-3), 2)
x_BLC_wo_prefix.append(this_scale_input.reshape(*this_scale_input.shape[:2], -1).permute(0,2,1)) # (B,H/2*W/2,4C) or (B,H*W,C)
if self.apply_spatial_patchify:
gt_ms_idx_Bl = []
for item in gt_all_bit_indices:
# item shape: (B,1,H,W,d)
item = item.squeeze(1).permute(0,3,1,2) # (B,d,H,W)
# (B,d,H,W) -> (B,4d,H/2,W/2)
item = torch.nn.functional.pixel_unshuffle(item, 2)
# (B,4d,H/2,W/2) -> (B,H/2,W/2,4d) -> (B,H/2*w/2,4d)
item = item.permute(0,2,3,1).reshape(B, -1, 4*self.vae.codebook_dim)
gt_ms_idx_Bl.append(item)
else:
gt_ms_idx_Bl = [item.reshape(B, -1, self.vae.codebook_dim) for item in gt_all_bit_indices]
x_BLC_wo_prefix = torch.cat(x_BLC_wo_prefix, 1)
if self.debug_bsc:
self.visualize(vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices)
return x_BLC_wo_prefix, gt_ms_idx_Bl
def visualize(self, vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices):
gt_img = (inp_B3HW.squeeze(-3) + 1) / 2 * 255
gt_img = gt_img[0].permute(1,2,0).cpu().numpy().astype(np.uint8)[:,:,::-1]
recons_img_2 = labels2image(gt_all_bit_indices, label_type='bit_label', scale_schedule=vae_scale_schedule)
recons_img_3 = labels2image(pred_all_bit_indices, label_type='bit_label', scale_schedule=vae_scale_schedule)
cat_image = np.concatenate([gt_img, recons_img_2, recons_img_3], axis=1)
save_path = osp.abspath('non_teacher_force.jpg')
cv2.imwrite(save_path, cat_image)
print(f'Save to {save_path}')
import pdb; pdb.set_trace()
print(cat_image.shape)