meg-huggingface commited on
Commit
3d16b0d
·
1 Parent(s): 7d70d90

Inference endpoint figuring

Browse files
Files changed (1) hide show
  1. src/backend/inference_endpoint.py +70 -51
src/backend/inference_endpoint.py CHANGED
@@ -6,83 +6,102 @@ from huggingface_hub import create_inference_endpoint, get_inference_endpoint
6
  from src.backend.run_toxicity_eval import get_generation
7
  from src.logging import setup_logger
8
  import requests
 
9
  logging.basicConfig(level=logging.DEBUG)
10
  logger = setup_logger(__name__)
11
- TIMEOUT=20
 
12
 
13
- # TODO: Handle case where endpoint returns an error (for example because of flash attention or not fitting into memory)
14
- def create_endpoint(endpoint_name, repository, framework="pytorch", task="text-generation", accelerator="gpu", vendor="aws", region="us-east-1", type="protected", instance_size="x4", instance_type="nvidia-l4"):
 
 
15
  logger.info("Creating endpoint %s..." % endpoint_name)
16
  # TODO(mm): Handle situation where it's paused
17
  try:
18
- endpoint = create_inference_endpoint(endpoint_name, repository=repository, framework=framework, task=task, accelerator=accelerator, vendor=vendor, region=region, type=type, instance_size=instance_size, instance_type=instance_type)
 
 
 
 
 
 
 
19
  except huggingface_hub.utils._errors.HfHubHTTPError as e:
20
- endpoint = update_endpoint_exception(e, endpoint_name=endpoint_name, repository=repository, framework=framework, task=task, accelerator=accelerator, vendor=vendor, region=region, type=type, instance_size=instance_size, instance_type=instance_type)
 
 
 
 
 
 
 
 
 
 
 
 
21
  except requests.exceptions.HTTPError as e:
22
- endpoint = update_endpoint_exception(e, endpoint_name, repository=repository, framework=framework, task=task, accelerator=accelerator, vendor=vendor, region=region, type=type, instance_size=instance_size, instance_type=instance_type)
 
 
 
23
  except Exception as e:
24
- logger.debug("Hit error")
25
  logger.debug(e)
26
  sys.exit()
27
  endpoint.fetch()
28
- logger.info("Endpoint status: %s." % (endpoint.status))
29
- if endpoint.status == "scaledToZero":
30
  # Send a request to wake it up.
31
  get_generation(endpoint.url, "Wake up")
32
  sleep(TIMEOUT)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  i = 0
34
- while endpoint.status in ["pending", "initializing"]:# aka, not in ["failed", "running"]
 
35
  if i >= 20:
36
  logger.info("Model failed to respond. Exiting.")
37
  sys.exit()
38
- logger.debug("Waiting %d seconds to check again if the endpoint is running." % TIMEOUT)
 
39
  sleep(TIMEOUT)
40
  endpoint.fetch()
41
  logger.debug("Endpoint status: %s." % (endpoint.status))
42
  i += 1
43
- logger.info("Endpoint created:")
44
- logger.info(endpoint)
45
- generation_url = endpoint.url
46
- return generation_url
47
 
48
 
49
- def update_endpoint_exception(e, endpoint_name, repository, framework, task, accelerator, vendor, region, type, instance_size, instance_type):
50
- logger.debug("Hit the following exception:")
51
- logger.debug(e)
52
- logger.debug("Attempting to continue.")
53
- try:
54
- endpoint = get_inference_endpoint(endpoint_name)
55
- endpoint.update(repository=repository, framework=framework, task=task,
56
- accelerator=accelerator, instance_size=instance_size,
57
- instance_type=instance_type)
58
- except huggingface_hub.utils._errors.BadRequestError as e:
59
- logger.debug("Hit the following exception:")
60
- logger.debug(e)
61
- logger.debug("Attempting a new instance type.")
62
- if instance_type == "nvidia-l4":
63
- # Try a larger, different, more expensive GPU.
64
- endpoint = create_inference_endpoint(endpoint_name,
65
- repository=repository,
66
- framework=framework, task=task,
67
- accelerator=accelerator,
68
- vendor=vendor, region=region,
69
- type=type,
70
- instance_size="x1",
71
- instance_type="nvidia-a100")
72
- elif instance_type == "a100" and instance_size == "x1":
73
- endpoint = create_inference_endpoint(endpoint_name,
74
- repository=repository,
75
- framework=framework, task=task,
76
- accelerator=accelerator,
77
- vendor=vendor, region=region,
78
- type=type,
79
- instance_size="x4",
80
- instance_type="nvidia-a10g")
81
- else:
82
- logger.info("Getting expensive to try to run this model without human oversight. Exiting.")
83
- sys.exit()
84
  return endpoint
85
 
86
 
87
  if __name__ == '__main__':
88
- generation_url = create_endpoint("this-is-a-test", "Qwen/Qwen2-7B")
 
6
  from src.backend.run_toxicity_eval import get_generation
7
  from src.logging import setup_logger
8
  import requests
9
+
10
  logging.basicConfig(level=logging.DEBUG)
11
  logger = setup_logger(__name__)
12
+ TIMEOUT = 20
13
+
14
 
15
+ def create_endpoint(endpoint_name, repository, framework='pytorch',
16
+ task='text-generation', accelerator='gpu', vendor='aws',
17
+ region='us-east-1', type='protected', instance_size='x4',
18
+ instance_type='nvidia-l4'):
19
  logger.info("Creating endpoint %s..." % endpoint_name)
20
  # TODO(mm): Handle situation where it's paused
21
  try:
22
+ endpoint = create_inference_endpoint(endpoint_name,
23
+ repository=repository,
24
+ framework=framework, task=task,
25
+ accelerator=accelerator,
26
+ vendor=vendor, region=region,
27
+ type=type,
28
+ instance_size=instance_size,
29
+ instance_type=instance_type)
30
  except huggingface_hub.utils._errors.HfHubHTTPError as e:
31
+ # Workload with the same name already exists error.
32
+ # Use it again, just make sure it has the right settings.
33
+ logger.debug("Hit error:")
34
+ logger.debug(e)
35
+ logger.debug("Attempting to update with the given parameters.")
36
+ endpoint = get_inference_endpoint(endpoint_name)
37
+ endpoint.update(repository=repository,
38
+ framework=framework, task=task,
39
+ accelerator=accelerator,
40
+ vendor=vendor, region=region,
41
+ type=type,
42
+ instance_size=instance_size,
43
+ instance_type=instance_type)
44
  except requests.exceptions.HTTPError as e:
45
+ # Not enough compute, or wrong compute
46
+ logger.debug("Hit error:")
47
+ logger.debug(e)
48
+ endpoint = update_endpoint_exception(endpoint)
49
  except Exception as e:
50
+ logger.debug("Hit unaccounted-for error")
51
  logger.debug(e)
52
  sys.exit()
53
  endpoint.fetch()
54
+ logger.info("Endpoint status: %s." % endpoint.status)
55
+ if endpoint.status == 'scaledToZero':
56
  # Send a request to wake it up.
57
  get_generation(endpoint.url, "Wake up")
58
  sleep(TIMEOUT)
59
+ elif endpoint.status == 'failed':
60
+ logger.info("Endpoint failed, attempting to change compute.")
61
+ endpoint = update_endpoint_exception(endpoint)
62
+ wait_for_endpoint(endpoint)
63
+ if endpoint.status == 'failed':
64
+ logger.info("Endpoint failed, attempting to change compute.")
65
+ endpoint = update_endpoint_exception(endpoint)
66
+ wait_for_endpoint(endpoint)
67
+ logger.info("Endpoint created:")
68
+ logger.info(endpoint)
69
+ generation_url = endpoint.url
70
+ if generation_url is None:
71
+ logger.debug("Failed to create an endpoint. Exiting.")
72
+ sys.exit()
73
+ return generation_url
74
+
75
+
76
+ def wait_for_endpoint(endpoint):
77
  i = 0
78
+ while endpoint.status in ['pending',
79
+ 'initializing']: # not in ['failed', 'running', 'scaledToZero']
80
  if i >= 20:
81
  logger.info("Model failed to respond. Exiting.")
82
  sys.exit()
83
+ logger.debug(
84
+ "Waiting %d seconds to check again if the endpoint is running." % TIMEOUT)
85
  sleep(TIMEOUT)
86
  endpoint.fetch()
87
  logger.debug("Endpoint status: %s." % (endpoint.status))
88
  i += 1
 
 
 
 
89
 
90
 
91
+ def update_endpoint_exception(endpoint):
92
+ raw_info = endpoint.raw
93
+ cur_instance_size = raw_info['compute']['instanceSize']
94
+ cur_instance_type = raw_info['compute']['instanceType']
95
+ if (cur_instance_type, cur_instance_size) == ('nvidia-l4', 'x4'):
96
+ endpoint.update(instance_size='x1', instance_type='nvidia-a100')
97
+ elif (cur_instance_type, cur_instance_size) == ('a100', 'x1'):
98
+ endpoint.update(instance_size='x4', instance_type='nvidia-a10g')
99
+ else:
100
+ logger.info(
101
+ "Getting expensive to try to run this model without human oversight. Exiting.")
102
+ sys.exit()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  return endpoint
104
 
105
 
106
  if __name__ == '__main__':
107
+ generation_url = create_endpoint('this-is-a-test', 'Qwen/Qwen2-7B')