geekyrakshit's picture
fix: LlamaGuardFineTuner
a202ba5
raw
history blame
5.03 kB
import matplotlib.pyplot as plt
import streamlit as st
import torch
import torch.nn.functional as F
from datasets import load_dataset
from pydantic import BaseModel
from rich.progress import track
from sklearn.metrics import roc_auc_score, roc_curve
from transformers import AutoModelForSequenceClassification, AutoTokenizer
class DatasetArgs(BaseModel):
dataset_address: str
train_dataset_range: int
test_dataset_range: int
class LlamaGuardFineTuner:
def __init__(self, streamlit_mode: bool = False):
self.streamlit_mode = streamlit_mode
def load_dataset(self, dataset_args: DatasetArgs):
dataset = load_dataset(dataset_args.dataset_address)
self.train_dataset = (
dataset["train"]
if dataset_args.train_dataset_range <= 0
or dataset_args.train_dataset_range > len(dataset["train"])
else dataset["train"].select(range(dataset_args.train_dataset_range))
)
self.test_dataset = (
dataset["test"]
if dataset_args.test_dataset_range <= 0
or dataset_args.test_dataset_range > len(dataset["test"])
else dataset["test"].select(range(dataset_args.test_dataset_range))
)
def load_model(self, model_name: str = "meta-llama/Prompt-Guard-86M"):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name).to(
self.device
)
def show_dataset_sample(self):
if self.streamlit_mode:
st.markdown("### Train Dataset Sample")
st.dataframe(self.train_dataset.to_pandas().head())
st.markdown("### Test Dataset Sample")
st.dataframe(self.test_dataset.to_pandas().head())
def evaluate_batch(
self,
texts,
batch_size: int = 32,
positive_label: int = 2,
temperature: float = 1.0,
truncation: bool = True,
max_length: int = 512,
) -> list[float]:
self.model.eval()
encoded_texts = self.tokenizer(
texts,
padding=True,
truncation=truncation,
max_length=max_length,
return_tensors="pt",
)
dataset = torch.utils.data.TensorDataset(
encoded_texts["input_ids"], encoded_texts["attention_mask"]
)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
scores = []
progress_bar = (
st.progress(0, text="Evaluating") if self.streamlit_mode else None
)
for i, batch in track(
enumerate(data_loader), description="Evaluating", total=len(data_loader)
):
input_ids, attention_mask = [b.to(self.device) for b in batch]
with torch.no_grad():
logits = self.model(
input_ids=input_ids, attention_mask=attention_mask
).logits
scaled_logits = logits / temperature
probabilities = F.softmax(scaled_logits, dim=-1)
positive_class_probabilities = (
probabilities[:, positive_label].cpu().numpy()
)
scores.extend(positive_class_probabilities)
if progress_bar:
progress_percentage = (i + 1) * 100 // len(data_loader)
progress_bar.progress(
progress_percentage,
text=f"Evaluating batch {i + 1}/{len(data_loader)}",
)
return scores
def visualize_roc_curve(self, test_scores: list[float]):
plt.figure(figsize=(8, 6))
test_labels = [int(elt) for elt in self.test_dataset["label"]]
fpr, tpr, _ = roc_curve(test_labels, test_scores)
roc_auc = roc_auc_score(test_labels, test_scores)
plt.plot(
fpr,
tpr,
color="darkorange",
lw=2,
label=f"ROC curve (area = {roc_auc:.3f})",
)
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic")
plt.legend(loc="lower right")
if self.streamlit_mode:
st.pyplot(plt)
else:
plt.show()
def evaluate_model(
self,
batch_size: int = 32,
positive_label: int = 2,
temperature: float = 3.0,
truncation: bool = True,
max_length: int = 512,
):
test_scores = self.evaluate_batch(
self.test_dataset["text"],
batch_size=batch_size,
positive_label=positive_label,
temperature=temperature,
truncation=truncation,
max_length=max_length,
)
self.visualize_roc_curve(test_scores)
return test_scores