meg-huggingface commited on
Commit
58956f6
·
1 Parent(s): 66621a9

Adding more endpoint options

Browse files
Files changed (1) hide show
  1. src/backend/inference_endpoint.py +27 -2
src/backend/inference_endpoint.py CHANGED
@@ -20,8 +20,33 @@ def create_endpoint(endpoint_name, repository, framework="pytorch", task="text-g
20
  logger.debug("Hit the following exception:")
21
  logger.debug(e)
22
  logger.debug("Attempting to continue.")
23
- endpoint = get_inference_endpoint(endpoint_name)
24
- endpoint.update(repository=repository, framework=framework, task=task, accelerator=accelerator, instance_size=instance_size, instance_type=instance_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  endpoint.fetch()
26
  logger.info("Endpoint status: %s." % (endpoint.status))
27
  if endpoint.status == "scaledToZero":
 
20
  logger.debug("Hit the following exception:")
21
  logger.debug(e)
22
  logger.debug("Attempting to continue.")
23
+ try:
24
+ endpoint = get_inference_endpoint(endpoint_name)
25
+ endpoint.update(repository=repository, framework=framework, task=task, accelerator=accelerator, instance_size=instance_size, instance_type=instance_type)
26
+ except:
27
+ if instance_type == "nvidia-l4":
28
+ # Try a larger, different, more expensive GPU.
29
+ endpoint = create_inference_endpoint(endpoint_name,
30
+ repository=repository,
31
+ framework=framework, task=task,
32
+ accelerator=accelerator,
33
+ vendor=vendor, region=region,
34
+ type=type,
35
+ instance_size="x1",
36
+ instance_type="nvidia-a100")
37
+ elif instance_type == "a100" and instance_size == "x1":
38
+ endpoint = create_inference_endpoint(endpoint_name,
39
+ repository=repository,
40
+ framework=framework, task=task,
41
+ accelerator=accelerator,
42
+ vendor=vendor, region=region,
43
+ type=type,
44
+ instance_size="x4",
45
+ instance_type="nvidia-a10g")
46
+ else:
47
+ logger.info("Getting expensive to run this model without human oversight. Exiting.")
48
+ sys.exit()
49
+
50
  endpoint.fetch()
51
  logger.info("Endpoint status: %s." % (endpoint.status))
52
  if endpoint.status == "scaledToZero":