wenxiang guo commited on
Commit
dbaf2bf
·
verified ·
1 Parent(s): fac679f

Update ldm/modules/diffusionmodules/flag_large_dit.py

Browse files
ldm/modules/diffusionmodules/flag_large_dit.py CHANGED
@@ -241,8 +241,11 @@ class TxtFlagLargeDiT(nn.Module):
241
 
242
  print(f"theta {theta} rope scaling {rope_scaling_factor} ntk {ntk_factor}")
243
 
 
 
 
244
  freqs = 1.0 / (theta ** (
245
- torch.arange(0, dim, 2)[: (dim // 2)].float().cuda() / dim
246
  ))
247
  t = torch.arange(end, device=freqs.device, dtype=torch.float) # type: ignore
248
  t = t / rope_scaling_factor
 
241
 
242
  print(f"theta {theta} rope scaling {rope_scaling_factor} ntk {ntk_factor}")
243
 
244
+ # freqs = 1.0 / (theta ** (
245
+ # torch.arange(0, dim, 2)[: (dim // 2)].float().cuda() / dim
246
+ # ))
247
  freqs = 1.0 / (theta ** (
248
+ torch.arange(0, dim, 2)[: (dim // 2)].float() / dim
249
  ))
250
  t = torch.arange(end, device=freqs.device, dtype=torch.float) # type: ignore
251
  t = t / rope_scaling_factor