JiminHeo commited on
Commit
1e39b03
·
1 Parent(s): 85ecc11
Files changed (1) hide show
  1. ldm/guided_diffusion/loss_vq.py +1 -1
ldm/guided_diffusion/loss_vq.py CHANGED
@@ -54,7 +54,7 @@ class VQLPIPSWithDiscriminator(nn.Module):
54
  self.pixel_weight = pixelloss_weight
55
  if perceptual_loss == "lpips":
56
  print(f"{self.__class__.__name__}: Running with LPIPS.")
57
- self.perceptual_loss = LPIPS().eval().to(device="cuda")
58
  else:
59
  raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
60
  self.perceptual_weight = perceptual_weight
 
54
  self.pixel_weight = pixelloss_weight
55
  if perceptual_loss == "lpips":
56
  print(f"{self.__class__.__name__}: Running with LPIPS.")
57
+ self.perceptual_loss = LPIPS().eval().cuda()
58
  else:
59
  raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
60
  self.perceptual_weight = perceptual_weight