meg-huggingface
commited on
Commit
·
3d16b0d
1
Parent(s):
7d70d90
Inference endpoint figuring
Browse files
src/backend/inference_endpoint.py
CHANGED
@@ -6,83 +6,102 @@ from huggingface_hub import create_inference_endpoint, get_inference_endpoint
|
|
6 |
from src.backend.run_toxicity_eval import get_generation
|
7 |
from src.logging import setup_logger
|
8 |
import requests
|
|
|
9 |
logging.basicConfig(level=logging.DEBUG)
|
10 |
logger = setup_logger(__name__)
|
11 |
-
TIMEOUT=20
|
|
|
12 |
|
13 |
-
|
14 |
-
|
|
|
|
|
15 |
logger.info("Creating endpoint %s..." % endpoint_name)
|
16 |
# TODO(mm): Handle situation where it's paused
|
17 |
try:
|
18 |
-
endpoint = create_inference_endpoint(endpoint_name,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
except huggingface_hub.utils._errors.HfHubHTTPError as e:
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
except requests.exceptions.HTTPError as e:
|
22 |
-
|
|
|
|
|
|
|
23 |
except Exception as e:
|
24 |
-
logger.debug("Hit error")
|
25 |
logger.debug(e)
|
26 |
sys.exit()
|
27 |
endpoint.fetch()
|
28 |
-
logger.info("Endpoint status: %s." %
|
29 |
-
if endpoint.status ==
|
30 |
# Send a request to wake it up.
|
31 |
get_generation(endpoint.url, "Wake up")
|
32 |
sleep(TIMEOUT)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
i = 0
|
34 |
-
while endpoint.status in [
|
|
|
35 |
if i >= 20:
|
36 |
logger.info("Model failed to respond. Exiting.")
|
37 |
sys.exit()
|
38 |
-
logger.debug(
|
|
|
39 |
sleep(TIMEOUT)
|
40 |
endpoint.fetch()
|
41 |
logger.debug("Endpoint status: %s." % (endpoint.status))
|
42 |
i += 1
|
43 |
-
logger.info("Endpoint created:")
|
44 |
-
logger.info(endpoint)
|
45 |
-
generation_url = endpoint.url
|
46 |
-
return generation_url
|
47 |
|
48 |
|
49 |
-
def update_endpoint_exception(
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
endpoint =
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
logger.debug("Attempting a new instance type.")
|
62 |
-
if instance_type == "nvidia-l4":
|
63 |
-
# Try a larger, different, more expensive GPU.
|
64 |
-
endpoint = create_inference_endpoint(endpoint_name,
|
65 |
-
repository=repository,
|
66 |
-
framework=framework, task=task,
|
67 |
-
accelerator=accelerator,
|
68 |
-
vendor=vendor, region=region,
|
69 |
-
type=type,
|
70 |
-
instance_size="x1",
|
71 |
-
instance_type="nvidia-a100")
|
72 |
-
elif instance_type == "a100" and instance_size == "x1":
|
73 |
-
endpoint = create_inference_endpoint(endpoint_name,
|
74 |
-
repository=repository,
|
75 |
-
framework=framework, task=task,
|
76 |
-
accelerator=accelerator,
|
77 |
-
vendor=vendor, region=region,
|
78 |
-
type=type,
|
79 |
-
instance_size="x4",
|
80 |
-
instance_type="nvidia-a10g")
|
81 |
-
else:
|
82 |
-
logger.info("Getting expensive to try to run this model without human oversight. Exiting.")
|
83 |
-
sys.exit()
|
84 |
return endpoint
|
85 |
|
86 |
|
87 |
if __name__ == '__main__':
|
88 |
-
generation_url = create_endpoint(
|
|
|
6 |
from src.backend.run_toxicity_eval import get_generation
|
7 |
from src.logging import setup_logger
|
8 |
import requests
|
9 |
+
|
10 |
logging.basicConfig(level=logging.DEBUG)
|
11 |
logger = setup_logger(__name__)
|
12 |
+
TIMEOUT = 20
|
13 |
+
|
14 |
|
15 |
+
def create_endpoint(endpoint_name, repository, framework='pytorch',
|
16 |
+
task='text-generation', accelerator='gpu', vendor='aws',
|
17 |
+
region='us-east-1', type='protected', instance_size='x4',
|
18 |
+
instance_type='nvidia-l4'):
|
19 |
logger.info("Creating endpoint %s..." % endpoint_name)
|
20 |
# TODO(mm): Handle situation where it's paused
|
21 |
try:
|
22 |
+
endpoint = create_inference_endpoint(endpoint_name,
|
23 |
+
repository=repository,
|
24 |
+
framework=framework, task=task,
|
25 |
+
accelerator=accelerator,
|
26 |
+
vendor=vendor, region=region,
|
27 |
+
type=type,
|
28 |
+
instance_size=instance_size,
|
29 |
+
instance_type=instance_type)
|
30 |
except huggingface_hub.utils._errors.HfHubHTTPError as e:
|
31 |
+
# Workload with the same name already exists error.
|
32 |
+
# Use it again, just make sure it has the right settings.
|
33 |
+
logger.debug("Hit error:")
|
34 |
+
logger.debug(e)
|
35 |
+
logger.debug("Attempting to update with the given parameters.")
|
36 |
+
endpoint = get_inference_endpoint(endpoint_name)
|
37 |
+
endpoint.update(repository=repository,
|
38 |
+
framework=framework, task=task,
|
39 |
+
accelerator=accelerator,
|
40 |
+
vendor=vendor, region=region,
|
41 |
+
type=type,
|
42 |
+
instance_size=instance_size,
|
43 |
+
instance_type=instance_type)
|
44 |
except requests.exceptions.HTTPError as e:
|
45 |
+
# Not enough compute, or wrong compute
|
46 |
+
logger.debug("Hit error:")
|
47 |
+
logger.debug(e)
|
48 |
+
endpoint = update_endpoint_exception(endpoint)
|
49 |
except Exception as e:
|
50 |
+
logger.debug("Hit unaccounted-for error")
|
51 |
logger.debug(e)
|
52 |
sys.exit()
|
53 |
endpoint.fetch()
|
54 |
+
logger.info("Endpoint status: %s." % endpoint.status)
|
55 |
+
if endpoint.status == 'scaledToZero':
|
56 |
# Send a request to wake it up.
|
57 |
get_generation(endpoint.url, "Wake up")
|
58 |
sleep(TIMEOUT)
|
59 |
+
elif endpoint.status == 'failed':
|
60 |
+
logger.info("Endpoint failed, attempting to change compute.")
|
61 |
+
endpoint = update_endpoint_exception(endpoint)
|
62 |
+
wait_for_endpoint(endpoint)
|
63 |
+
if endpoint.status == 'failed':
|
64 |
+
logger.info("Endpoint failed, attempting to change compute.")
|
65 |
+
endpoint = update_endpoint_exception(endpoint)
|
66 |
+
wait_for_endpoint(endpoint)
|
67 |
+
logger.info("Endpoint created:")
|
68 |
+
logger.info(endpoint)
|
69 |
+
generation_url = endpoint.url
|
70 |
+
if generation_url is None:
|
71 |
+
logger.debug("Failed to create an endpoint. Exiting.")
|
72 |
+
sys.exit()
|
73 |
+
return generation_url
|
74 |
+
|
75 |
+
|
76 |
+
def wait_for_endpoint(endpoint):
|
77 |
i = 0
|
78 |
+
while endpoint.status in ['pending',
|
79 |
+
'initializing']: # not in ['failed', 'running', 'scaledToZero']
|
80 |
if i >= 20:
|
81 |
logger.info("Model failed to respond. Exiting.")
|
82 |
sys.exit()
|
83 |
+
logger.debug(
|
84 |
+
"Waiting %d seconds to check again if the endpoint is running." % TIMEOUT)
|
85 |
sleep(TIMEOUT)
|
86 |
endpoint.fetch()
|
87 |
logger.debug("Endpoint status: %s." % (endpoint.status))
|
88 |
i += 1
|
|
|
|
|
|
|
|
|
89 |
|
90 |
|
91 |
+
def update_endpoint_exception(endpoint):
|
92 |
+
raw_info = endpoint.raw
|
93 |
+
cur_instance_size = raw_info['compute']['instanceSize']
|
94 |
+
cur_instance_type = raw_info['compute']['instanceType']
|
95 |
+
if (cur_instance_type, cur_instance_size) == ('nvidia-l4', 'x4'):
|
96 |
+
endpoint.update(instance_size='x1', instance_type='nvidia-a100')
|
97 |
+
elif (cur_instance_type, cur_instance_size) == ('a100', 'x1'):
|
98 |
+
endpoint.update(instance_size='x4', instance_type='nvidia-a10g')
|
99 |
+
else:
|
100 |
+
logger.info(
|
101 |
+
"Getting expensive to try to run this model without human oversight. Exiting.")
|
102 |
+
sys.exit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
return endpoint
|
104 |
|
105 |
|
106 |
if __name__ == '__main__':
|
107 |
+
generation_url = create_endpoint('this-is-a-test', 'Qwen/Qwen2-7B')
|