Dakerqi commited on
Commit
5d67b89
·
verified ·
1 Parent(s): 99924e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -41
app.py CHANGED
@@ -70,8 +70,6 @@ def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompt
70
  truncation=True,
71
  return_tensors="pt",
72
  )
73
- for name, param in text_encoder.named_parameters():
74
- print(name, param.device)
75
 
76
  print(f"Text Encoder Device: {text_encoder.device}")
77
  text_input_ids = text_inputs.input_ids.cuda()
@@ -90,7 +88,7 @@ def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompt
90
 
91
 
92
  @torch.no_grad()
93
- def model_main(args, master_port, rank, request_queue, response_queue, mp_barrier):
94
  # import here to avoid huggingface Tokenizer parallelism warnings
95
  from diffusers.models import AutoencoderKL
96
  from transformers import AutoModel, AutoTokenizer
@@ -106,10 +104,10 @@ def model_main(args, master_port, rank, request_queue, response_queue, mp_barrie
106
  # Override the built-in print with the new version
107
  builtins.print = print
108
 
109
- os.environ["MASTER_PORT"] = str(master_port)
110
- os.environ["MASTER_ADDR"] = "127.0.0.1"
111
- os.environ["RANK"] = str(rank)
112
- os.environ["WORLD_SIZE"] = str(args.num_gpus)
113
 
114
 
115
  train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
@@ -159,8 +157,12 @@ def model_main(args, master_port, rank, request_queue, response_queue, mp_barrie
159
  ckpt = torch.load(ckpt_path, map_location="cuda")
160
  model.load_state_dict(ckpt, strict=True)
161
  print('load model finish')
162
- mp_barrier.wait()
 
 
163
 
 
 
164
  with torch.autocast("cuda", dtype):
165
  while True:
166
  (
@@ -178,7 +180,7 @@ def model_main(args, master_port, rank, request_queue, response_queue, mp_barrie
178
  scaling_method,
179
  scaling_watershed,
180
  proportional_attn,
181
- ) = request_queue.get()
182
 
183
 
184
  system_prompt = system_type
@@ -243,13 +245,13 @@ def model_main(args, master_port, rank, request_queue, response_queue, mp_barrie
243
  torch.random.manual_seed(int(seed))
244
  z = torch.randn([1, 16, latent_h, latent_w], device="cuda").to(dtype)
245
  z = z.repeat(2, 1, 1, 1)
246
- model.cpu()
247
  with torch.no_grad():
248
  if neg_cap != "":
249
  cap_feats, cap_mask = encode_prompt([cap] + [neg_cap], text_encoder, tokenizer, 0.0)
250
  else:
251
  cap_feats, cap_mask = encode_prompt([cap] + [""], text_encoder, tokenizer, 0.0)
252
- model.cuda()
253
  cap_mask = cap_mask.to(cap_feats.device)
254
 
255
  model_kwargs = dict(
@@ -297,12 +299,13 @@ def model_main(args, master_port, rank, request_queue, response_queue, mp_barrie
297
  img = to_pil_image(samples[0, :].float())
298
  print("> generated image, done.")
299
 
300
- if response_queue is not None:
301
- response_queue.put((img, metadata))
302
-
303
  except Exception:
304
  print(traceback.format_exc())
305
- response_queue.put(ModelFailure())
 
306
 
307
 
308
  def none_or_str(value):
@@ -389,25 +392,27 @@ def main():
389
 
390
  master_port = find_free_port()
391
  #mp.set_start_method("fork")
392
- processes = []
393
- request_queues = []
394
- response_queue = mp.Queue()
395
- mp_barrier = mp.Barrier(args.num_gpus + 1)
396
- for i in range(args.num_gpus):
397
- request_queues.append(mp.Queue())
398
- p = mp.Process(
399
- target=model_main,
400
- args=(
401
- args,
402
- master_port,
403
- i,
404
- request_queues[i],
405
- response_queue if i == 0 else None,
406
- mp_barrier,
407
- ),
408
- )
409
- p.start()
410
- processes.append(p)
 
 
411
 
412
  description = args.ckpt.split('/')[-1]
413
  #"""
@@ -552,15 +557,18 @@ def main():
552
  ) # noqa
553
 
554
  @spaces.GPU(duration=200)
555
- def on_submit(*args):
556
- for q in request_queues:
557
- q.put(args)
558
- result = response_queue.get()
 
 
 
 
559
  if isinstance(result, ModelFailure):
560
- raise RuntimeError
561
- img, metadata = result
562
 
563
- return img, metadata
564
 
565
  submit_btn.click(
566
  on_submit,
 
70
  truncation=True,
71
  return_tensors="pt",
72
  )
 
 
73
 
74
  print(f"Text Encoder Device: {text_encoder.device}")
75
  text_input_ids = text_inputs.input_ids.cuda()
 
88
 
89
 
90
  @torch.no_grad()
91
+ def model_main(args, master_port, rank):
92
  # import here to avoid huggingface Tokenizer parallelism warnings
93
  from diffusers.models import AutoencoderKL
94
  from transformers import AutoModel, AutoTokenizer
 
104
  # Override the built-in print with the new version
105
  builtins.print = print
106
 
107
+ # os.environ["MASTER_PORT"] = str(master_port)
108
+ # os.environ["MASTER_ADDR"] = "127.0.0.1"
109
+ # os.environ["RANK"] = str(rank)
110
+ # os.environ["WORLD_SIZE"] = str(args.num_gpus)
111
 
112
 
113
  train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
 
157
  ckpt = torch.load(ckpt_path, map_location="cuda")
158
  model.load_state_dict(ckpt, strict=True)
159
  print('load model finish')
160
+
161
+ return text_encoder, tokenizer, vae, model
162
+
163
 
164
+ @torch.no_grad()
165
+ def inference(args, infer_args, text_encoder, tokenizer, vae, model)
166
  with torch.autocast("cuda", dtype):
167
  while True:
168
  (
 
180
  scaling_method,
181
  scaling_watershed,
182
  proportional_attn,
183
+ ) = infer_args
184
 
185
 
186
  system_prompt = system_type
 
245
  torch.random.manual_seed(int(seed))
246
  z = torch.randn([1, 16, latent_h, latent_w], device="cuda").to(dtype)
247
  z = z.repeat(2, 1, 1, 1)
248
+
249
  with torch.no_grad():
250
  if neg_cap != "":
251
  cap_feats, cap_mask = encode_prompt([cap] + [neg_cap], text_encoder, tokenizer, 0.0)
252
  else:
253
  cap_feats, cap_mask = encode_prompt([cap] + [""], text_encoder, tokenizer, 0.0)
254
+
255
  cap_mask = cap_mask.to(cap_feats.device)
256
 
257
  model_kwargs = dict(
 
299
  img = to_pil_image(samples[0, :].float())
300
  print("> generated image, done.")
301
 
302
+ # if response_queue is not None:
303
+ # response_queue.put((img, metadata))
304
+ return img, metadata
305
  except Exception:
306
  print(traceback.format_exc())
307
+ return ModelFailure()
308
+ # response_queue.put(ModelFailure())
309
 
310
 
311
  def none_or_str(value):
 
392
 
393
  master_port = find_free_port()
394
  #mp.set_start_method("fork")
395
+ # processes = []
396
+ # request_queues = []
397
+ # response_queue = mp.Queue()
398
+ # mp_barrier = mp.Barrier(args.num_gpus + 1)
399
+ # for i in range(args.num_gpus):
400
+ # request_queues.append(mp.Queue())
401
+ # p = mp.Process(
402
+ # target=model_main,
403
+ # args=(
404
+ # args,
405
+ # master_port,
406
+ # i,
407
+ # request_queues[i],
408
+ # response_queue if i == 0 else None,
409
+ # mp_barrier,
410
+ # ),
411
+ # )
412
+ # p.start()
413
+ # processes.append(p)
414
+
415
+ model_main(args, master_port, 0)
416
 
417
  description = args.ckpt.split('/')[-1]
418
  #"""
 
557
  ) # noqa
558
 
559
  @spaces.GPU(duration=200)
560
+ def on_submit(*infer_args):
561
+ # for q in request_queues:
562
+ # q.put(args)
563
+ # result = response_queue.get()
564
+ # if isinstance(result, ModelFailure):
565
+ # raise RuntimeError
566
+ # img, metadata = result
567
+ result = inference(args, infer_args, text_encoder, tokenizer, vae, model)
568
  if isinstance(result, ModelFailure):
569
+ raise RuntimeError("Model failed to generate the image.")
 
570
 
571
+ return result
572
 
573
  submit_btn.click(
574
  on_submit,