#!/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import threading import time from contextlib import ContextDecorator class TimeBenchmark(ContextDecorator): """ Measures execution time using a context manager or decorator. This class supports both context manager and decorator usage, and is thread-safe for multithreaded environments. Args: print: If True, prints the elapsed time upon exiting the context or completing the function. Defaults to False. Examples: Using as a context manager: >>> benchmark = TimeBenchmark() >>> with benchmark: ... time.sleep(1) >>> print(f"Block took {benchmark.result:.4f} seconds") Block took approximately 1.0000 seconds Using with multithreading: ```python import threading benchmark = TimeBenchmark() def context_manager_example(): with benchmark: time.sleep(0.01) print(f"Block took {benchmark.result_ms:.2f} milliseconds") threads = [] for _ in range(3): t1 = threading.Thread(target=context_manager_example) threads.append(t1) for t in threads: t.start() for t in threads: t.join() ``` Expected output: Block took approximately 10.00 milliseconds Block took approximately 10.00 milliseconds Block took approximately 10.00 milliseconds """ def __init__(self, print=False): self.local = threading.local() self.print_time = print def __enter__(self): self.local.start_time = time.perf_counter() return self def __exit__(self, *exc): self.local.end_time = time.perf_counter() self.local.elapsed_time = self.local.end_time - self.local.start_time if self.print_time: print(f"Elapsed time: {self.local.elapsed_time:.4f} seconds") return False @property def result(self): return getattr(self.local, "elapsed_time", None) @property def result_ms(self): return self.result * 1e3