Spaces:
Running
Running
Commit
·
65321e4
1
Parent(s):
351c0ef
update: docstring
Browse files
guardrails_genie/train_classifier.py
CHANGED
@@ -79,24 +79,18 @@ def train_binary_classifier(
|
|
79 |
entity_name (str): The Weights & Biases entity (user or team).
|
80 |
run_name (str): The name of the Weights & Biases run.
|
81 |
dataset_repo (str, optional): The Hugging Face dataset repository to load.
|
82 |
-
|
83 |
-
model_name (str, optional): The pre-trained model to use. Defaults to
|
84 |
-
"distilbert/distilbert-base-uncased".
|
85 |
prompt_column_name (str, optional): The column name in the dataset containing
|
86 |
-
the text prompts.
|
87 |
id2label (dict[int, str], optional): Mapping from label IDs to label names.
|
88 |
-
Defaults to {0: "SAFE", 1: "INJECTION"}.
|
89 |
label2id (dict[str, int], optional): Mapping from label names to label IDs.
|
90 |
-
|
91 |
-
learning_rate (float, optional): The learning rate for training. Defaults to 1e-5.
|
92 |
batch_size (int, optional): The batch size for training and evaluation.
|
93 |
-
|
94 |
-
|
95 |
-
weight_decay (float, optional): The weight decay for the optimizer. Defaults to 0.01.
|
96 |
save_steps (int, optional): The number of steps between model checkpoints.
|
97 |
-
Defaults to 1000.
|
98 |
streamlit_mode (bool, optional): If True, integrates with Streamlit to display
|
99 |
-
a progress bar.
|
100 |
|
101 |
Returns:
|
102 |
dict: The output of the training process, including metrics and model state.
|
|
|
79 |
entity_name (str): The Weights & Biases entity (user or team).
|
80 |
run_name (str): The name of the Weights & Biases run.
|
81 |
dataset_repo (str, optional): The Hugging Face dataset repository to load.
|
82 |
+
model_name (str, optional): The pre-trained model to use.
|
|
|
|
|
83 |
prompt_column_name (str, optional): The column name in the dataset containing
|
84 |
+
the text prompts.
|
85 |
id2label (dict[int, str], optional): Mapping from label IDs to label names.
|
|
|
86 |
label2id (dict[str, int], optional): Mapping from label names to label IDs.
|
87 |
+
learning_rate (float, optional): The learning rate for training.
|
|
|
88 |
batch_size (int, optional): The batch size for training and evaluation.
|
89 |
+
num_epochs (int, optional): The number of training epochs.
|
90 |
+
weight_decay (float, optional): The weight decay for the optimizer.
|
|
|
91 |
save_steps (int, optional): The number of steps between model checkpoints.
|
|
|
92 |
streamlit_mode (bool, optional): If True, integrates with Streamlit to display
|
93 |
+
a progress bar.
|
94 |
|
95 |
Returns:
|
96 |
dict: The output of the training process, including metrics and model state.
|