File size: 9,132 Bytes
2900eb1
 
dfbca8a
573a89c
 
2900eb1
573a89c
2900eb1
 
573a89c
 
 
2900eb1
573a89c
2900eb1
573a89c
 
 
 
 
 
 
 
 
 
2900eb1
 
 
 
 
573a89c
 
 
 
 
 
a202ba5
 
573a89c
 
 
 
a202ba5
 
573a89c
 
 
 
 
2900eb1
573a89c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a202ba5
 
 
 
 
 
573a89c
 
 
 
 
 
 
 
 
 
 
a202ba5
 
 
 
 
 
573a89c
 
 
 
 
 
 
dfbca8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573a89c
dfbca8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573a89c
dfbca8a
573a89c
dfbca8a
573a89c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfbca8a
573a89c
2900eb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import os

import plotly.graph_objects as go
import streamlit as st
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import wandb
from datasets import load_dataset
from pydantic import BaseModel
from rich.progress import track
from safetensors.torch import save_model
from sklearn.metrics import roc_auc_score, roc_curve
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer


class DatasetArgs(BaseModel):
    dataset_address: str
    train_dataset_range: int
    test_dataset_range: int


class LlamaGuardFineTuner:
    def __init__(
        self, wandb_project: str, wandb_entity: str, streamlit_mode: bool = False
    ):
        self.wandb_project = wandb_project
        self.wandb_entity = wandb_entity
        self.streamlit_mode = streamlit_mode

    def load_dataset(self, dataset_args: DatasetArgs):
        dataset = load_dataset(dataset_args.dataset_address)
        self.train_dataset = (
            dataset["train"]
            if dataset_args.train_dataset_range <= 0
            or dataset_args.train_dataset_range > len(dataset["train"])
            else dataset["train"].select(range(dataset_args.train_dataset_range))
        )
        self.test_dataset = (
            dataset["test"]
            if dataset_args.test_dataset_range <= 0
            or dataset_args.test_dataset_range > len(dataset["test"])
            else dataset["test"].select(range(dataset_args.test_dataset_range))
        )

    def load_model(self, model_name: str = "meta-llama/Prompt-Guard-86M"):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name).to(
            self.device
        )

    def show_dataset_sample(self):
        if self.streamlit_mode:
            st.markdown("### Train Dataset Sample")
            st.dataframe(self.train_dataset.to_pandas().head())
            st.markdown("### Test Dataset Sample")
            st.dataframe(self.test_dataset.to_pandas().head())

    def evaluate_batch(
        self,
        texts,
        batch_size: int = 32,
        positive_label: int = 2,
        temperature: float = 1.0,
        truncation: bool = True,
        max_length: int = 512,
    ) -> list[float]:
        self.model.eval()
        encoded_texts = self.tokenizer(
            texts,
            padding=True,
            truncation=truncation,
            max_length=max_length,
            return_tensors="pt",
        )
        dataset = torch.utils.data.TensorDataset(
            encoded_texts["input_ids"], encoded_texts["attention_mask"]
        )
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)

        scores = []
        progress_bar = (
            st.progress(0, text="Evaluating") if self.streamlit_mode else None
        )
        for i, batch in track(
            enumerate(data_loader), description="Evaluating", total=len(data_loader)
        ):
            input_ids, attention_mask = [b.to(self.device) for b in batch]
            with torch.no_grad():
                logits = self.model(
                    input_ids=input_ids, attention_mask=attention_mask
                ).logits
            scaled_logits = logits / temperature
            probabilities = F.softmax(scaled_logits, dim=-1)
            positive_class_probabilities = (
                probabilities[:, positive_label].cpu().numpy()
            )
            scores.extend(positive_class_probabilities)
            if progress_bar:
                progress_percentage = (i + 1) * 100 // len(data_loader)
                progress_bar.progress(
                    progress_percentage,
                    text=f"Evaluating batch {i + 1}/{len(data_loader)}",
                )

        return scores

    def visualize_roc_curve(self, test_scores: list[float]):
        test_labels = [int(elt) for elt in self.test_dataset["label"]]
        fpr, tpr, _ = roc_curve(test_labels, test_scores)
        roc_auc = roc_auc_score(test_labels, test_scores)
        fig = go.Figure()
        fig.add_trace(
            go.Scatter(
                x=fpr,
                y=tpr,
                mode="lines",
                name=f"ROC curve (area = {roc_auc:.3f})",
                line=dict(color="darkorange", width=2),
            )
        )
        fig.add_trace(
            go.Scatter(
                x=[0, 1],
                y=[0, 1],
                mode="lines",
                name="Random Guess",
                line=dict(color="navy", width=2, dash="dash"),
            )
        )
        fig.update_layout(
            title="Receiver Operating Characteristic",
            xaxis_title="False Positive Rate",
            yaxis_title="True Positive Rate",
            xaxis=dict(range=[0.0, 1.0]),
            yaxis=dict(range=[0.0, 1.05]),
            legend=dict(x=0.8, y=0.2),
        )
        if self.streamlit_mode:
            st.plotly_chart(fig)
        else:
            fig.show()

    def visualize_score_distribution(self, scores: list[float]):
        test_labels = [int(elt) for elt in self.test_dataset["label"]]
        positive_scores = [scores[i] for i in range(500) if test_labels[i] == 1]
        negative_scores = [scores[i] for i in range(500) if test_labels[i] == 0]
        fig = go.Figure()
        fig.add_trace(
            go.Histogram(
                x=positive_scores,
                histnorm="probability density",
                name="Positive",
                marker_color="darkblue",
                opacity=0.75,
            )
        )
        fig.add_trace(
            go.Histogram(
                x=negative_scores,
                histnorm="probability density",
                name="Negative",
                marker_color="darkred",
                opacity=0.75,
            )
        )
        fig.update_layout(
            title="Score Distribution for Positive and Negative Examples",
            xaxis_title="Score",
            yaxis_title="Density",
            barmode="overlay",
            legend_title="Scores",
        )
        if self.streamlit_mode:
            st.plotly_chart(fig)
        else:
            fig.show()

    def evaluate_model(
        self,
        batch_size: int = 32,
        positive_label: int = 2,
        temperature: float = 3.0,
        truncation: bool = True,
        max_length: int = 512,
    ):
        test_scores = self.evaluate_batch(
            self.test_dataset["text"],
            batch_size=batch_size,
            positive_label=positive_label,
            temperature=temperature,
            truncation=truncation,
            max_length=max_length,
        )
        self.visualize_roc_curve(test_scores)
        self.visualize_score_distribution(test_scores)
        return test_scores

    def collate_fn(self, batch):
        texts = [item["text"] for item in batch]
        labels = torch.tensor([int(item["label"]) for item in batch])
        encodings = self.tokenizer(
            texts, padding=True, truncation=True, max_length=512, return_tensors="pt"
        )
        return encodings.input_ids, encodings.attention_mask, labels

    def train(self, batch_size: int = 32, lr: float = 5e-6, num_classes: int = 2):
        wandb.init(
            project=self.wandb_project,
            entity=self.wandb_entity,
            name=f"{self.model_name}-{self.dataset_name}",
        )
        self.model.classifier = nn.Linear(
            self.model.classifier.in_features, num_classes
        )
        self.model.num_labels = num_classes
        self.model.train()
        optimizer = optim.AdamW(self.model.parameters(), lr=lr)
        data_loader = DataLoader(
            self.train_dataset,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=self.collate_fn,
        )
        progress_bar = st.progress(0, text="Training") if self.streamlit_mode else None
        for i, batch in track(
            enumerate(data_loader), description="Training", total=len(data_loader)
        ):
            input_ids, attention_mask, labels = [x.to(self.device) for x in batch]
            outputs = self.model(
                input_ids=input_ids, attention_mask=attention_mask, labels=labels
            )
            loss = outputs.loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            wandb.log({"loss": loss.item()})
            if progress_bar:
                progress_percentage = (i + 1) * 100 // len(data_loader)
                progress_bar.progress(
                    progress_percentage,
                    text=f"Training batch {i + 1}/{len(data_loader)}, Loss: {loss.item()}",
                )
        save_model(self.model, f"{self.model_name}-{self.dataset_name}.safetensors")
        wandb.log_model(f"{self.model_name}-{self.dataset_name}.safetensors")
        wandb.finish()
        os.remove(f"{self.model_name}-{self.dataset_name}.safetensors")