Upload diffusion.py
Browse files- scripts/diffusion.py +8 -6
scripts/diffusion.py
CHANGED
|
@@ -110,6 +110,7 @@ class Diffusion(L.LightningModule):
|
|
| 110 |
|
| 111 |
############ FORWARD DIFFUSION #########
|
| 112 |
def subs_parameterization(self, logits, noised_latents):
|
|
|
|
| 113 |
logits = logits.float()
|
| 114 |
logits[:, :, self.mask_index] += self.neg_infinity
|
| 115 |
|
|
@@ -147,7 +148,7 @@ class Diffusion(L.LightningModule):
|
|
| 147 |
x: int torch.Tensor with shape (batch_size, diffusion_model_input_length), input.
|
| 148 |
move_chance: float torch.Tensor with shape (batch_size, 1).
|
| 149 |
"""
|
| 150 |
-
|
| 151 |
move_indices = torch.rand(* latents.shape, device=latents.device) < move_chance
|
| 152 |
noised_latents = torch.where(move_indices, self.mask_index, latents)
|
| 153 |
return noised_latents
|
|
@@ -172,13 +173,14 @@ class Diffusion(L.LightningModule):
|
|
| 172 |
|
| 173 |
xt = self.q_xt(x0, move_chance)
|
| 174 |
model_output = self.forward(xt, unet_conditioning)
|
|
|
|
|
|
|
| 175 |
|
| 176 |
# SUBS parameterization, continuous time.
|
| 177 |
-
idx = x0.long()
|
| 178 |
-
print(f'idx: {idx
|
| 179 |
-
print(f'idx
|
| 180 |
-
|
| 181 |
-
print(f'model out: {model_output.size()}')
|
| 182 |
log_p_theta = torch.gather(input=model_output, dim=-1, index=idx).squeeze(-1)
|
| 183 |
scale = (dsigma / torch.expm1(sigma))[:, None]
|
| 184 |
return - log_p_theta * scale
|
|
|
|
| 110 |
|
| 111 |
############ FORWARD DIFFUSION #########
|
| 112 |
def subs_parameterization(self, logits, noised_latents):
|
| 113 |
+
print(logits.size()) # [bsz x bsz x seq_len]
|
| 114 |
logits = logits.float()
|
| 115 |
logits[:, :, self.mask_index] += self.neg_infinity
|
| 116 |
|
|
|
|
| 148 |
x: int torch.Tensor with shape (batch_size, diffusion_model_input_length), input.
|
| 149 |
move_chance: float torch.Tensor with shape (batch_size, 1).
|
| 150 |
"""
|
| 151 |
+
latents = torch.mean(latents, dim=2) # [bsz x seq_len x 1280] --> [bsz x seq_len] as per markdown
|
| 152 |
move_indices = torch.rand(* latents.shape, device=latents.device) < move_chance
|
| 153 |
noised_latents = torch.where(move_indices, self.mask_index, latents)
|
| 154 |
return noised_latents
|
|
|
|
| 173 |
|
| 174 |
xt = self.q_xt(x0, move_chance)
|
| 175 |
model_output = self.forward(xt, unet_conditioning)
|
| 176 |
+
print(f'model out: {model_output}')
|
| 177 |
+
print(f'model out dim: {model_output.size()}') # [bsz x bsz x seq_len]
|
| 178 |
|
| 179 |
# SUBS parameterization, continuous time.
|
| 180 |
+
idx = torch.mean(x0, dim=2).long()[:, :, None]
|
| 181 |
+
print(f'idx: {idx}')
|
| 182 |
+
print(f'idx dim: {idx.size()}') # [bsz x seq_len x 1]
|
| 183 |
+
|
|
|
|
| 184 |
log_p_theta = torch.gather(input=model_output, dim=-1, index=idx).squeeze(-1)
|
| 185 |
scale = (dsigma / torch.expm1(sigma))[:, None]
|
| 186 |
return - log_p_theta * scale
|