sxela commited on
Commit
f9507c6
Β·
1 Parent(s): 03667b4
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -106,13 +106,14 @@ def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, r
106
  print('Using device:', device)
107
  model, diffusion = create_model_and_diffusion(**model_config)
108
  model.load_state_dict(torch.load('256x256_openai_comics_faces_by_alex_spirin_084000.pt', map_location='cpu'))
109
- model.requires_grad_(False).eval().to(device)
110
  for name, param in model.named_parameters():
111
  if 'qkv' in name or 'norm' in name or 'proj' in name:
112
  param.requires_grad_()
113
  if model_config['use_fp16']:
114
  model.convert_to_fp16()
115
- clip_model = clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device)
 
116
  clip_size = clip_model.visual.input_resolution
117
  normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
118
  std=[0.26862954, 0.26130258, 0.27577711])
 
106
  print('Using device:', device)
107
  model, diffusion = create_model_and_diffusion(**model_config)
108
  model.load_state_dict(torch.load('256x256_openai_comics_faces_by_alex_spirin_084000.pt', map_location='cpu'))
109
+ model.requires_grad_(False).eval().to(device).float()
110
  for name, param in model.named_parameters():
111
  if 'qkv' in name or 'norm' in name or 'proj' in name:
112
  param.requires_grad_()
113
  if model_config['use_fp16']:
114
  model.convert_to_fp16()
115
+ else: model.convert_to_fp32()
116
+ clip_model = clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device).float()
117
  clip_size = clip_model.visual.input_resolution
118
  normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
119
  std=[0.26862954, 0.26130258, 0.27577711])