Dakerqi commited on
Commit
5566586
·
verified ·
1 Parent(s): c7ee98c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -0
app.py CHANGED
@@ -163,6 +163,11 @@ def model_main(args, master_port, rank):
163
 
164
  @torch.no_grad()
165
  def inference(args, infer_args, text_encoder, tokenizer, vae, model):
 
 
 
 
 
166
  with torch.autocast("cuda", dtype):
167
  while True:
168
  (
 
163
 
164
  @torch.no_grad()
165
  def inference(args, infer_args, text_encoder, tokenizer, vae, model):
166
+ dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
167
+ args.precision
168
+ ]
169
+ train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
170
+ torch.cuda.set_device(0)
171
  with torch.autocast("cuda", dtype):
172
  while True:
173
  (