update utils
Browse files- src/display/utils.py +99 -0
src/display/utils.py
CHANGED
|
@@ -101,3 +101,102 @@ COLS = [col.name for col in COLUMNS]
|
|
| 101 |
BENCHMARK_COLS = [col.name for col in COLUMNS if col.name not in [
|
| 102 |
"model", "average", "model_type", "weight_type", "precision", "license", "likes", "still_on_hub"
|
| 103 |
]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
BENCHMARK_COLS = [col.name for col in COLUMNS if col.name not in [
|
| 102 |
"model", "average", "model_type", "weight_type", "precision", "license", "likes", "still_on_hub"
|
| 103 |
]]
|
| 104 |
+
|
| 105 |
+
# For the queue columns in the submission tab
|
| 106 |
+
@dataclass(frozen=True)
|
| 107 |
+
class EvalQueueColumn:
|
| 108 |
+
name: str
|
| 109 |
+
type: Any
|
| 110 |
+
label: str
|
| 111 |
+
description: str
|
| 112 |
+
|
| 113 |
+
# Define the queue columns
|
| 114 |
+
EVAL_QUEUE_COLUMNS: List[EvalQueueColumn] = [
|
| 115 |
+
EvalQueueColumn(
|
| 116 |
+
name="model",
|
| 117 |
+
type=str,
|
| 118 |
+
label="Model",
|
| 119 |
+
description="Model name",
|
| 120 |
+
),
|
| 121 |
+
EvalQueueColumn(
|
| 122 |
+
name="revision",
|
| 123 |
+
type=str,
|
| 124 |
+
label="Revision",
|
| 125 |
+
description="Model revision or commit hash",
|
| 126 |
+
),
|
| 127 |
+
EvalQueueColumn(
|
| 128 |
+
name="private",
|
| 129 |
+
type=bool,
|
| 130 |
+
label="Private",
|
| 131 |
+
description="Is the model private?",
|
| 132 |
+
),
|
| 133 |
+
EvalQueueColumn(
|
| 134 |
+
name="precision",
|
| 135 |
+
type=str,
|
| 136 |
+
label="Precision",
|
| 137 |
+
description="Precision of the model weights",
|
| 138 |
+
),
|
| 139 |
+
EvalQueueColumn(
|
| 140 |
+
name="weight_type",
|
| 141 |
+
type=str,
|
| 142 |
+
label="Weight Type",
|
| 143 |
+
description="Type of model weights",
|
| 144 |
+
),
|
| 145 |
+
EvalQueueColumn(
|
| 146 |
+
name="status",
|
| 147 |
+
type=str,
|
| 148 |
+
label="Status",
|
| 149 |
+
description="Evaluation status",
|
| 150 |
+
),
|
| 151 |
+
]
|
| 152 |
+
|
| 153 |
+
# Create lists for evaluation columns and types
|
| 154 |
+
EVAL_COLS = [col.name for col in EVAL_QUEUE_COLUMNS]
|
| 155 |
+
EVAL_TYPES = [col.type for col in EVAL_QUEUE_COLUMNS]
|
| 156 |
+
|
| 157 |
+
# Model information
|
| 158 |
+
@dataclass
|
| 159 |
+
class ModelDetails:
|
| 160 |
+
name: str
|
| 161 |
+
display_name: str = ""
|
| 162 |
+
symbol: str = "" # emoji
|
| 163 |
+
|
| 164 |
+
class ModelType(Enum):
|
| 165 |
+
PT = ModelDetails(name="pretrained", symbol="🟢")
|
| 166 |
+
FT = ModelDetails(name="fine-tuned", symbol="🔶")
|
| 167 |
+
IFT = ModelDetails(name="instruction-tuned", symbol="⭕")
|
| 168 |
+
RL = ModelDetails(name="RL-tuned", symbol="🟦")
|
| 169 |
+
Unknown = ModelDetails(name="", symbol="?")
|
| 170 |
+
|
| 171 |
+
def to_str(self, separator=" "):
|
| 172 |
+
return f"{self.value.symbol}{separator}{self.value.name}"
|
| 173 |
+
|
| 174 |
+
@staticmethod
|
| 175 |
+
def from_str(type_str):
|
| 176 |
+
if "fine-tuned" in type_str or "🔶" in type_str:
|
| 177 |
+
return ModelType.FT
|
| 178 |
+
if "pretrained" in type_str or "🟢" in type_str:
|
| 179 |
+
return ModelType.PT
|
| 180 |
+
if "RL-tuned" in type_str or "🟦" in type_str:
|
| 181 |
+
return ModelType.RL
|
| 182 |
+
if "instruction-tuned" in type_str or "⭕" in type_str:
|
| 183 |
+
return ModelType.IFT
|
| 184 |
+
return ModelType.Unknown
|
| 185 |
+
|
| 186 |
+
class WeightType(Enum):
|
| 187 |
+
Adapter = "Adapter"
|
| 188 |
+
Original = "Original"
|
| 189 |
+
Delta = "Delta"
|
| 190 |
+
|
| 191 |
+
class Precision(Enum):
|
| 192 |
+
float16 = "float16"
|
| 193 |
+
bfloat16 = "bfloat16"
|
| 194 |
+
Unknown = "Unknown"
|
| 195 |
+
|
| 196 |
+
@staticmethod
|
| 197 |
+
def from_str(precision_str):
|
| 198 |
+
if precision_str in ["torch.float16", "float16"]:
|
| 199 |
+
return Precision.float16
|
| 200 |
+
if precision_str in ["torch.bfloat16", "bfloat16"]:
|
| 201 |
+
return Precision.bfloat16
|
| 202 |
+
return Precision.Unknown
|