wenxiang guo commited on
Commit
060e76c
·
verified ·
1 Parent(s): 78435ed

Update ldm/modules/diffusionmodules/flag_large_dit.py

Browse files
ldm/modules/diffusionmodules/flag_large_dit.py CHANGED
@@ -240,13 +240,14 @@ class TxtFlagLargeDiT(nn.Module):
240
  theta = theta * ntk_factor
241
 
242
  print(f"theta {theta} rope scaling {rope_scaling_factor} ntk {ntk_factor}")
243
-
244
- freqs = 1.0 / (theta ** (
245
  torch.arange(0, dim, 2)[: (dim // 2)].float().cuda() / dim
246
  ))
247
- # freqs = 1.0 / (theta ** (
248
- # torch.arange(0, dim, 2)[: (dim // 2)].float() / dim
249
- # ))
 
250
  t = torch.arange(end, device=freqs.device, dtype=torch.float) # type: ignore
251
  t = t / rope_scaling_factor
252
  freqs = torch.outer(t, freqs).float() # type: ignore
 
240
  theta = theta * ntk_factor
241
 
242
  print(f"theta {theta} rope scaling {rope_scaling_factor} ntk {ntk_factor}")
243
+ if torch.cuda.is_available():
244
+ freqs = 1.0 / (theta ** (
245
  torch.arange(0, dim, 2)[: (dim // 2)].float().cuda() / dim
246
  ))
247
+ else:
248
+ freqs = 1.0 / (theta ** (
249
+ torch.arange(0, dim, 2)[: (dim // 2)].float() / dim
250
+ ))
251
  t = torch.arange(end, device=freqs.device, dtype=torch.float) # type: ignore
252
  t = t / rope_scaling_factor
253
  freqs = torch.outer(t, freqs).float() # type: ignore