Spaces:
Running
Running
Commit
·
351c0ef
1
Parent(s):
f94b561
add: docs for train classifier
Browse files- docs/train_classifier.md +3 -0
- guardrails_genie/train_classifier.py +56 -3
- mkdocs.yml +1 -0
docs/train_classifier.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Train Classifier
|
2 |
+
|
3 |
+
::: guardrails_genie.train_classifier
|
guardrails_genie/train_classifier.py
CHANGED
@@ -16,6 +16,22 @@ import wandb
|
|
16 |
|
17 |
|
18 |
class StreamlitProgressbarCallback(TrainerCallback):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
def __init__(self, *args, **kwargs):
|
21 |
super().__init__(*args, **kwargs)
|
@@ -42,6 +58,8 @@ def train_binary_classifier(
|
|
42 |
dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
|
43 |
model_name: str = "distilbert/distilbert-base-uncased",
|
44 |
prompt_column_name: str = "prompt",
|
|
|
|
|
45 |
learning_rate: float = 1e-5,
|
46 |
batch_size: int = 16,
|
47 |
num_epochs: int = 2,
|
@@ -49,6 +67,44 @@ def train_binary_classifier(
|
|
49 |
save_steps: int = 1000,
|
50 |
streamlit_mode: bool = False,
|
51 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
wandb.init(project=project_name, entity=entity_name, name=run_name)
|
53 |
if streamlit_mode:
|
54 |
st.markdown(
|
@@ -69,9 +125,6 @@ def train_binary_classifier(
|
|
69 |
predictions = np.argmax(predictions, axis=1)
|
70 |
return accuracy.compute(predictions=predictions, references=labels)
|
71 |
|
72 |
-
id2label = {0: "SAFE", 1: "INJECTION"}
|
73 |
-
label2id = {"SAFE": 0, "INJECTION": 1}
|
74 |
-
|
75 |
model = AutoModelForSequenceClassification.from_pretrained(
|
76 |
model_name,
|
77 |
num_labels=2,
|
|
|
16 |
|
17 |
|
18 |
class StreamlitProgressbarCallback(TrainerCallback):
|
19 |
+
"""
|
20 |
+
StreamlitProgressbarCallback is a custom callback for the Hugging Face Trainer
|
21 |
+
that integrates a progress bar into a Streamlit application. This class updates
|
22 |
+
the progress bar at each training step, providing real-time feedback on the
|
23 |
+
training process within the Streamlit interface.
|
24 |
+
|
25 |
+
Attributes:
|
26 |
+
progress_bar (streamlit.delta_generator.DeltaGenerator): A Streamlit progress
|
27 |
+
bar object initialized to 0 with the text "Training".
|
28 |
+
|
29 |
+
Methods:
|
30 |
+
on_step_begin(args, state, control, **kwargs):
|
31 |
+
Updates the progress bar at the beginning of each training step. The progress
|
32 |
+
is calculated as the percentage of completed steps out of the total steps.
|
33 |
+
The progress bar text is updated to show the current step and the total steps.
|
34 |
+
"""
|
35 |
|
36 |
def __init__(self, *args, **kwargs):
|
37 |
super().__init__(*args, **kwargs)
|
|
|
58 |
dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
|
59 |
model_name: str = "distilbert/distilbert-base-uncased",
|
60 |
prompt_column_name: str = "prompt",
|
61 |
+
id2label: dict[int, str] = {0: "SAFE", 1: "INJECTION"},
|
62 |
+
label2id: dict[str, int] = {"SAFE": 0, "INJECTION": 1},
|
63 |
learning_rate: float = 1e-5,
|
64 |
batch_size: int = 16,
|
65 |
num_epochs: int = 2,
|
|
|
67 |
save_steps: int = 1000,
|
68 |
streamlit_mode: bool = False,
|
69 |
):
|
70 |
+
"""
|
71 |
+
Trains a binary classifier using a specified dataset and model architecture.
|
72 |
+
|
73 |
+
This function sets up and trains a binary sequence classification model using
|
74 |
+
the Hugging Face Transformers library. It integrates with Weights & Biases for
|
75 |
+
experiment tracking and optionally displays a progress bar in a Streamlit app.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
project_name (str): The name of the Weights & Biases project.
|
79 |
+
entity_name (str): The Weights & Biases entity (user or team).
|
80 |
+
run_name (str): The name of the Weights & Biases run.
|
81 |
+
dataset_repo (str, optional): The Hugging Face dataset repository to load.
|
82 |
+
Defaults to "geekyrakshit/prompt-injection-dataset".
|
83 |
+
model_name (str, optional): The pre-trained model to use. Defaults to
|
84 |
+
"distilbert/distilbert-base-uncased".
|
85 |
+
prompt_column_name (str, optional): The column name in the dataset containing
|
86 |
+
the text prompts. Defaults to "prompt".
|
87 |
+
id2label (dict[int, str], optional): Mapping from label IDs to label names.
|
88 |
+
Defaults to {0: "SAFE", 1: "INJECTION"}.
|
89 |
+
label2id (dict[str, int], optional): Mapping from label names to label IDs.
|
90 |
+
Defaults to {"SAFE": 0, "INJECTION": 1}.
|
91 |
+
learning_rate (float, optional): The learning rate for training. Defaults to 1e-5.
|
92 |
+
batch_size (int, optional): The batch size for training and evaluation.
|
93 |
+
Defaults to 16.
|
94 |
+
num_epochs (int, optional): The number of training epochs. Defaults to 2.
|
95 |
+
weight_decay (float, optional): The weight decay for the optimizer. Defaults to 0.01.
|
96 |
+
save_steps (int, optional): The number of steps between model checkpoints.
|
97 |
+
Defaults to 1000.
|
98 |
+
streamlit_mode (bool, optional): If True, integrates with Streamlit to display
|
99 |
+
a progress bar. Defaults to False.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
dict: The output of the training process, including metrics and model state.
|
103 |
+
|
104 |
+
Raises:
|
105 |
+
Exception: If an error occurs during training, the exception is raised after
|
106 |
+
ensuring Weights & Biases run is finished.
|
107 |
+
"""
|
108 |
wandb.init(project=project_name, entity=entity_name, name=run_name)
|
109 |
if streamlit_mode:
|
110 |
st.markdown(
|
|
|
125 |
predictions = np.argmax(predictions, axis=1)
|
126 |
return accuracy.compute(predictions=predictions, references=labels)
|
127 |
|
|
|
|
|
|
|
128 |
model = AutoModelForSequenceClassification.from_pretrained(
|
129 |
model_name,
|
130 |
num_labels=2,
|
mkdocs.yml
CHANGED
@@ -68,6 +68,7 @@ nav:
|
|
68 |
- LLM: 'llm.md'
|
69 |
- Metrics: 'metrics.md'
|
70 |
- RegexModel: 'regex_model.md'
|
|
|
71 |
- Utils: 'utils.md'
|
72 |
|
73 |
repo_url: https://github.com/soumik12345/guardrails-genie
|
|
|
68 |
- LLM: 'llm.md'
|
69 |
- Metrics: 'metrics.md'
|
70 |
- RegexModel: 'regex_model.md'
|
71 |
+
- Train Classifier: 'train_classifier.md'
|
72 |
- Utils: 'utils.md'
|
73 |
|
74 |
repo_url: https://github.com/soumik12345/guardrails-genie
|