geekyrakshit commited on
Commit
65321e4
·
1 Parent(s): 351c0ef

update: docstring

Browse files
guardrails_genie/train_classifier.py CHANGED
@@ -79,24 +79,18 @@ def train_binary_classifier(
79
  entity_name (str): The Weights & Biases entity (user or team).
80
  run_name (str): The name of the Weights & Biases run.
81
  dataset_repo (str, optional): The Hugging Face dataset repository to load.
82
- Defaults to "geekyrakshit/prompt-injection-dataset".
83
- model_name (str, optional): The pre-trained model to use. Defaults to
84
- "distilbert/distilbert-base-uncased".
85
  prompt_column_name (str, optional): The column name in the dataset containing
86
- the text prompts. Defaults to "prompt".
87
  id2label (dict[int, str], optional): Mapping from label IDs to label names.
88
- Defaults to {0: "SAFE", 1: "INJECTION"}.
89
  label2id (dict[str, int], optional): Mapping from label names to label IDs.
90
- Defaults to {"SAFE": 0, "INJECTION": 1}.
91
- learning_rate (float, optional): The learning rate for training. Defaults to 1e-5.
92
  batch_size (int, optional): The batch size for training and evaluation.
93
- Defaults to 16.
94
- num_epochs (int, optional): The number of training epochs. Defaults to 2.
95
- weight_decay (float, optional): The weight decay for the optimizer. Defaults to 0.01.
96
  save_steps (int, optional): The number of steps between model checkpoints.
97
- Defaults to 1000.
98
  streamlit_mode (bool, optional): If True, integrates with Streamlit to display
99
- a progress bar. Defaults to False.
100
 
101
  Returns:
102
  dict: The output of the training process, including metrics and model state.
 
79
  entity_name (str): The Weights & Biases entity (user or team).
80
  run_name (str): The name of the Weights & Biases run.
81
  dataset_repo (str, optional): The Hugging Face dataset repository to load.
82
+ model_name (str, optional): The pre-trained model to use.
 
 
83
  prompt_column_name (str, optional): The column name in the dataset containing
84
+ the text prompts.
85
  id2label (dict[int, str], optional): Mapping from label IDs to label names.
 
86
  label2id (dict[str, int], optional): Mapping from label names to label IDs.
87
+ learning_rate (float, optional): The learning rate for training.
 
88
  batch_size (int, optional): The batch size for training and evaluation.
89
+ num_epochs (int, optional): The number of training epochs.
90
+ weight_decay (float, optional): The weight decay for the optimizer.
 
91
  save_steps (int, optional): The number of steps between model checkpoints.
 
92
  streamlit_mode (bool, optional): If True, integrates with Streamlit to display
93
+ a progress bar.
94
 
95
  Returns:
96
  dict: The output of the training process, including metrics and model state.