Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -70,8 +70,6 @@ def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompt
|
|
70 |
truncation=True,
|
71 |
return_tensors="pt",
|
72 |
)
|
73 |
-
for name, param in text_encoder.named_parameters():
|
74 |
-
print(name, param.device)
|
75 |
|
76 |
print(f"Text Encoder Device: {text_encoder.device}")
|
77 |
text_input_ids = text_inputs.input_ids.cuda()
|
@@ -90,7 +88,7 @@ def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompt
|
|
90 |
|
91 |
|
92 |
@torch.no_grad()
|
93 |
-
def model_main(args, master_port, rank
|
94 |
# import here to avoid huggingface Tokenizer parallelism warnings
|
95 |
from diffusers.models import AutoencoderKL
|
96 |
from transformers import AutoModel, AutoTokenizer
|
@@ -106,10 +104,10 @@ def model_main(args, master_port, rank, request_queue, response_queue, mp_barrie
|
|
106 |
# Override the built-in print with the new version
|
107 |
builtins.print = print
|
108 |
|
109 |
-
os.environ["MASTER_PORT"] = str(master_port)
|
110 |
-
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
111 |
-
os.environ["RANK"] = str(rank)
|
112 |
-
os.environ["WORLD_SIZE"] = str(args.num_gpus)
|
113 |
|
114 |
|
115 |
train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
|
@@ -159,8 +157,12 @@ def model_main(args, master_port, rank, request_queue, response_queue, mp_barrie
|
|
159 |
ckpt = torch.load(ckpt_path, map_location="cuda")
|
160 |
model.load_state_dict(ckpt, strict=True)
|
161 |
print('load model finish')
|
162 |
-
|
|
|
|
|
163 |
|
|
|
|
|
164 |
with torch.autocast("cuda", dtype):
|
165 |
while True:
|
166 |
(
|
@@ -178,7 +180,7 @@ def model_main(args, master_port, rank, request_queue, response_queue, mp_barrie
|
|
178 |
scaling_method,
|
179 |
scaling_watershed,
|
180 |
proportional_attn,
|
181 |
-
) =
|
182 |
|
183 |
|
184 |
system_prompt = system_type
|
@@ -243,13 +245,13 @@ def model_main(args, master_port, rank, request_queue, response_queue, mp_barrie
|
|
243 |
torch.random.manual_seed(int(seed))
|
244 |
z = torch.randn([1, 16, latent_h, latent_w], device="cuda").to(dtype)
|
245 |
z = z.repeat(2, 1, 1, 1)
|
246 |
-
|
247 |
with torch.no_grad():
|
248 |
if neg_cap != "":
|
249 |
cap_feats, cap_mask = encode_prompt([cap] + [neg_cap], text_encoder, tokenizer, 0.0)
|
250 |
else:
|
251 |
cap_feats, cap_mask = encode_prompt([cap] + [""], text_encoder, tokenizer, 0.0)
|
252 |
-
|
253 |
cap_mask = cap_mask.to(cap_feats.device)
|
254 |
|
255 |
model_kwargs = dict(
|
@@ -297,12 +299,13 @@ def model_main(args, master_port, rank, request_queue, response_queue, mp_barrie
|
|
297 |
img = to_pil_image(samples[0, :].float())
|
298 |
print("> generated image, done.")
|
299 |
|
300 |
-
if response_queue is not None:
|
301 |
-
|
302 |
-
|
303 |
except Exception:
|
304 |
print(traceback.format_exc())
|
305 |
-
|
|
|
306 |
|
307 |
|
308 |
def none_or_str(value):
|
@@ -389,25 +392,27 @@ def main():
|
|
389 |
|
390 |
master_port = find_free_port()
|
391 |
#mp.set_start_method("fork")
|
392 |
-
processes = []
|
393 |
-
request_queues = []
|
394 |
-
response_queue = mp.Queue()
|
395 |
-
mp_barrier = mp.Barrier(args.num_gpus + 1)
|
396 |
-
for i in range(args.num_gpus):
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
|
|
|
|
411 |
|
412 |
description = args.ckpt.split('/')[-1]
|
413 |
#"""
|
@@ -552,15 +557,18 @@ def main():
|
|
552 |
) # noqa
|
553 |
|
554 |
@spaces.GPU(duration=200)
|
555 |
-
def on_submit(*
|
556 |
-
for q in request_queues:
|
557 |
-
|
558 |
-
result = response_queue.get()
|
|
|
|
|
|
|
|
|
559 |
if isinstance(result, ModelFailure):
|
560 |
-
raise RuntimeError
|
561 |
-
img, metadata = result
|
562 |
|
563 |
-
return
|
564 |
|
565 |
submit_btn.click(
|
566 |
on_submit,
|
|
|
70 |
truncation=True,
|
71 |
return_tensors="pt",
|
72 |
)
|
|
|
|
|
73 |
|
74 |
print(f"Text Encoder Device: {text_encoder.device}")
|
75 |
text_input_ids = text_inputs.input_ids.cuda()
|
|
|
88 |
|
89 |
|
90 |
@torch.no_grad()
|
91 |
+
def model_main(args, master_port, rank):
|
92 |
# import here to avoid huggingface Tokenizer parallelism warnings
|
93 |
from diffusers.models import AutoencoderKL
|
94 |
from transformers import AutoModel, AutoTokenizer
|
|
|
104 |
# Override the built-in print with the new version
|
105 |
builtins.print = print
|
106 |
|
107 |
+
# os.environ["MASTER_PORT"] = str(master_port)
|
108 |
+
# os.environ["MASTER_ADDR"] = "127.0.0.1"
|
109 |
+
# os.environ["RANK"] = str(rank)
|
110 |
+
# os.environ["WORLD_SIZE"] = str(args.num_gpus)
|
111 |
|
112 |
|
113 |
train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
|
|
|
157 |
ckpt = torch.load(ckpt_path, map_location="cuda")
|
158 |
model.load_state_dict(ckpt, strict=True)
|
159 |
print('load model finish')
|
160 |
+
|
161 |
+
return text_encoder, tokenizer, vae, model
|
162 |
+
|
163 |
|
164 |
+
@torch.no_grad()
|
165 |
+
def inference(args, infer_args, text_encoder, tokenizer, vae, model)
|
166 |
with torch.autocast("cuda", dtype):
|
167 |
while True:
|
168 |
(
|
|
|
180 |
scaling_method,
|
181 |
scaling_watershed,
|
182 |
proportional_attn,
|
183 |
+
) = infer_args
|
184 |
|
185 |
|
186 |
system_prompt = system_type
|
|
|
245 |
torch.random.manual_seed(int(seed))
|
246 |
z = torch.randn([1, 16, latent_h, latent_w], device="cuda").to(dtype)
|
247 |
z = z.repeat(2, 1, 1, 1)
|
248 |
+
|
249 |
with torch.no_grad():
|
250 |
if neg_cap != "":
|
251 |
cap_feats, cap_mask = encode_prompt([cap] + [neg_cap], text_encoder, tokenizer, 0.0)
|
252 |
else:
|
253 |
cap_feats, cap_mask = encode_prompt([cap] + [""], text_encoder, tokenizer, 0.0)
|
254 |
+
|
255 |
cap_mask = cap_mask.to(cap_feats.device)
|
256 |
|
257 |
model_kwargs = dict(
|
|
|
299 |
img = to_pil_image(samples[0, :].float())
|
300 |
print("> generated image, done.")
|
301 |
|
302 |
+
# if response_queue is not None:
|
303 |
+
# response_queue.put((img, metadata))
|
304 |
+
return img, metadata
|
305 |
except Exception:
|
306 |
print(traceback.format_exc())
|
307 |
+
return ModelFailure()
|
308 |
+
# response_queue.put(ModelFailure())
|
309 |
|
310 |
|
311 |
def none_or_str(value):
|
|
|
392 |
|
393 |
master_port = find_free_port()
|
394 |
#mp.set_start_method("fork")
|
395 |
+
# processes = []
|
396 |
+
# request_queues = []
|
397 |
+
# response_queue = mp.Queue()
|
398 |
+
# mp_barrier = mp.Barrier(args.num_gpus + 1)
|
399 |
+
# for i in range(args.num_gpus):
|
400 |
+
# request_queues.append(mp.Queue())
|
401 |
+
# p = mp.Process(
|
402 |
+
# target=model_main,
|
403 |
+
# args=(
|
404 |
+
# args,
|
405 |
+
# master_port,
|
406 |
+
# i,
|
407 |
+
# request_queues[i],
|
408 |
+
# response_queue if i == 0 else None,
|
409 |
+
# mp_barrier,
|
410 |
+
# ),
|
411 |
+
# )
|
412 |
+
# p.start()
|
413 |
+
# processes.append(p)
|
414 |
+
|
415 |
+
model_main(args, master_port, 0)
|
416 |
|
417 |
description = args.ckpt.split('/')[-1]
|
418 |
#"""
|
|
|
557 |
) # noqa
|
558 |
|
559 |
@spaces.GPU(duration=200)
|
560 |
+
def on_submit(*infer_args):
|
561 |
+
# for q in request_queues:
|
562 |
+
# q.put(args)
|
563 |
+
# result = response_queue.get()
|
564 |
+
# if isinstance(result, ModelFailure):
|
565 |
+
# raise RuntimeError
|
566 |
+
# img, metadata = result
|
567 |
+
result = inference(args, infer_args, text_encoder, tokenizer, vae, model)
|
568 |
if isinstance(result, ModelFailure):
|
569 |
+
raise RuntimeError("Model failed to generate the image.")
|
|
|
570 |
|
571 |
+
return result
|
572 |
|
573 |
submit_btn.click(
|
574 |
on_submit,
|