import os import plotly.graph_objects as go import streamlit as st import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import wandb from datasets import load_dataset from pydantic import BaseModel from rich.progress import track from safetensors.torch import save_model from sklearn.metrics import roc_auc_score, roc_curve from torch.utils.data import DataLoader from transformers import AutoModelForSequenceClassification, AutoTokenizer class DatasetArgs(BaseModel): dataset_address: str train_dataset_range: int test_dataset_range: int class LlamaGuardFineTuner: def __init__( self, wandb_project: str, wandb_entity: str, streamlit_mode: bool = False ): self.wandb_project = wandb_project self.wandb_entity = wandb_entity self.streamlit_mode = streamlit_mode def load_dataset(self, dataset_args: DatasetArgs): dataset = load_dataset(dataset_args.dataset_address) self.train_dataset = ( dataset["train"] if dataset_args.train_dataset_range <= 0 or dataset_args.train_dataset_range > len(dataset["train"]) else dataset["train"].select(range(dataset_args.train_dataset_range)) ) self.test_dataset = ( dataset["test"] if dataset_args.test_dataset_range <= 0 or dataset_args.test_dataset_range > len(dataset["test"]) else dataset["test"].select(range(dataset_args.test_dataset_range)) ) def load_model(self, model_name: str = "meta-llama/Prompt-Guard-86M"): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model_name = model_name self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForSequenceClassification.from_pretrained(model_name).to( self.device ) def show_dataset_sample(self): if self.streamlit_mode: st.markdown("### Train Dataset Sample") st.dataframe(self.train_dataset.to_pandas().head()) st.markdown("### Test Dataset Sample") st.dataframe(self.test_dataset.to_pandas().head()) def evaluate_batch( self, texts, batch_size: int = 32, positive_label: int = 2, temperature: float = 1.0, truncation: bool = True, max_length: int = 512, ) -> list[float]: self.model.eval() encoded_texts = self.tokenizer( texts, padding=True, truncation=truncation, max_length=max_length, return_tensors="pt", ) dataset = torch.utils.data.TensorDataset( encoded_texts["input_ids"], encoded_texts["attention_mask"] ) data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size) scores = [] progress_bar = ( st.progress(0, text="Evaluating") if self.streamlit_mode else None ) for i, batch in track( enumerate(data_loader), description="Evaluating", total=len(data_loader) ): input_ids, attention_mask = [b.to(self.device) for b in batch] with torch.no_grad(): logits = self.model( input_ids=input_ids, attention_mask=attention_mask ).logits scaled_logits = logits / temperature probabilities = F.softmax(scaled_logits, dim=-1) positive_class_probabilities = ( probabilities[:, positive_label].cpu().numpy() ) scores.extend(positive_class_probabilities) if progress_bar: progress_percentage = (i + 1) * 100 // len(data_loader) progress_bar.progress( progress_percentage, text=f"Evaluating batch {i + 1}/{len(data_loader)}", ) return scores def visualize_roc_curve(self, test_scores: list[float]): test_labels = [int(elt) for elt in self.test_dataset["label"]] fpr, tpr, _ = roc_curve(test_labels, test_scores) roc_auc = roc_auc_score(test_labels, test_scores) fig = go.Figure() fig.add_trace( go.Scatter( x=fpr, y=tpr, mode="lines", name=f"ROC curve (area = {roc_auc:.3f})", line=dict(color="darkorange", width=2), ) ) fig.add_trace( go.Scatter( x=[0, 1], y=[0, 1], mode="lines", name="Random Guess", line=dict(color="navy", width=2, dash="dash"), ) ) fig.update_layout( title="Receiver Operating Characteristic", xaxis_title="False Positive Rate", yaxis_title="True Positive Rate", xaxis=dict(range=[0.0, 1.0]), yaxis=dict(range=[0.0, 1.05]), legend=dict(x=0.8, y=0.2), ) if self.streamlit_mode: st.plotly_chart(fig) else: fig.show() def visualize_score_distribution(self, scores: list[float]): test_labels = [int(elt) for elt in self.test_dataset["label"]] positive_scores = [scores[i] for i in range(500) if test_labels[i] == 1] negative_scores = [scores[i] for i in range(500) if test_labels[i] == 0] fig = go.Figure() fig.add_trace( go.Histogram( x=positive_scores, histnorm="probability density", name="Positive", marker_color="darkblue", opacity=0.75, ) ) fig.add_trace( go.Histogram( x=negative_scores, histnorm="probability density", name="Negative", marker_color="darkred", opacity=0.75, ) ) fig.update_layout( title="Score Distribution for Positive and Negative Examples", xaxis_title="Score", yaxis_title="Density", barmode="overlay", legend_title="Scores", ) if self.streamlit_mode: st.plotly_chart(fig) else: fig.show() def evaluate_model( self, batch_size: int = 32, positive_label: int = 2, temperature: float = 3.0, truncation: bool = True, max_length: int = 512, ): test_scores = self.evaluate_batch( self.test_dataset["text"], batch_size=batch_size, positive_label=positive_label, temperature=temperature, truncation=truncation, max_length=max_length, ) self.visualize_roc_curve(test_scores) self.visualize_score_distribution(test_scores) return test_scores def collate_fn(self, batch): texts = [item["text"] for item in batch] labels = torch.tensor([int(item["label"]) for item in batch]) encodings = self.tokenizer( texts, padding=True, truncation=True, max_length=512, return_tensors="pt" ) return encodings.input_ids, encodings.attention_mask, labels def train(self, batch_size: int = 32, lr: float = 5e-6, num_classes: int = 2): wandb.init( project=self.wandb_project, entity=self.wandb_entity, name=f"{self.model_name}-{self.dataset_name}", ) self.model.classifier = nn.Linear( self.model.classifier.in_features, num_classes ) self.model.num_labels = num_classes self.model.train() optimizer = optim.AdamW(self.model.parameters(), lr=lr) data_loader = DataLoader( self.train_dataset, batch_size=batch_size, shuffle=True, collate_fn=self.collate_fn, ) progress_bar = st.progress(0, text="Training") if self.streamlit_mode else None for i, batch in track( enumerate(data_loader), description="Training", total=len(data_loader) ): input_ids, attention_mask, labels = [x.to(self.device) for x in batch] outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, labels=labels ) loss = outputs.loss optimizer.zero_grad() loss.backward() optimizer.step() wandb.log({"loss": loss.item()}) if progress_bar: progress_percentage = (i + 1) * 100 // len(data_loader) progress_bar.progress( progress_percentage, text=f"Training batch {i + 1}/{len(data_loader)}, Loss: {loss.item()}", ) save_model(self.model, f"{self.model_name}-{self.dataset_name}.safetensors") wandb.log_model(f"{self.model_name}-{self.dataset_name}.safetensors") wandb.finish() os.remove(f"{self.model_name}-{self.dataset_name}.safetensors")