from dataclasses import dataclass from typing import Optional, Tuple, Union, List from PIL import Image import PIL import torch import torch.nn as nn import torch.nn.functional as F from transformers import ( PreTrainedModel, CLIPSegProcessor, CLIPSegForImageSegmentation, ) from transformers.modeling_outputs import ModelOutput from .config import ClipSegMultiClassConfig from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score import numpy as np from torch.utils.data import DataLoader from collections import defaultdict def flatten_outputs(preds, targets, num_classes): """Flatten predictions and targets to 1D arrays, filter ignored labels.""" preds = preds.cpu().numpy().reshape(-1) targets = targets.cpu().numpy().reshape(-1) mask = (targets >= 0) & (targets < num_classes) return preds[mask], targets[mask] def compute_metrics(all_preds, all_targets, num_classes, average="macro"): y_pred = np.concatenate(all_preds) y_true = np.concatenate(all_targets) metrics = { "accuracy": accuracy_score(y_true, y_pred), "precision": precision_score(y_true, y_pred, average=average, zero_division=0), "recall": recall_score(y_true, y_pred, average=average, zero_division=0), "f1": f1_score(y_true, y_pred, average=average, zero_division=0), } return metrics @dataclass class ClipSegMultiClassOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None predictions: Optional[torch.LongTensor] = None class ClipSegMultiClassModel(PreTrainedModel): config_class = ClipSegMultiClassConfig base_model_prefix = "clipseg_multiclass" def __init__(self, config: ClipSegMultiClassConfig): super().__init__(config) self.config = config self.class_labels = config.class_labels self.num_classes = config.num_classes self.processor = CLIPSegProcessor.from_pretrained(config.model) self.clipseg = CLIPSegForImageSegmentation.from_pretrained(config.model) self.loss_fct = nn.CrossEntropyLoss() def forward( self, pixel_values: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, **kwargs ) -> ClipSegMultiClassOutput: if pixel_values is None or input_ids is None: raise ValueError("Both `pixel_values` and `input_ids` must be provided.") pixel_values = pixel_values.to(self.device) input_ids = input_ids.to(self.device) outputs = self.clipseg(pixel_values=pixel_values, input_ids=input_ids) raw_logits = outputs.logits # shape: [B * C, H, W] B = raw_logits.shape[0] // self.num_classes C = self.num_classes H, W = raw_logits.shape[-2:] logits = raw_logits.view(B, C, H, W) # [B, C, H, W] pred = torch.argmax(logits, dim=1) # [B, H, W] loss = self.loss_fct(logits, labels.long()) if labels is not None else None return ClipSegMultiClassOutput( loss=loss, logits=logits, predictions=pred ) @torch.no_grad() def predict(self, images: Union[List, "PIL.Image.Image"]) -> torch.Tensor: self.eval() if isinstance(images, Image.Image): images = [images] inputs = self.processor( images=[img for img in images for _ in self.class_labels], text=self.class_labels * len(images), return_tensors="pt", padding=True, truncation=True ).to(self.device) output = self.forward( pixel_values=inputs["pixel_values"], input_ids=inputs["input_ids"] ) return output.predictions def evaluate(self, dataloader: torch.utils.data.DataLoader) -> dict: from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score import numpy as np self.eval() all_preds = [] all_targets = [] with torch.no_grad(): for batch in dataloader: pixel_values = batch["pixel_values"].to(self.device) # [B * C, 3, H, W] input_ids = batch["input_ids"].to(self.device) # [B * C, T] labels = batch["labels"].to(self.device) # [B, H, W] outputs = self.forward(pixel_values=pixel_values, input_ids=input_ids) preds = outputs.predictions # [B, H, W] for pred, label in zip(preds, labels): pred = pred.cpu().flatten() label = label.cpu().flatten() mask = label != 0 pred = pred[mask] label = label[mask] all_preds.append(pred) all_targets.append(label) y_pred = torch.cat(all_preds).numpy() y_true = torch.cat(all_targets).numpy() return { "accuracy": accuracy_score(y_true, y_pred), "precision": precision_score(y_true, y_pred, average="macro", zero_division=0), "recall": recall_score(y_true, y_pred, average="macro", zero_division=0), "f1": f1_score(y_true, y_pred, average="macro", zero_division=0), }