tathagataraha commited on
Commit
671e1a6
·
1 Parent(s): 61cd814

[ADD] Auto Precision for loading directly from model

Browse files
Files changed (3) hide show
  1. app.py +2 -2
  2. src/display/utils.py +3 -0
  3. src/submission/submit.py +2 -2
app.py CHANGED
@@ -422,7 +422,7 @@ with demo:
422
  choices=[i.value.name for i in Precision if i != Precision.Unknown],
423
  label="Precision",
424
  multiselect=False,
425
- value="float16",
426
  interactive=True,
427
  )
428
  weight_type = gr.Dropdown(
@@ -430,7 +430,7 @@ with demo:
430
  label="Weights type",
431
  multiselect=False,
432
  value=WeightType.Original.value.name,
433
- interactive=True,
434
  )
435
  base_model_name_textbox = gr.Textbox(label="Base model (for delta or adapter weights)", interactive=False)
436
  with gr.Row():
 
422
  choices=[i.value.name for i in Precision if i != Precision.Unknown],
423
  label="Precision",
424
  multiselect=False,
425
+ value="auto",
426
  interactive=True,
427
  )
428
  weight_type = gr.Dropdown(
 
430
  label="Weights type",
431
  multiselect=False,
432
  value=WeightType.Original.value.name,
433
+ interactive=False,
434
  )
435
  base_model_name_textbox = gr.Textbox(label="Base model (for delta or adapter weights)", interactive=False)
436
  with gr.Row():
src/display/utils.py CHANGED
@@ -132,6 +132,7 @@ class WeightType(Enum):
132
 
133
 
134
  class Precision(Enum):
 
135
  float16 = ModelDetails("float16")
136
  bfloat16 = ModelDetails("bfloat16")
137
  float32 = ModelDetails("float32")
@@ -141,6 +142,8 @@ class Precision(Enum):
141
  Unknown = ModelDetails("?")
142
 
143
  def from_str(precision):
 
 
144
  if precision in ["torch.float16", "float16"]:
145
  return Precision.float16
146
  if precision in ["torch.bfloat16", "bfloat16"]:
 
132
 
133
 
134
  class Precision(Enum):
135
+ auto = ModelDetails("auto")
136
  float16 = ModelDetails("float16")
137
  bfloat16 = ModelDetails("bfloat16")
138
  float32 = ModelDetails("float32")
 
142
  Unknown = ModelDetails("?")
143
 
144
  def from_str(precision):
145
+ if precision in ["auto"]:
146
+ return Precision.auto
147
  if precision in ["torch.float16", "float16"]:
148
  return Precision.float16
149
  if precision in ["torch.bfloat16", "bfloat16"]:
src/submission/submit.py CHANGED
@@ -133,8 +133,8 @@ def add_new_eval(
133
  "revision": revision,
134
  "precision": precision,
135
  "weight_type": weight_type,
136
- # "is_domain_specific": domain_specific,
137
- # "use_chat_template": chat_template,
138
  "status": "PENDING",
139
  "submitted_time": current_time,
140
  "model_type": model_type,
 
133
  "revision": revision,
134
  "precision": precision,
135
  "weight_type": weight_type,
136
+ "is_domain_specific": domain_specific,
137
+ "use_chat_template": chat_template,
138
  "status": "PENDING",
139
  "submitted_time": current_time,
140
  "model_type": model_type,