Spaces:
Running
Running
from contextlib import contextmanager | |
from typing import cast | |
import logging | |
from . import api | |
from . import TensorPipeAgent | |
logger = logging.getLogger(__name__) | |
def _group_membership_management(store, name, is_join): | |
token_key = "RpcGroupManagementToken" | |
join_or_leave = "join" if is_join else "leave" | |
my_token = f"Token_for_{name}_{join_or_leave}" | |
while True: | |
# Retrieve token from store to signal start of rank join/leave critical section | |
returned = store.compare_set(token_key, "", my_token).decode() | |
if returned == my_token: | |
# Yield to the function this context manager wraps | |
yield | |
# Finished, now exit and release token | |
# Update from store to signal end of rank join/leave critical section | |
store.set(token_key, "") | |
# Other will wait for this token to be set before they execute | |
store.set(my_token, "Done") | |
break | |
else: | |
# Store will wait for the token to be released | |
try: | |
store.wait([returned]) | |
except RuntimeError: | |
logger.error("Group membership token %s timed out waiting for %s to be released.", my_token, returned) | |
raise | |
def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join): | |
agent = cast(TensorPipeAgent, api._get_current_rpc_agent()) | |
ret = agent._update_group_membership(worker_info, my_devices, reverse_device_map, is_join) | |
return ret | |