Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Add GSM8K (#27)
Browse files- Add GSM8K (900a631539af9b1ea4ccf861e170bcdeae8d46fe)
- Merge branch 'main' into pr/27 (e82a7af1535e766b16314ac4eefb0d2fa1fbaee4)
- Delete gsm8k yamls (f38163c1a9a66242aa5baf0ba91bd786b8d802d3)
- Fix some bugs (9ffef81821400a94b5d4c08eddb5268944b26e7f)
- Fix bugs on wrappers and add quantization requirement (28b60907c7a9ed112ee151c6eadb22d1e7074116)
- Fix bugs in gsm8k (4045483a84607da8b1c2505dc7f1ba2bdd407f47)
- backend-cli.py +68 -42
 - requirements.txt +2 -1
 - src/backend/envs.py +1 -0
 - src/backend/hflm_with_measurement.py +50 -21
 - src/backend/tasks/gsm8k/gsm8k-custom.yaml +44 -0
 - src/display/utils.py +1 -0
 - src/submission/check_validity.py +2 -1
 - src/utils.py +104 -3
 
    	
        backend-cli.py
    CHANGED
    
    | 
         @@ -17,7 +17,7 @@ from src.backend.manage_requests import EvalRequest 
     | 
|
| 17 | 
         
             
            from src.leaderboard.read_evals import EvalResult
         
     | 
| 18 | 
         | 
| 19 | 
         
             
            from src.envs import QUEUE_REPO, RESULTS_REPO, API, DEBUG_QUEUE_REPO, DEBUG_RESULTS_REPO
         
     | 
| 20 | 
         
            -
            from src.utils import my_snapshot_download, analyze_gpu_stats, parse_nvidia_smi, monitor_gpus
         
     | 
| 21 | 
         | 
| 22 | 
         
             
            from src.leaderboard.read_evals import get_raw_eval_results
         
     | 
| 23 | 
         | 
| 
         @@ -28,6 +28,8 @@ import time 
     | 
|
| 28 | 
         
             
            import pprint
         
     | 
| 29 | 
         
             
            import logging
         
     | 
| 30 | 
         | 
| 
         | 
|
| 
         | 
|
| 31 | 
         | 
| 32 | 
         
             
            # Configure the root logger
         
     | 
| 33 | 
         
             
            logging.basicConfig(
         
     | 
| 
         @@ -42,6 +44,20 @@ eval_logger = logging.getLogger("lm-eval") 
     | 
|
| 42 | 
         
             
            # Explicitly set the level for 'lm-eval' logger to WARNING
         
     | 
| 43 | 
         
             
            eval_logger.setLevel(logging.WARNING)
         
     | 
| 44 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 45 | 
         | 
| 46 | 
         
             
            def my_set_eval_request(api, eval_request, set_to_status, hf_repo, local_dir):
         
     | 
| 47 | 
         
             
                for i in range(10):
         
     | 
| 
         @@ -126,9 +142,6 @@ def request_to_result_name(request: EvalRequest) -> str: 
     | 
|
| 126 | 
         
             
            def process_evaluation(task: Task, eval_request: EvalRequest, limit: Optional[int] = None) -> dict:
         
     | 
| 127 | 
         
             
                batch_size = 1
         
     | 
| 128 | 
         
             
                batch_size = eval_request.batch_size
         
     | 
| 129 | 
         
            -
                
         
     | 
| 130 | 
         
            -
                if args.debug:
         
     | 
| 131 | 
         
            -
                    RESULTS_REPO = DEBUG_RESULTS_REPO
         
     | 
| 132 | 
         | 
| 133 | 
         
             
                init_gpu_info = analyze_gpu_stats(parse_nvidia_smi())
         
     | 
| 134 | 
         
             
                # if init_gpu_info['Mem(M)'] > 500:
         
     | 
| 
         @@ -137,6 +150,12 @@ def process_evaluation(task: Task, eval_request: EvalRequest, limit: Optional[in 
     | 
|
| 137 | 
         
             
                stop_event = threading.Event()
         
     | 
| 138 | 
         
             
                monitor_thread = threading.Thread(target=monitor_gpus, args=(stop_event, 5, gpu_stats_list))
         
     | 
| 139 | 
         
             
                monitor_thread.start()
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 140 | 
         | 
| 141 | 
         
             
                try:
         
     | 
| 142 | 
         
             
                    results = run_evaluation(
         
     | 
| 
         @@ -198,6 +217,8 @@ def process_evaluation(task: Task, eval_request: EvalRequest, limit: Optional[in 
     | 
|
| 198 | 
         
             
                    repo_id=RESULTS_REPO,
         
     | 
| 199 | 
         
             
                    repo_type="dataset",
         
     | 
| 200 | 
         
             
                )
         
     | 
| 
         | 
|
| 
         | 
|
| 201 | 
         
             
                return results
         
     | 
| 202 | 
         | 
| 203 | 
         | 
| 
         @@ -366,21 +387,7 @@ def maybe_refresh_results(thr: int, hard_task_lst: Optional[list[str]] = None) - 
     | 
|
| 366 | 
         | 
| 367 | 
         
             
                return False
         
     | 
| 368 | 
         | 
| 369 | 
         
            -
             
     | 
| 370 | 
         
            -
            def get_gpu_details():
         
     | 
| 371 | 
         
            -
                gpus = GPUtil.getGPUs()
         
     | 
| 372 | 
         
            -
                gpu = gpus[0]
         
     | 
| 373 | 
         
            -
                name = gpu.name.replace(" ", "-")
         
     | 
| 374 | 
         
            -
                # Convert memory from MB to GB and round to nearest whole number
         
     | 
| 375 | 
         
            -
                memory_gb = round(gpu.memoryTotal / 1024)
         
     | 
| 376 | 
         
            -
                memory = f"{memory_gb}GB"
         
     | 
| 377 | 
         
            -
                formatted_name = f"{name}-{memory}"
         
     | 
| 378 | 
         
            -
                return formatted_name
         
     | 
| 379 | 
         
            -
             
     | 
| 380 | 
         
             
            def process_pending_requests() -> bool:
         
     | 
| 381 | 
         
            -
                if args.debug:
         
     | 
| 382 | 
         
            -
                    QUEUE_REPO = DEBUG_QUEUE_REPO
         
     | 
| 383 | 
         
            -
                    
         
     | 
| 384 | 
         
             
                sanity_checks()
         
     | 
| 385 | 
         
             
                print("Processing pending requests")
         
     | 
| 386 | 
         
             
                current_pending_status = [PENDING_STATUS]
         
     | 
| 
         @@ -443,13 +450,14 @@ def get_args(): 
     | 
|
| 443 | 
         
             
                parser = argparse.ArgumentParser(description="Run the backend")
         
     | 
| 444 | 
         
             
                parser.add_argument("--debug", action="store_true", help="Run in debug mode")
         
     | 
| 445 | 
         
             
                # debug parameters
         
     | 
| 446 | 
         
            -
                parser.add_argument("--task", type=str, default="selfcheckgpt,mmlu", help="Task to debug")
         
     | 
| 447 | 
         
             
                parser.add_argument("--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1,mistralai/Mixtral-8x7B-v0.1", help="Model to debug")
         
     | 
| 448 | 
         
             
                parser.add_argument("--precision", type=str, default="float32,float16,8bit,4bit", help="Precision to debug")
         
     | 
| 449 | 
         
             
                parser.add_argument("--inference-framework", type=str, default="hf-chat", help="Inference framework to debug")
         
     | 
| 450 | 
         
             
                parser.add_argument("--limit", type=int, default=None, help="Limit for the number of samples")
         
     | 
| 451 | 
         
             
                parser.add_argument("--gpu-type", type=str, default="NVIDIA-A100-PCIe-80GB", 
         
     | 
| 452 | 
         
             
                                    help="GPU type. NVIDIA-A100-PCIe-80GB; NVIDIA-RTX-A5000-24GB; NVIDIA-H100-PCIe-80GB")
         
     | 
| 
         | 
|
| 453 | 
         
             
                return parser.parse_args()
         
     | 
| 454 | 
         | 
| 455 | 
         | 
| 
         @@ -457,7 +465,7 @@ if __name__ == "__main__": 
     | 
|
| 457 | 
         
             
                args = get_args()
         
     | 
| 458 | 
         
             
                local_debug = args.debug
         
     | 
| 459 | 
         
             
                # debug specific task by ping
         
     | 
| 460 | 
         
            -
                if local_debug:
         
     | 
| 461 | 
         
             
                    # debug_model_names = [args.model]  # Use model from arguments
         
     | 
| 462 | 
         
             
                    # debug_task_name = [args.task]  # Use task from arguments
         
     | 
| 463 | 
         
             
                    debug_model_names = args.model.split(",")
         
     | 
| 
         @@ -471,42 +479,60 @@ if __name__ == "__main__": 
     | 
|
| 471 | 
         
             
                                task_name = task.benchmark
         
     | 
| 472 | 
         
             
                                if task_name not in debug_task_name:
         
     | 
| 473 | 
         
             
                                    continue
         
     | 
| 474 | 
         
            -
                                try:
         
     | 
| 475 | 
         
            -
             
     | 
| 476 | 
         
            -
             
     | 
| 477 | 
         
            -
             
     | 
| 478 | 
         
            -
             
     | 
| 479 | 
         
            -
             
     | 
| 480 | 
         
            -
             
     | 
| 481 | 
         
            -
             
     | 
| 482 | 
         
            -
             
     | 
| 483 | 
         
            -
             
     | 
| 484 | 
         
            -
             
     | 
| 485 | 
         
            -
             
     | 
| 486 | 
         
            -
             
     | 
| 487 | 
         
            -
             
     | 
| 488 | 
         
            -
             
     | 
| 489 | 
         
            -
                                except Exception as e:
         
     | 
| 490 | 
         
            -
             
     | 
| 491 | 
         
            -
                 
     | 
| 
         | 
|
| 
         | 
|
| 492 | 
         
             
                    while True:
         
     | 
| 493 | 
         
             
                        res = False
         
     | 
| 494 | 
         
            -
             
     | 
| 495 | 
         
             
                        # if random.randint(0, 10) == 0:
         
     | 
| 496 | 
         
             
                        res = process_pending_requests()
         
     | 
| 497 | 
         
             
                        print(f"waiting for 60 seconds")
         
     | 
| 498 | 
         
             
                        time.sleep(60)
         
     | 
| 499 | 
         
            -
             
     | 
| 500 | 
         
             
                        # if res is False:
         
     | 
| 501 | 
         
             
                        #     if random.randint(0, 5) == 0:
         
     | 
| 502 | 
         
             
                        #         res = maybe_refresh_results(100)
         
     | 
| 503 | 
         
             
                        #     else:
         
     | 
| 504 | 
         
             
                        #         res = process_finished_requests(100)
         
     | 
| 505 | 
         
            -
             
     | 
| 506 | 
         
             
                        # time.sleep(60)
         
     | 
| 507 | 
         
            -
             
     | 
| 508 | 
         
             
                        # if res is False:
         
     | 
| 509 | 
         
             
                        #     if random.randint(0, 5) == 0:
         
     | 
| 510 | 
         
             
                        #         res = maybe_refresh_results(0)
         
     | 
| 511 | 
         
             
                        #     else:
         
     | 
| 512 | 
         
             
                        #         res = process_finished_requests(0)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 17 | 
         
             
            from src.leaderboard.read_evals import EvalResult
         
     | 
| 18 | 
         | 
| 19 | 
         
             
            from src.envs import QUEUE_REPO, RESULTS_REPO, API, DEBUG_QUEUE_REPO, DEBUG_RESULTS_REPO
         
     | 
| 20 | 
         
            +
            from src.utils import my_snapshot_download, analyze_gpu_stats, parse_nvidia_smi, monitor_gpus, get_gpu_details
         
     | 
| 21 | 
         | 
| 22 | 
         
             
            from src.leaderboard.read_evals import get_raw_eval_results
         
     | 
| 23 | 
         | 
| 
         | 
|
| 28 | 
         
             
            import pprint
         
     | 
| 29 | 
         
             
            import logging
         
     | 
| 30 | 
         | 
| 31 | 
         
            +
            from lm_eval.filters.extraction import RegexFilter
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         | 
| 34 | 
         
             
            # Configure the root logger
         
     | 
| 35 | 
         
             
            logging.basicConfig(
         
     | 
| 
         | 
|
| 44 | 
         
             
            # Explicitly set the level for 'lm-eval' logger to WARNING
         
     | 
| 45 | 
         
             
            eval_logger.setLevel(logging.WARNING)
         
     | 
| 46 | 
         | 
| 47 | 
         
            +
            def tuple_input_decorator(func):
         
     | 
| 48 | 
         
            +
                def wrapper(self, resps, docs):
         
     | 
| 49 | 
         
            +
                    stripped_resps = [[resp_data[0] for resp_data in group] for group in resps]
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                    filtered_resps = func(self, stripped_resps, docs)
         
     | 
| 52 | 
         
            +
                    
         
     | 
| 53 | 
         
            +
                    combined_resps = []
         
     | 
| 54 | 
         
            +
                    for original_group, new_group in zip(resps, filtered_resps):
         
     | 
| 55 | 
         
            +
                        combined_group = [(new_resp,) + rest_of_data[1:] for new_resp, rest_of_data in zip(new_group, original_group)]
         
     | 
| 56 | 
         
            +
                        combined_resps.append(combined_group)
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    return combined_resps
         
     | 
| 59 | 
         
            +
                return wrapper
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         | 
| 62 | 
         
             
            def my_set_eval_request(api, eval_request, set_to_status, hf_repo, local_dir):
         
     | 
| 63 | 
         
             
                for i in range(10):
         
     | 
| 
         | 
|
| 142 | 
         
             
            def process_evaluation(task: Task, eval_request: EvalRequest, limit: Optional[int] = None) -> dict:
         
     | 
| 143 | 
         
             
                batch_size = 1
         
     | 
| 144 | 
         
             
                batch_size = eval_request.batch_size
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 145 | 
         | 
| 146 | 
         
             
                init_gpu_info = analyze_gpu_stats(parse_nvidia_smi())
         
     | 
| 147 | 
         
             
                # if init_gpu_info['Mem(M)'] > 500:
         
     | 
| 
         | 
|
| 150 | 
         
             
                stop_event = threading.Event()
         
     | 
| 151 | 
         
             
                monitor_thread = threading.Thread(target=monitor_gpus, args=(stop_event, 5, gpu_stats_list))
         
     | 
| 152 | 
         
             
                monitor_thread.start()
         
     | 
| 153 | 
         
            +
                
         
     | 
| 154 | 
         
            +
                original_apply = RegexFilter.apply
         
     | 
| 155 | 
         
            +
                if task.benchmark in ["gsm8k", "gsm8k_cot", "gsm8k_cot_self_consistency", "gsm8k_custom"]:
         
     | 
| 156 | 
         
            +
                    RegexFilter.apply = tuple_input_decorator(RegexFilter.apply)
         
     | 
| 157 | 
         
            +
                else:
         
     | 
| 158 | 
         
            +
                    RegexFilter.apply = original_apply
         
     | 
| 159 | 
         | 
| 160 | 
         
             
                try:
         
     | 
| 161 | 
         
             
                    results = run_evaluation(
         
     | 
| 
         | 
|
| 217 | 
         
             
                    repo_id=RESULTS_REPO,
         
     | 
| 218 | 
         
             
                    repo_type="dataset",
         
     | 
| 219 | 
         
             
                )
         
     | 
| 220 | 
         
            +
                
         
     | 
| 221 | 
         
            +
                RegexFilter.apply = original_apply
         
     | 
| 222 | 
         
             
                return results
         
     | 
| 223 | 
         | 
| 224 | 
         | 
| 
         | 
|
| 387 | 
         | 
| 388 | 
         
             
                return False
         
     | 
| 389 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 390 | 
         
             
            def process_pending_requests() -> bool:
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 391 | 
         
             
                sanity_checks()
         
     | 
| 392 | 
         
             
                print("Processing pending requests")
         
     | 
| 393 | 
         
             
                current_pending_status = [PENDING_STATUS]
         
     | 
| 
         | 
|
| 450 | 
         
             
                parser = argparse.ArgumentParser(description="Run the backend")
         
     | 
| 451 | 
         
             
                parser.add_argument("--debug", action="store_true", help="Run in debug mode")
         
     | 
| 452 | 
         
             
                # debug parameters
         
     | 
| 453 | 
         
            +
                parser.add_argument("--task", type=str, default="selfcheckgpt,mmlu, gsm8k", help="Task to debug")
         
     | 
| 454 | 
         
             
                parser.add_argument("--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1,mistralai/Mixtral-8x7B-v0.1", help="Model to debug")
         
     | 
| 455 | 
         
             
                parser.add_argument("--precision", type=str, default="float32,float16,8bit,4bit", help="Precision to debug")
         
     | 
| 456 | 
         
             
                parser.add_argument("--inference-framework", type=str, default="hf-chat", help="Inference framework to debug")
         
     | 
| 457 | 
         
             
                parser.add_argument("--limit", type=int, default=None, help="Limit for the number of samples")
         
     | 
| 458 | 
         
             
                parser.add_argument("--gpu-type", type=str, default="NVIDIA-A100-PCIe-80GB", 
         
     | 
| 459 | 
         
             
                                    help="GPU type. NVIDIA-A100-PCIe-80GB; NVIDIA-RTX-A5000-24GB; NVIDIA-H100-PCIe-80GB")
         
     | 
| 460 | 
         
            +
                parser.add_argument("--debug_repo", action="store_true", help="Use debug repo")
         
     | 
| 461 | 
         
             
                return parser.parse_args()
         
     | 
| 462 | 
         | 
| 463 | 
         | 
| 
         | 
|
| 465 | 
         
             
                args = get_args()
         
     | 
| 466 | 
         
             
                local_debug = args.debug
         
     | 
| 467 | 
         
             
                # debug specific task by ping
         
     | 
| 468 | 
         
            +
                if local_debug and not args.debug_repo:
         
     | 
| 469 | 
         
             
                    # debug_model_names = [args.model]  # Use model from arguments
         
     | 
| 470 | 
         
             
                    # debug_task_name = [args.task]  # Use task from arguments
         
     | 
| 471 | 
         
             
                    debug_model_names = args.model.split(",")
         
     | 
| 
         | 
|
| 479 | 
         
             
                                task_name = task.benchmark
         
     | 
| 480 | 
         
             
                                if task_name not in debug_task_name:
         
     | 
| 481 | 
         
             
                                    continue
         
     | 
| 482 | 
         
            +
                                # try:
         
     | 
| 483 | 
         
            +
                                eval_request = EvalRequest(
         
     | 
| 484 | 
         
            +
                                    model=debug_model_name,
         
     | 
| 485 | 
         
            +
                                    private=False,
         
     | 
| 486 | 
         
            +
                                    status="",
         
     | 
| 487 | 
         
            +
                                    json_filepath="",
         
     | 
| 488 | 
         
            +
                                    precision=precision,  # Use precision from arguments
         
     | 
| 489 | 
         
            +
                                    inference_framework=args.inference_framework,  # Use inference framework from arguments
         
     | 
| 490 | 
         
            +
                                    gpu_type=args.gpu_type
         
     | 
| 491 | 
         
            +
                                )
         
     | 
| 492 | 
         
            +
                                curr_gpu_type = get_gpu_details()
         
     | 
| 493 | 
         
            +
                                if eval_request.gpu_type != curr_gpu_type:
         
     | 
| 494 | 
         
            +
                                    print(f"GPU type mismatch: {eval_request.gpu_type} vs {curr_gpu_type}")
         
     | 
| 495 | 
         
            +
                                    raise Exception("GPU type mismatch")
         
     | 
| 496 | 
         
            +
                                results = process_evaluation(task, eval_request, limit=args.limit)
         
     | 
| 497 | 
         
            +
                                # except Exception as e:
         
     | 
| 498 | 
         
            +
                                #     print(f"debug running error: {e}")
         
     | 
| 499 | 
         
            +
                elif local_debug and args.debug_repo:
         
     | 
| 500 | 
         
            +
                    QUEUE_REPO = DEBUG_QUEUE_REPO
         
     | 
| 501 | 
         
            +
                    RESULTS_REPO = DEBUG_RESULTS_REPO
         
     | 
| 502 | 
         
             
                    while True:
         
     | 
| 503 | 
         
             
                        res = False
         
     | 
| 
         | 
|
| 504 | 
         
             
                        # if random.randint(0, 10) == 0:
         
     | 
| 505 | 
         
             
                        res = process_pending_requests()
         
     | 
| 506 | 
         
             
                        print(f"waiting for 60 seconds")
         
     | 
| 507 | 
         
             
                        time.sleep(60)
         
     | 
| 
         | 
|
| 508 | 
         
             
                        # if res is False:
         
     | 
| 509 | 
         
             
                        #     if random.randint(0, 5) == 0:
         
     | 
| 510 | 
         
             
                        #         res = maybe_refresh_results(100)
         
     | 
| 511 | 
         
             
                        #     else:
         
     | 
| 512 | 
         
             
                        #         res = process_finished_requests(100)
         
     | 
| 
         | 
|
| 513 | 
         
             
                        # time.sleep(60)
         
     | 
| 
         | 
|
| 514 | 
         
             
                        # if res is False:
         
     | 
| 515 | 
         
             
                        #     if random.randint(0, 5) == 0:
         
     | 
| 516 | 
         
             
                        #         res = maybe_refresh_results(0)
         
     | 
| 517 | 
         
             
                        #     else:
         
     | 
| 518 | 
         
             
                        #         res = process_finished_requests(0)
         
     | 
| 519 | 
         
            +
                elif not local_debug and not args.debug_repo:
         
     | 
| 520 | 
         
            +
                    while True:
         
     | 
| 521 | 
         
            +
                       res = False
         
     | 
| 522 | 
         
            +
                       # if random.randint(0, 10) == 0:
         
     | 
| 523 | 
         
            +
                       res = process_pending_requests()
         
     | 
| 524 | 
         
            +
                       print(f"waiting for 60 seconds")
         
     | 
| 525 | 
         
            +
                       time.sleep(60)
         
     | 
| 526 | 
         
            +
                       # if res is False:
         
     | 
| 527 | 
         
            +
                       #     if random.randint(0, 5) == 0:
         
     | 
| 528 | 
         
            +
                       #         res = maybe_refresh_results(100)
         
     | 
| 529 | 
         
            +
                       #     else:
         
     | 
| 530 | 
         
            +
                       #         res = process_finished_requests(100)
         
     | 
| 531 | 
         
            +
                       # time.sleep(60)
         
     | 
| 532 | 
         
            +
                       # if res is False:
         
     | 
| 533 | 
         
            +
                       #     if random.randint(0, 5) == 0:
         
     | 
| 534 | 
         
            +
                       #         res = maybe_refresh_results(0)
         
     | 
| 535 | 
         
            +
                       #     else:
         
     | 
| 536 | 
         
            +
                       #         res = process_finished_requests(0)
         
     | 
| 537 | 
         
            +
                else:
         
     | 
| 538 | 
         
            +
                    raise Exception("Cannot use debug_repo without local debug flag")
         
     | 
    	
        requirements.txt
    CHANGED
    
    | 
         @@ -30,4 +30,5 @@ evaluate 
     | 
|
| 30 | 
         
             
            spacy==3.7.4
         
     | 
| 31 | 
         
             
            selfcheckgpt
         
     | 
| 32 | 
         
             
            immutabledict
         
     | 
| 33 | 
         
            -
            gputil
         
     | 
| 
         | 
| 
         | 
|
| 30 | 
         
             
            spacy==3.7.4
         
     | 
| 31 | 
         
             
            selfcheckgpt
         
     | 
| 32 | 
         
             
            immutabledict
         
     | 
| 33 | 
         
            +
            gputil
         
     | 
| 34 | 
         
            +
            bitsandbytes
         
     | 
    	
        src/backend/envs.py
    CHANGED
    
    | 
         @@ -57,6 +57,7 @@ class Tasks(Enum): 
     | 
|
| 57 | 
         | 
| 58 | 
         
             
                # task20 = Task("race", "acc", "RACE", 0)
         
     | 
| 59 | 
         
             
                task21 = Task("mmlu", "acc", "MMLU", 5)
         
     | 
| 
         | 
|
| 60 | 
         | 
| 61 | 
         | 
| 62 | 
         
             
            EVAL_REQUESTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-queue-bk")
         
     | 
| 
         | 
|
| 57 | 
         | 
| 58 | 
         
             
                # task20 = Task("race", "acc", "RACE", 0)
         
     | 
| 59 | 
         
             
                task21 = Task("mmlu", "acc", "MMLU", 5)
         
     | 
| 60 | 
         
            +
                task22 = Task("gsm8k_custom", "em", "GSM8K", 5)
         
     | 
| 61 | 
         | 
| 62 | 
         | 
| 63 | 
         
             
            EVAL_REQUESTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-queue-bk")
         
     | 
    	
        src/backend/hflm_with_measurement.py
    CHANGED
    
    | 
         @@ -295,6 +295,8 @@ class HFLMWithMeasurement(HFLM): 
     | 
|
| 295 | 
         
             
                    # and we don't want a warning from HF
         
     | 
| 296 | 
         
             
                    generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
         
     | 
| 297 | 
         
             
                    do_sample = generation_kwargs.get("do_sample", None)
         
     | 
| 
         | 
|
| 
         | 
|
| 298 | 
         | 
| 299 | 
         
             
                    # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
         
     | 
| 300 | 
         
             
                    if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
         
     | 
| 
         @@ -302,22 +304,40 @@ class HFLMWithMeasurement(HFLM): 
     | 
|
| 302 | 
         | 
| 303 | 
         
             
                    if do_sample is False and generation_kwargs.get("temperature") == 0.0:
         
     | 
| 304 | 
         
             
                        generation_kwargs.pop("temperature")
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 305 | 
         
             
                    # build stopping criteria
         
     | 
| 306 | 
         
            -
             
     | 
| 307 | 
         
            -
             
     | 
| 308 | 
         
            -
             
     | 
| 309 | 
         
            -
             
     | 
| 310 | 
         
            -
             
     | 
| 311 | 
         
            -
             
     | 
| 312 | 
         
            -
             
     | 
| 313 | 
         
            -
             
     | 
| 314 | 
         
            -
             
     | 
| 315 | 
         
            -
             
     | 
| 316 | 
         
            -
             
     | 
| 317 | 
         
            -
             
     | 
| 318 | 
         
            -
             
     | 
| 319 | 
         
            -
             
     | 
| 320 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 321 | 
         | 
| 322 | 
         
             
                    batch_size = context.shape[0]
         
     | 
| 323 | 
         
             
                    output_length = stop_watch.decoding_iterations
         
     | 
| 
         @@ -408,6 +428,11 @@ class HFLMWithMeasurement(HFLM): 
     | 
|
| 408 | 
         
             
                            until = [eos]
         
     | 
| 409 | 
         
             
                        else:
         
     | 
| 410 | 
         
             
                            until.append(eos)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 411 | 
         
             
                        if "max_gen_toks" in kwargs.keys():
         
     | 
| 412 | 
         
             
                            max_gen_toks = kwargs.pop("max_gen_toks")
         
     | 
| 413 | 
         
             
                        else:
         
     | 
| 
         @@ -427,6 +452,8 @@ class HFLMWithMeasurement(HFLM): 
     | 
|
| 427 | 
         
             
                            left_truncate_len=max_ctx_len,
         
     | 
| 428 | 
         
             
                            truncation=self.truncation,
         
     | 
| 429 | 
         
             
                        )
         
     | 
| 
         | 
|
| 
         | 
|
| 430 | 
         
             
                        context_enc = context_enc.to(self.device)
         
     | 
| 431 | 
         
             
                        attn_masks = attn_masks.to(self.device)
         
     | 
| 432 | 
         | 
| 
         @@ -445,16 +472,18 @@ class HFLMWithMeasurement(HFLM): 
     | 
|
| 445 | 
         
             
                        for cont_toks, context in zip(cont_toks_list, contexts):
         
     | 
| 446 | 
         
             
                            # discard context + left-padding toks if using causal decoder-only LM
         
     | 
| 447 | 
         
             
                            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
         
     | 
| 
         | 
|
| 448 | 
         
             
                                cont_toks = cont_toks[context_enc.shape[1] :]
         
     | 
| 449 | 
         
            -
             
     | 
| 450 | 
         
             
                            s = self.tok_decode(cont_toks)
         
     | 
| 451 | 
         | 
| 452 | 
         
             
                            # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
         
     | 
| 453 | 
         
            -
                             
     | 
| 454 | 
         
            -
                                 
     | 
| 455 | 
         
            -
                                     
     | 
| 456 | 
         
            -
             
     | 
| 457 | 
         
            -
             
     | 
| 
         | 
|
| 458 | 
         | 
| 459 | 
         
             
                            res.append((s, end_to_end_time, prefilling_time, token_per_sec))
         
     | 
| 460 | 
         | 
| 
         | 
|
| 295 | 
         
             
                    # and we don't want a warning from HF
         
     | 
| 296 | 
         
             
                    generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
         
     | 
| 297 | 
         
             
                    do_sample = generation_kwargs.get("do_sample", None)
         
     | 
| 298 | 
         
            +
                    
         
     | 
| 299 | 
         
            +
                    is_gsm8k = generation_kwargs.get("is_gsm8k", False)
         
     | 
| 300 | 
         | 
| 301 | 
         
             
                    # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
         
     | 
| 302 | 
         
             
                    if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
         
     | 
| 
         | 
|
| 304 | 
         | 
| 305 | 
         
             
                    if do_sample is False and generation_kwargs.get("temperature") == 0.0:
         
     | 
| 306 | 
         
             
                        generation_kwargs.pop("temperature")
         
     | 
| 307 | 
         
            +
                    
         
     | 
| 308 | 
         
            +
                    generation_kwargs.pop("is_gsm8k")
         
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
                    if not is_gsm8k:
         
     | 
| 311 | 
         
             
                    # build stopping criteria
         
     | 
| 312 | 
         
            +
                        stopping_criteria = stop_sequences_criteria(
         
     | 
| 313 | 
         
            +
                            self.tokenizer, stop, context.shape[1], context.shape[0]
         
     | 
| 314 | 
         
            +
                        )
         
     | 
| 315 | 
         
            +
                        stop_watch = StopWatch(self.tokenizer)
         
     | 
| 316 | 
         
            +
                        start = time()
         
     | 
| 317 | 
         
            +
                        res = self.model.generate(
         
     | 
| 318 | 
         
            +
                            input_ids=context,
         
     | 
| 319 | 
         
            +
                            max_length=max_length,
         
     | 
| 320 | 
         
            +
                            stopping_criteria=stopping_criteria,
         
     | 
| 321 | 
         
            +
                            pad_token_id=self.tokenizer.pad_token_id,
         
     | 
| 322 | 
         
            +
                            use_cache=True,
         
     | 
| 323 | 
         
            +
                            streamer=stop_watch,
         
     | 
| 324 | 
         
            +
                            **generation_kwargs,
         
     | 
| 325 | 
         
            +
                        )
         
     | 
| 326 | 
         
            +
                        end = time()
         
     | 
| 327 | 
         
            +
                    else:
         
     | 
| 328 | 
         
            +
                        # print("Using GSM8K")
         
     | 
| 329 | 
         
            +
                        stop_watch = StopWatch(self.tokenizer)
         
     | 
| 330 | 
         
            +
                        start = time()
         
     | 
| 331 | 
         
            +
                        res = self.model.generate(
         
     | 
| 332 | 
         
            +
                            input_ids=context,
         
     | 
| 333 | 
         
            +
                            max_length=max_length,
         
     | 
| 334 | 
         
            +
                            eos_token_id=stop,
         
     | 
| 335 | 
         
            +
                            pad_token_id=self.tokenizer.pad_token_id,
         
     | 
| 336 | 
         
            +
                            use_cache=True,
         
     | 
| 337 | 
         
            +
                            streamer=stop_watch,
         
     | 
| 338 | 
         
            +
                            **generation_kwargs,
         
     | 
| 339 | 
         
            +
                        )
         
     | 
| 340 | 
         
            +
                        end = time()
         
     | 
| 341 | 
         | 
| 342 | 
         
             
                    batch_size = context.shape[0]
         
     | 
| 343 | 
         
             
                    output_length = stop_watch.decoding_iterations
         
     | 
| 
         | 
|
| 428 | 
         
             
                            until = [eos]
         
     | 
| 429 | 
         
             
                        else:
         
     | 
| 430 | 
         
             
                            until.append(eos)
         
     | 
| 431 | 
         
            +
                        
         
     | 
| 432 | 
         
            +
                        is_gsm8k = kwargs.get("is_gsm8k", False)
         
     | 
| 433 | 
         
            +
                        if is_gsm8k:
         
     | 
| 434 | 
         
            +
                            until = [self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids("<|eot_id|>")]
         
     | 
| 435 | 
         
            +
                                
         
     | 
| 436 | 
         
             
                        if "max_gen_toks" in kwargs.keys():
         
     | 
| 437 | 
         
             
                            max_gen_toks = kwargs.pop("max_gen_toks")
         
     | 
| 438 | 
         
             
                        else:
         
     | 
| 
         | 
|
| 452 | 
         
             
                            left_truncate_len=max_ctx_len,
         
     | 
| 453 | 
         
             
                            truncation=self.truncation,
         
     | 
| 454 | 
         
             
                        )
         
     | 
| 455 | 
         
            +
                        
         
     | 
| 456 | 
         
            +
                        # print("context: ", self.tok_decode(context_enc[0]))
         
     | 
| 457 | 
         
             
                        context_enc = context_enc.to(self.device)
         
     | 
| 458 | 
         
             
                        attn_masks = attn_masks.to(self.device)
         
     | 
| 459 | 
         | 
| 
         | 
|
| 472 | 
         
             
                        for cont_toks, context in zip(cont_toks_list, contexts):
         
     | 
| 473 | 
         
             
                            # discard context + left-padding toks if using causal decoder-only LM
         
     | 
| 474 | 
         
             
                            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
         
     | 
| 475 | 
         
            +
                                # print("After Generation: ", self.tok_decode(cont_toks))
         
     | 
| 476 | 
         
             
                                cont_toks = cont_toks[context_enc.shape[1] :]
         
     | 
| 477 | 
         
            +
                            
         
     | 
| 478 | 
         
             
                            s = self.tok_decode(cont_toks)
         
     | 
| 479 | 
         | 
| 480 | 
         
             
                            # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
         
     | 
| 481 | 
         
            +
                            if not is_gsm8k:
         
     | 
| 482 | 
         
            +
                                for term in until:
         
     | 
| 483 | 
         
            +
                                    if len(term) > 0:
         
     | 
| 484 | 
         
            +
                                        # ignore '' separator,
         
     | 
| 485 | 
         
            +
                                        # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
         
     | 
| 486 | 
         
            +
                                        s = s.split(term)[0]
         
     | 
| 487 | 
         | 
| 488 | 
         
             
                            res.append((s, end_to_end_time, prefilling_time, token_per_sec))
         
     | 
| 489 | 
         | 
    	
        src/backend/tasks/gsm8k/gsm8k-custom.yaml
    ADDED
    
    | 
         @@ -0,0 +1,44 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            group:
         
     | 
| 2 | 
         
            +
              - math_word_problems
         
     | 
| 3 | 
         
            +
            task: gsm8k_custom
         
     | 
| 4 | 
         
            +
            dataset_path: gsm8k
         
     | 
| 5 | 
         
            +
            dataset_name: main
         
     | 
| 6 | 
         
            +
            output_type: generate_until
         
     | 
| 7 | 
         
            +
            training_split: train
         
     | 
| 8 | 
         
            +
            fewshot_split: train
         
     | 
| 9 | 
         
            +
            test_split: test
         
     | 
| 10 | 
         
            +
            doc_to_text: "Question: {{question}}\nAnswer:"
         
     | 
| 11 | 
         
            +
            doc_to_target: "{{answer}}" #" {{answer.split('### ')[-1].rstrip()}}"
         
     | 
| 12 | 
         
            +
            metric_list:
         
     | 
| 13 | 
         
            +
              - metric: exact_match
         
     | 
| 14 | 
         
            +
                aggregation: mean
         
     | 
| 15 | 
         
            +
                higher_is_better: true
         
     | 
| 16 | 
         
            +
                ignore_case: true
         
     | 
| 17 | 
         
            +
                ignore_punctuation: false
         
     | 
| 18 | 
         
            +
                regexes_to_ignore:
         
     | 
| 19 | 
         
            +
                  - ","
         
     | 
| 20 | 
         
            +
                  - "\\$"
         
     | 
| 21 | 
         
            +
                  - "(?s).*#### "
         
     | 
| 22 | 
         
            +
                  - "\\.$"
         
     | 
| 23 | 
         
            +
            generation_kwargs:
         
     | 
| 24 | 
         
            +
              until:
         
     | 
| 25 | 
         
            +
                - "<|eot_id|>"
         
     | 
| 26 | 
         
            +
              do_sample: false
         
     | 
| 27 | 
         
            +
              temperature: 0.0
         
     | 
| 28 | 
         
            +
              is_gsm8k: true
         
     | 
| 29 | 
         
            +
            repeats: 1
         
     | 
| 30 | 
         
            +
            num_fewshot: 5
         
     | 
| 31 | 
         
            +
            filter_list:
         
     | 
| 32 | 
         
            +
              # - name: "strict-match"
         
     | 
| 33 | 
         
            +
              #   filter:
         
     | 
| 34 | 
         
            +
              #     - function: "regex"
         
     | 
| 35 | 
         
            +
              #       regex_pattern: "#### (\\-?[0-9\\.\\,]+)"
         
     | 
| 36 | 
         
            +
              #     - function: "take_first"
         
     | 
| 37 | 
         
            +
              - name: "flexible-extract"
         
     | 
| 38 | 
         
            +
                filter:
         
     | 
| 39 | 
         
            +
                  - function: "regex"
         
     | 
| 40 | 
         
            +
                    group_select: -1
         
     | 
| 41 | 
         
            +
                    regex_pattern: "(-?[$0-9.,]{2,})|(-?[0-9]+)"
         
     | 
| 42 | 
         
            +
                  - function: "take_first"
         
     | 
| 43 | 
         
            +
            metadata:
         
     | 
| 44 | 
         
            +
              version: 3.0
         
     | 
    	
        src/display/utils.py
    CHANGED
    
    | 
         @@ -75,6 +75,7 @@ class Tasks(Enum): 
     | 
|
| 75 | 
         
             
                # # XXX include me back at some point
         
     | 
| 76 | 
         
             
                selfcheck = Task("selfcheckgpt", "max-selfcheckgpt", "SelfCheckGPT")
         
     | 
| 77 | 
         
             
                mmlu = Task("mmlu", "acc", "MMLU") #MMLU/Acc (5-shot)
         
     | 
| 
         | 
|
| 78 | 
         | 
| 79 | 
         | 
| 80 | 
         
             
            # These classes are for user facing column names,
         
     | 
| 
         | 
|
| 75 | 
         
             
                # # XXX include me back at some point
         
     | 
| 76 | 
         
             
                selfcheck = Task("selfcheckgpt", "max-selfcheckgpt", "SelfCheckGPT")
         
     | 
| 77 | 
         
             
                mmlu = Task("mmlu", "acc", "MMLU") #MMLU/Acc (5-shot)
         
     | 
| 78 | 
         
            +
                gsm8k = Task("gsm8k_custom", "em", "GSM8K") #GSM8K/EM (8-shot)
         
     | 
| 79 | 
         | 
| 80 | 
         | 
| 81 | 
         
             
            # These classes are for user facing column names,
         
     | 
    	
        src/submission/check_validity.py
    CHANGED
    
    | 
         @@ -130,7 +130,8 @@ def already_submitted_models(requested_models_dir: str) -> set[str]: 
     | 
|
| 130 | 
         
             
                                continue
         
     | 
| 131 | 
         
             
                            with open(os.path.join(root, file), "r") as f:
         
     | 
| 132 | 
         
             
                                info = json.load(f)
         
     | 
| 133 | 
         
            -
                                 
     | 
| 
         | 
|
| 134 | 
         | 
| 135 | 
         
             
                                # Select organisation
         
     | 
| 136 | 
         
             
                                if info["model"].count("/") == 0 or "submitted_time" not in info:
         
     | 
| 
         | 
|
| 130 | 
         
             
                                continue
         
     | 
| 131 | 
         
             
                            with open(os.path.join(root, file), "r") as f:
         
     | 
| 132 | 
         
             
                                info = json.load(f)
         
     | 
| 133 | 
         
            +
                                if not info["status"] == "FINISHED" and not info["status"] == "RUNNING":
         
     | 
| 134 | 
         
            +
                                    file_names.append(f"{info['model']}_{info['revision']}_{info['precision']}_{info['inference_framework']}_{info['gpu_type']}")
         
     | 
| 135 | 
         | 
| 136 | 
         
             
                                # Select organisation
         
     | 
| 137 | 
         
             
                                if info["model"].count("/") == 0 or "submitted_time" not in info:
         
     | 
    	
        src/utils.py
    CHANGED
    
    | 
         @@ -3,12 +3,48 @@ from huggingface_hub import snapshot_download 
     | 
|
| 3 | 
         
             
            import subprocess
         
     | 
| 4 | 
         
             
            import re
         
     | 
| 5 | 
         
             
            import os
         
     | 
| 
         | 
|
| 6 | 
         | 
| 7 | 
         
             
            try:
         
     | 
| 8 | 
         
             
                from src.display.utils import GPU_TEMP, GPU_Mem, GPU_Power, GPU_Util, GPU_Name
         
     | 
| 9 | 
         
             
            except:
         
     | 
| 10 | 
         
             
                print("local debug: from display.utils")
         
     | 
| 11 | 
         
             
                from display.utils import GPU_TEMP, GPU_Mem, GPU_Power, GPU_Util, GPU_Name
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 12 | 
         | 
| 13 | 
         
             
            def my_snapshot_download(repo_id, revision, local_dir, repo_type, max_workers):
         
     | 
| 14 | 
         
             
                for i in range(10):
         
     | 
| 
         @@ -52,11 +88,11 @@ def parse_nvidia_smi(): 
     | 
|
| 52 | 
         
             
                        print("Failed to query GPU indices.")
         
     | 
| 53 | 
         
             
                        return []
         
     | 
| 54 | 
         
             
                    gpu_indices = result.stdout.strip().split('\n')
         
     | 
| 55 | 
         
            -
                print(f"gpu_indices: {gpu_indices}")
         
     | 
| 56 | 
         
             
                gpu_stats = []
         
     | 
| 57 | 
         | 
| 58 | 
         
             
                gpu_info_pattern = re.compile(r'(\d+)C\s+P\d+\s+(\d+)W / \d+W\s+\|\s+(\d+)MiB / \d+MiB\s+\|\s+(\d+)%')
         
     | 
| 59 | 
         
            -
                gpu_name_pattern = re.compile(r'NVIDIA\s+([\w\s] 
     | 
| 60 | 
         | 
| 61 | 
         
             
                gpu_name = ""
         
     | 
| 62 | 
         
             
                for index in gpu_indices:
         
     | 
| 
         @@ -80,7 +116,7 @@ def parse_nvidia_smi(): 
     | 
|
| 80 | 
         | 
| 81 | 
         
             
                        if len(gpu_info) >= 4:
         
     | 
| 82 | 
         
             
                            gpu_stats.append(gpu_info)
         
     | 
| 83 | 
         
            -
                print(f"gpu_stats: {gpu_stats}")
         
     | 
| 84 | 
         
             
                gpu_name = f"{len(gpu_stats)}x{gpu_name}"
         
     | 
| 85 | 
         
             
                gpu_stats_total = {
         
     | 
| 86 | 
         
             
                                    GPU_TEMP: 0,
         
     | 
| 
         @@ -131,5 +167,70 @@ def analyze_gpu_stats(stats_list): 
     | 
|
| 131 | 
         | 
| 132 | 
         
             
                return avg_stats
         
     | 
| 133 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 134 | 
         
             
            if __name__ == "__main__":
         
     | 
| 135 | 
         
             
                print(analyze_gpu_stats(parse_nvidia_smi()))
         
     | 
| 
         | 
|
| 3 | 
         
             
            import subprocess
         
     | 
| 4 | 
         
             
            import re
         
     | 
| 5 | 
         
             
            import os
         
     | 
| 6 | 
         
            +
            import GPUtil
         
     | 
| 7 | 
         | 
| 8 | 
         
             
            try:
         
     | 
| 9 | 
         
             
                from src.display.utils import GPU_TEMP, GPU_Mem, GPU_Power, GPU_Util, GPU_Name
         
     | 
| 10 | 
         
             
            except:
         
     | 
| 11 | 
         
             
                print("local debug: from display.utils")
         
     | 
| 12 | 
         
             
                from display.utils import GPU_TEMP, GPU_Mem, GPU_Power, GPU_Util, GPU_Name
         
     | 
| 13 | 
         
            +
                
         
     | 
| 14 | 
         
            +
            MEM_BW_DICT ={
         
     | 
| 15 | 
         
            +
                "NVIDIA-A100-PCIe-80GB": 1935,
         
     | 
| 16 | 
         
            +
                "NVIDIA-A100-SXM-80GB": 2039,
         
     | 
| 17 | 
         
            +
                "NVIDIA-H100-PCIe-80GB": 2039,
         
     | 
| 18 | 
         
            +
                "NVIDIA-RTX-A5000-24GB": 768
         
     | 
| 19 | 
         
            +
            }
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            PEAK_FLOPS_DICT = {
         
     | 
| 22 | 
         
            +
                "float32":{
         
     | 
| 23 | 
         
            +
                    "NVIDIA-A100-PCIe-80GB": 312e12,
         
     | 
| 24 | 
         
            +
                    "NVIDIA-A100-SXM-80GB": 312e12,
         
     | 
| 25 | 
         
            +
                    "NVIDIA-H100-PCIe-80GB": 756e12,
         
     | 
| 26 | 
         
            +
                    "NVIDIA-RTX-A5000-24GB": 222.2e12
         
     | 
| 27 | 
         
            +
                },
         
     | 
| 28 | 
         
            +
                "float16":{
         
     | 
| 29 | 
         
            +
                    "NVIDIA-A100-PCIe-80GB": 624e12,
         
     | 
| 30 | 
         
            +
                    "NVIDIA-A100-SXM-80GB": 624e12,
         
     | 
| 31 | 
         
            +
                    "NVIDIA-H100-PCIe-80GB": 1513e12,
         
     | 
| 32 | 
         
            +
                    "NVIDIA-RTX-A5000-24GB": 444.4e12
         
     | 
| 33 | 
         
            +
                },
         
     | 
| 34 | 
         
            +
                "8bit":{
         
     | 
| 35 | 
         
            +
                    "NVIDIA-A100-PCIe-80GB": 1248e12,
         
     | 
| 36 | 
         
            +
                    "NVIDIA-A100-SXM-80GB": 1248e12,
         
     | 
| 37 | 
         
            +
                    "NVIDIA-H100-PCIe-80GB": 3026e12,
         
     | 
| 38 | 
         
            +
                    "NVIDIA-RTX-A5000-24GB": 889e12
         
     | 
| 39 | 
         
            +
                },
         
     | 
| 40 | 
         
            +
                "4bit": {
         
     | 
| 41 | 
         
            +
                    "NVIDIA-A100-PCIe-80GB": 2496e12,
         
     | 
| 42 | 
         
            +
                    "NVIDIA-A100-SXM-80GB": 2496e12,
         
     | 
| 43 | 
         
            +
                    "NVIDIA-H100-PCIe-80GB": 6052e12,
         
     | 
| 44 | 
         
            +
                    "NVIDIA-RTX-A5000-24GB": 1778e12
         
     | 
| 45 | 
         
            +
                }
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            }
         
     | 
| 48 | 
         | 
| 49 | 
         
             
            def my_snapshot_download(repo_id, revision, local_dir, repo_type, max_workers):
         
     | 
| 50 | 
         
             
                for i in range(10):
         
     | 
| 
         | 
|
| 88 | 
         
             
                        print("Failed to query GPU indices.")
         
     | 
| 89 | 
         
             
                        return []
         
     | 
| 90 | 
         
             
                    gpu_indices = result.stdout.strip().split('\n')
         
     | 
| 91 | 
         
            +
                # print(f"gpu_indices: {gpu_indices}")
         
     | 
| 92 | 
         
             
                gpu_stats = []
         
     | 
| 93 | 
         | 
| 94 | 
         
             
                gpu_info_pattern = re.compile(r'(\d+)C\s+P\d+\s+(\d+)W / \d+W\s+\|\s+(\d+)MiB / \d+MiB\s+\|\s+(\d+)%')
         
     | 
| 95 | 
         
            +
                gpu_name_pattern = re.compile(r'NVIDIA\s+([\w\s]+\d+(?:\s*GB)?)')
         
     | 
| 96 | 
         | 
| 97 | 
         
             
                gpu_name = ""
         
     | 
| 98 | 
         
             
                for index in gpu_indices:
         
     | 
| 
         | 
|
| 116 | 
         | 
| 117 | 
         
             
                        if len(gpu_info) >= 4:
         
     | 
| 118 | 
         
             
                            gpu_stats.append(gpu_info)
         
     | 
| 119 | 
         
            +
                # print(f"gpu_stats: {gpu_stats}")
         
     | 
| 120 | 
         
             
                gpu_name = f"{len(gpu_stats)}x{gpu_name}"
         
     | 
| 121 | 
         
             
                gpu_stats_total = {
         
     | 
| 122 | 
         
             
                                    GPU_TEMP: 0,
         
     | 
| 
         | 
|
| 167 | 
         | 
| 168 | 
         
             
                return avg_stats
         
     | 
| 169 | 
         | 
| 170 | 
         
            +
            def get_gpu_number():
         
     | 
| 171 | 
         
            +
                visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
         
     | 
| 172 | 
         
            +
                if visible_devices is not None:
         
     | 
| 173 | 
         
            +
                    gpu_indices = visible_devices.split(',')
         
     | 
| 174 | 
         
            +
                else:
         
     | 
| 175 | 
         
            +
                    # Query all GPU indices if CUDA_VISIBLE_DEVICES is not set
         
     | 
| 176 | 
         
            +
                    result = subprocess.run(['nvidia-smi', '--query-gpu=index', '--format=csv,noheader'], capture_output=True, text=True)
         
     | 
| 177 | 
         
            +
                    if result.returncode != 0:
         
     | 
| 178 | 
         
            +
                        print("Failed to query GPU indices.")
         
     | 
| 179 | 
         
            +
                        return []
         
     | 
| 180 | 
         
            +
                    gpu_indices = result.stdout.strip().split('\n')
         
     | 
| 181 | 
         
            +
                # print(f"gpu_indices: {gpu_indices}")
         
     | 
| 182 | 
         
            +
                gpu_stats = []
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                gpu_info_pattern = re.compile(r'(\d+)C\s+P\d+\s+(\d+)W / \d+W\s+\|\s+(\d+)MiB / \d+MiB\s+\|\s+(\d+)%')
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                for index in gpu_indices:
         
     | 
| 187 | 
         
            +
                    result = subprocess.run(['nvidia-smi', '-i', index], capture_output=True, text=True)
         
     | 
| 188 | 
         
            +
                    output = result.stdout.strip()
         
     | 
| 189 | 
         
            +
                    lines = output.split("\n")
         
     | 
| 190 | 
         
            +
                    for line in lines:
         
     | 
| 191 | 
         
            +
                        match = gpu_info_pattern.search(line)
         
     | 
| 192 | 
         
            +
                        gpu_info = {}
         
     | 
| 193 | 
         
            +
                        if match:
         
     | 
| 194 | 
         
            +
                            temp, power_usage, mem_usage, gpu_util = map(int, match.groups())
         
     | 
| 195 | 
         
            +
                            gpu_info.update({
         
     | 
| 196 | 
         
            +
                                GPU_TEMP: temp,
         
     | 
| 197 | 
         
            +
                                GPU_Power: power_usage,
         
     | 
| 198 | 
         
            +
                                GPU_Mem: round(mem_usage / 1024, 2),
         
     | 
| 199 | 
         
            +
                                GPU_Util: gpu_util
         
     | 
| 200 | 
         
            +
                            })
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                        if len(gpu_info) >= 4:
         
     | 
| 203 | 
         
            +
                            gpu_stats.append(gpu_info)
         
     | 
| 204 | 
         
            +
                            
         
     | 
| 205 | 
         
            +
                return len(gpu_stats)
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
            def get_gpu_details():
         
     | 
| 208 | 
         
            +
                gpus = GPUtil.getGPUs()
         
     | 
| 209 | 
         
            +
                gpu = gpus[0]
         
     | 
| 210 | 
         
            +
                name = gpu.name.replace(" ", "-")
         
     | 
| 211 | 
         
            +
                # Convert memory from MB to GB and round to nearest whole number
         
     | 
| 212 | 
         
            +
                memory_gb = round(gpu.memoryTotal / 1024)
         
     | 
| 213 | 
         
            +
                memory = f"{memory_gb}GB"
         
     | 
| 214 | 
         
            +
                formatted_name = f"{name}-{memory}"
         
     | 
| 215 | 
         
            +
                return formatted_name
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
            def get_peak_bw(gpu_name):
         
     | 
| 218 | 
         
            +
                return MEM_BW_DICT[gpu_name]
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
            def get_peak_flops(gpu_name, precision):
         
     | 
| 221 | 
         
            +
                return PEAK_FLOPS_DICT[precision][gpu_name]
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
            def transfer_precision2bytes(precision):
         
     | 
| 224 | 
         
            +
                if precision == "float32":
         
     | 
| 225 | 
         
            +
                    return 4
         
     | 
| 226 | 
         
            +
                elif precision == "float16":
         
     | 
| 227 | 
         
            +
                    return 2
         
     | 
| 228 | 
         
            +
                elif precision == "8bit":
         
     | 
| 229 | 
         
            +
                    return 1
         
     | 
| 230 | 
         
            +
                elif precision == "4bit":
         
     | 
| 231 | 
         
            +
                    return 0.5
         
     | 
| 232 | 
         
            +
                else:
         
     | 
| 233 | 
         
            +
                    raise ValueError(f"Unsupported precision: {precision}")
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
             
            if __name__ == "__main__":
         
     | 
| 236 | 
         
             
                print(analyze_gpu_stats(parse_nvidia_smi()))
         
     |