BioMike commited on
Commit
4875d48
·
verified ·
1 Parent(s): 4964139

Upload 3 files

Browse files
Files changed (3) hide show
  1. src/config.py +33 -0
  2. src/data_processing.py +68 -0
  3. src/model.py +154 -0
src/config.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from transformers.utils import logging
3
+
4
+ logger = logging.get_logger(__name__)
5
+
6
+ class ClipSegMultiClassConfig(PretrainedConfig):
7
+ model_type = "clipseg-multiclass"
8
+ is_composition = False
9
+
10
+ def __init__(
11
+ self,
12
+ class_labels=None,
13
+ label2color=None,
14
+ model="CIDAS/clipseg-rd64-refined",
15
+ image_size=352,
16
+ **kwargs
17
+ ):
18
+ super().__init__(**kwargs)
19
+
20
+ self.class_labels = class_labels or []
21
+ self.num_classes = len(self.class_labels)
22
+
23
+ self.label2color = label2color or {
24
+ i: [
25
+ int(255 * (i / max(1, self.num_classes - 1))),
26
+ 0,
27
+ 255 - int(255 * (i / max(1, self.num_classes - 1)))
28
+ ]
29
+ for i in range(self.num_classes)
30
+ }
31
+
32
+ self.model = model
33
+ self.image_size = image_size
src/data_processing.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ import numpy as np
6
+
7
+ class SingleClassSegmentationDataset(Dataset):
8
+ def __init__(self, dataset, class_labels, image_size=352, transform=None):
9
+
10
+ self.items = dataset
11
+ self.class_labels = class_labels
12
+ self.image_size = image_size
13
+ self.transform = transform
14
+
15
+ def __len__(self):
16
+ return len(self.items)
17
+
18
+ def __getitem__(self, idx):
19
+ item = self.items[idx]
20
+
21
+ image = Image.open(item["img_path"]).convert("RGB")
22
+ mask = Image.open(item["mask_path"]).convert("L")
23
+ class_name = item["label"]
24
+
25
+ class_index = self.class_labels.index(class_name)
26
+ background_index = 0
27
+
28
+ mask_np = np.array(mask) > 0
29
+ final_mask = np.full(mask_np.shape, background_index, dtype=np.uint8)
30
+ final_mask[mask_np] = class_index
31
+
32
+ image = image.resize((self.image_size, self.image_size), Image.BILINEAR)
33
+ final_mask = Image.fromarray(final_mask).resize((self.image_size, self.image_size), Image.NEAREST)
34
+
35
+ if self.transform:
36
+ image, final_mask = self.transform(image, final_mask)
37
+
38
+ return {
39
+ "image": image,
40
+ "labels": torch.from_numpy(np.array(final_mask)).long()
41
+ }
42
+
43
+
44
+ class SegmentationCollator:
45
+ def __init__(self, processor, class_labels):
46
+ self.processor = processor
47
+ self.class_labels = class_labels
48
+
49
+ def __call__(self, batch):
50
+ images = [item["image"] for item in batch]
51
+ labels = [item["labels"] for item in batch]
52
+
53
+ prompts = self.class_labels * len(images)
54
+ expanded_images = [img for img in images for _ in self.class_labels]
55
+
56
+ inputs = self.processor(
57
+ images=expanded_images,
58
+ text=prompts,
59
+ return_tensors="pt",
60
+ padding=True,
61
+ truncation=True
62
+ )
63
+
64
+ return {
65
+ "pixel_values": inputs["pixel_values"],
66
+ "input_ids": inputs["input_ids"],
67
+ "labels": torch.stack(labels)
68
+ }
src/model.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Union, List
3
+ from PIL import Image
4
+ import PIL
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from transformers import (
9
+ PreTrainedModel,
10
+ CLIPSegProcessor,
11
+ CLIPSegForImageSegmentation,
12
+ )
13
+ from transformers.modeling_outputs import ModelOutput
14
+
15
+ from .config import ClipSegMultiClassConfig
16
+ from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
17
+ import numpy as np
18
+ from torch.utils.data import DataLoader
19
+ from collections import defaultdict
20
+
21
+ def flatten_outputs(preds, targets, num_classes):
22
+ """Flatten predictions and targets to 1D arrays, filter ignored labels."""
23
+ preds = preds.cpu().numpy().reshape(-1)
24
+ targets = targets.cpu().numpy().reshape(-1)
25
+
26
+ mask = (targets >= 0) & (targets < num_classes)
27
+ return preds[mask], targets[mask]
28
+
29
+ def compute_metrics(all_preds, all_targets, num_classes, average="macro"):
30
+ y_pred = np.concatenate(all_preds)
31
+ y_true = np.concatenate(all_targets)
32
+
33
+ metrics = {
34
+ "accuracy": accuracy_score(y_true, y_pred),
35
+ "precision": precision_score(y_true, y_pred, average=average, zero_division=0),
36
+ "recall": recall_score(y_true, y_pred, average=average, zero_division=0),
37
+ "f1": f1_score(y_true, y_pred, average=average, zero_division=0),
38
+ }
39
+
40
+ return metrics
41
+
42
+
43
+ @dataclass
44
+ class ClipSegMultiClassOutput(ModelOutput):
45
+ loss: Optional[torch.FloatTensor] = None
46
+ logits: Optional[torch.FloatTensor] = None
47
+ predictions: Optional[torch.LongTensor] = None
48
+
49
+
50
+ class ClipSegMultiClassModel(PreTrainedModel):
51
+ config_class = ClipSegMultiClassConfig
52
+ base_model_prefix = "clipseg_multiclass"
53
+
54
+ def __init__(self, config: ClipSegMultiClassConfig):
55
+ super().__init__(config)
56
+
57
+ self.config = config
58
+ self.class_labels = config.class_labels
59
+ self.num_classes = config.num_classes
60
+ self.processor = CLIPSegProcessor.from_pretrained(config.model)
61
+ self.clipseg = CLIPSegForImageSegmentation.from_pretrained(config.model)
62
+ self.loss_fct = nn.CrossEntropyLoss()
63
+
64
+ def forward(
65
+ self,
66
+ pixel_values: Optional[torch.Tensor] = None,
67
+ input_ids: Optional[torch.Tensor] = None,
68
+ labels: Optional[torch.Tensor] = None,
69
+ **kwargs
70
+ ) -> ClipSegMultiClassOutput:
71
+
72
+ if pixel_values is None or input_ids is None:
73
+ raise ValueError("Both `pixel_values` and `input_ids` must be provided.")
74
+
75
+ pixel_values = pixel_values.to(self.device)
76
+ input_ids = input_ids.to(self.device)
77
+
78
+ outputs = self.clipseg(pixel_values=pixel_values, input_ids=input_ids)
79
+ raw_logits = outputs.logits # shape: [B * C, H, W]
80
+
81
+ B = raw_logits.shape[0] // self.num_classes
82
+ C = self.num_classes
83
+ H, W = raw_logits.shape[-2:]
84
+
85
+ logits = raw_logits.view(B, C, H, W) # [B, C, H, W]
86
+ pred = torch.argmax(logits, dim=1) # [B, H, W]
87
+
88
+ loss = self.loss_fct(logits, labels.long()) if labels is not None else None
89
+
90
+ return ClipSegMultiClassOutput(
91
+ loss=loss,
92
+ logits=logits,
93
+ predictions=pred
94
+ )
95
+
96
+ @torch.no_grad()
97
+ def predict(self, images: Union[List, "PIL.Image.Image"]) -> torch.Tensor:
98
+ self.eval()
99
+ if isinstance(images, Image.Image):
100
+ images = [images]
101
+
102
+ inputs = self.processor(
103
+ images=[img for img in images for _ in self.class_labels],
104
+ text=self.class_labels * len(images),
105
+ return_tensors="pt",
106
+ padding=True,
107
+ truncation=True
108
+ ).to(self.device)
109
+
110
+ output = self.forward(
111
+ pixel_values=inputs["pixel_values"],
112
+ input_ids=inputs["input_ids"]
113
+ )
114
+ return output.predictions
115
+
116
+ def evaluate(self, dataloader: torch.utils.data.DataLoader) -> dict:
117
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
118
+ import numpy as np
119
+
120
+ self.eval()
121
+
122
+ all_preds = []
123
+ all_targets = []
124
+
125
+ with torch.no_grad():
126
+ for batch in dataloader:
127
+ pixel_values = batch["pixel_values"].to(self.device) # [B * C, 3, H, W]
128
+ input_ids = batch["input_ids"].to(self.device) # [B * C, T]
129
+ labels = batch["labels"].to(self.device) # [B, H, W]
130
+
131
+ outputs = self.forward(pixel_values=pixel_values, input_ids=input_ids)
132
+ preds = outputs.predictions # [B, H, W]
133
+
134
+ for pred, label in zip(preds, labels):
135
+ pred = pred.cpu().flatten()
136
+ label = label.cpu().flatten()
137
+
138
+ mask = label != 0
139
+ pred = pred[mask]
140
+ label = label[mask]
141
+
142
+ all_preds.append(pred)
143
+ all_targets.append(label)
144
+
145
+ y_pred = torch.cat(all_preds).numpy()
146
+ y_true = torch.cat(all_targets).numpy()
147
+
148
+ return {
149
+ "accuracy": accuracy_score(y_true, y_pred),
150
+ "precision": precision_score(y_true, y_pred, average="macro", zero_division=0),
151
+ "recall": recall_score(y_true, y_pred, average="macro", zero_division=0),
152
+ "f1": f1_score(y_true, y_pred, average="macro", zero_division=0),
153
+ }
154
+