geekyrakshit commited on
Commit
351c0ef
·
1 Parent(s): f94b561

add: docs for train classifier

Browse files
docs/train_classifier.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Train Classifier
2
+
3
+ ::: guardrails_genie.train_classifier
guardrails_genie/train_classifier.py CHANGED
@@ -16,6 +16,22 @@ import wandb
16
 
17
 
18
  class StreamlitProgressbarCallback(TrainerCallback):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def __init__(self, *args, **kwargs):
21
  super().__init__(*args, **kwargs)
@@ -42,6 +58,8 @@ def train_binary_classifier(
42
  dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
43
  model_name: str = "distilbert/distilbert-base-uncased",
44
  prompt_column_name: str = "prompt",
 
 
45
  learning_rate: float = 1e-5,
46
  batch_size: int = 16,
47
  num_epochs: int = 2,
@@ -49,6 +67,44 @@ def train_binary_classifier(
49
  save_steps: int = 1000,
50
  streamlit_mode: bool = False,
51
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  wandb.init(project=project_name, entity=entity_name, name=run_name)
53
  if streamlit_mode:
54
  st.markdown(
@@ -69,9 +125,6 @@ def train_binary_classifier(
69
  predictions = np.argmax(predictions, axis=1)
70
  return accuracy.compute(predictions=predictions, references=labels)
71
 
72
- id2label = {0: "SAFE", 1: "INJECTION"}
73
- label2id = {"SAFE": 0, "INJECTION": 1}
74
-
75
  model = AutoModelForSequenceClassification.from_pretrained(
76
  model_name,
77
  num_labels=2,
 
16
 
17
 
18
  class StreamlitProgressbarCallback(TrainerCallback):
19
+ """
20
+ StreamlitProgressbarCallback is a custom callback for the Hugging Face Trainer
21
+ that integrates a progress bar into a Streamlit application. This class updates
22
+ the progress bar at each training step, providing real-time feedback on the
23
+ training process within the Streamlit interface.
24
+
25
+ Attributes:
26
+ progress_bar (streamlit.delta_generator.DeltaGenerator): A Streamlit progress
27
+ bar object initialized to 0 with the text "Training".
28
+
29
+ Methods:
30
+ on_step_begin(args, state, control, **kwargs):
31
+ Updates the progress bar at the beginning of each training step. The progress
32
+ is calculated as the percentage of completed steps out of the total steps.
33
+ The progress bar text is updated to show the current step and the total steps.
34
+ """
35
 
36
  def __init__(self, *args, **kwargs):
37
  super().__init__(*args, **kwargs)
 
58
  dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
59
  model_name: str = "distilbert/distilbert-base-uncased",
60
  prompt_column_name: str = "prompt",
61
+ id2label: dict[int, str] = {0: "SAFE", 1: "INJECTION"},
62
+ label2id: dict[str, int] = {"SAFE": 0, "INJECTION": 1},
63
  learning_rate: float = 1e-5,
64
  batch_size: int = 16,
65
  num_epochs: int = 2,
 
67
  save_steps: int = 1000,
68
  streamlit_mode: bool = False,
69
  ):
70
+ """
71
+ Trains a binary classifier using a specified dataset and model architecture.
72
+
73
+ This function sets up and trains a binary sequence classification model using
74
+ the Hugging Face Transformers library. It integrates with Weights & Biases for
75
+ experiment tracking and optionally displays a progress bar in a Streamlit app.
76
+
77
+ Args:
78
+ project_name (str): The name of the Weights & Biases project.
79
+ entity_name (str): The Weights & Biases entity (user or team).
80
+ run_name (str): The name of the Weights & Biases run.
81
+ dataset_repo (str, optional): The Hugging Face dataset repository to load.
82
+ Defaults to "geekyrakshit/prompt-injection-dataset".
83
+ model_name (str, optional): The pre-trained model to use. Defaults to
84
+ "distilbert/distilbert-base-uncased".
85
+ prompt_column_name (str, optional): The column name in the dataset containing
86
+ the text prompts. Defaults to "prompt".
87
+ id2label (dict[int, str], optional): Mapping from label IDs to label names.
88
+ Defaults to {0: "SAFE", 1: "INJECTION"}.
89
+ label2id (dict[str, int], optional): Mapping from label names to label IDs.
90
+ Defaults to {"SAFE": 0, "INJECTION": 1}.
91
+ learning_rate (float, optional): The learning rate for training. Defaults to 1e-5.
92
+ batch_size (int, optional): The batch size for training and evaluation.
93
+ Defaults to 16.
94
+ num_epochs (int, optional): The number of training epochs. Defaults to 2.
95
+ weight_decay (float, optional): The weight decay for the optimizer. Defaults to 0.01.
96
+ save_steps (int, optional): The number of steps between model checkpoints.
97
+ Defaults to 1000.
98
+ streamlit_mode (bool, optional): If True, integrates with Streamlit to display
99
+ a progress bar. Defaults to False.
100
+
101
+ Returns:
102
+ dict: The output of the training process, including metrics and model state.
103
+
104
+ Raises:
105
+ Exception: If an error occurs during training, the exception is raised after
106
+ ensuring Weights & Biases run is finished.
107
+ """
108
  wandb.init(project=project_name, entity=entity_name, name=run_name)
109
  if streamlit_mode:
110
  st.markdown(
 
125
  predictions = np.argmax(predictions, axis=1)
126
  return accuracy.compute(predictions=predictions, references=labels)
127
 
 
 
 
128
  model = AutoModelForSequenceClassification.from_pretrained(
129
  model_name,
130
  num_labels=2,
mkdocs.yml CHANGED
@@ -68,6 +68,7 @@ nav:
68
  - LLM: 'llm.md'
69
  - Metrics: 'metrics.md'
70
  - RegexModel: 'regex_model.md'
 
71
  - Utils: 'utils.md'
72
 
73
  repo_url: https://github.com/soumik12345/guardrails-genie
 
68
  - LLM: 'llm.md'
69
  - Metrics: 'metrics.md'
70
  - RegexModel: 'regex_model.md'
71
+ - Train Classifier: 'train_classifier.md'
72
  - Utils: 'utils.md'
73
 
74
  repo_url: https://github.com/soumik12345/guardrails-genie