Spaces:
Running
on
L4
Running
on
L4
Better init event waiting
Browse files- app.py +0 -3
- tools/llama/generate.py +4 -4
app.py
CHANGED
|
@@ -306,7 +306,6 @@ if __name__ == "__main__":
|
|
| 306 |
args.vqgan_config_name = "vqgan_pretrain"
|
| 307 |
|
| 308 |
logger.info("Loading Llama model...")
|
| 309 |
-
init_event = threading.Event()
|
| 310 |
llama_queue = launch_thread_safe_queue(
|
| 311 |
config_name=args.llama_config_name,
|
| 312 |
checkpoint_path=args.llama_checkpoint_path,
|
|
@@ -314,10 +313,8 @@ if __name__ == "__main__":
|
|
| 314 |
precision=args.precision,
|
| 315 |
max_length=args.max_length,
|
| 316 |
compile=args.compile,
|
| 317 |
-
init_event=init_event,
|
| 318 |
)
|
| 319 |
llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
|
| 320 |
-
init_event.wait()
|
| 321 |
logger.info("Llama model loaded, loading VQ-GAN model...")
|
| 322 |
|
| 323 |
vqgan_model = load_vqgan_model(
|
|
|
|
| 306 |
args.vqgan_config_name = "vqgan_pretrain"
|
| 307 |
|
| 308 |
logger.info("Loading Llama model...")
|
|
|
|
| 309 |
llama_queue = launch_thread_safe_queue(
|
| 310 |
config_name=args.llama_config_name,
|
| 311 |
checkpoint_path=args.llama_checkpoint_path,
|
|
|
|
| 313 |
precision=args.precision,
|
| 314 |
max_length=args.max_length,
|
| 315 |
compile=args.compile,
|
|
|
|
| 316 |
)
|
| 317 |
llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
|
|
|
|
| 318 |
logger.info("Llama model loaded, loading VQ-GAN model...")
|
| 319 |
|
| 320 |
vqgan_model = load_vqgan_model(
|
tools/llama/generate.py
CHANGED
|
@@ -600,6 +600,7 @@ def generate_long(
|
|
| 600 |
yield all_codes
|
| 601 |
|
| 602 |
|
|
|
|
| 603 |
def launch_thread_safe_queue(
|
| 604 |
config_name,
|
| 605 |
checkpoint_path,
|
|
@@ -607,17 +608,15 @@ def launch_thread_safe_queue(
|
|
| 607 |
precision,
|
| 608 |
max_length,
|
| 609 |
compile=False,
|
| 610 |
-
init_event=None,
|
| 611 |
):
|
| 612 |
input_queue = queue.Queue()
|
|
|
|
| 613 |
|
| 614 |
def worker():
|
| 615 |
model, decode_one_token = load_model(
|
| 616 |
config_name, checkpoint_path, device, precision, max_length, compile=compile
|
| 617 |
)
|
| 618 |
-
|
| 619 |
-
if init_event is not None:
|
| 620 |
-
init_event.set()
|
| 621 |
|
| 622 |
while True:
|
| 623 |
item = input_queue.get()
|
|
@@ -641,6 +640,7 @@ def launch_thread_safe_queue(
|
|
| 641 |
event.set()
|
| 642 |
|
| 643 |
threading.Thread(target=worker, daemon=True).start()
|
|
|
|
| 644 |
|
| 645 |
return input_queue
|
| 646 |
|
|
|
|
| 600 |
yield all_codes
|
| 601 |
|
| 602 |
|
| 603 |
+
|
| 604 |
def launch_thread_safe_queue(
|
| 605 |
config_name,
|
| 606 |
checkpoint_path,
|
|
|
|
| 608 |
precision,
|
| 609 |
max_length,
|
| 610 |
compile=False,
|
|
|
|
| 611 |
):
|
| 612 |
input_queue = queue.Queue()
|
| 613 |
+
init_event = threading.Event()
|
| 614 |
|
| 615 |
def worker():
|
| 616 |
model, decode_one_token = load_model(
|
| 617 |
config_name, checkpoint_path, device, precision, max_length, compile=compile
|
| 618 |
)
|
| 619 |
+
init_event.set()
|
|
|
|
|
|
|
| 620 |
|
| 621 |
while True:
|
| 622 |
item = input_queue.get()
|
|
|
|
| 640 |
event.set()
|
| 641 |
|
| 642 |
threading.Thread(target=worker, daemon=True).start()
|
| 643 |
+
init_event.wait()
|
| 644 |
|
| 645 |
return input_queue
|
| 646 |
|