basic setup details and stuffffff
Browse files
readme.md
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
datasets:
|
4 |
+
- vector-institute/newsmediabias-plus
|
5 |
+
language:
|
6 |
+
- en
|
7 |
+
metrics:
|
8 |
+
- f1(0.698616087436676)
|
9 |
+
- precision(0.6369158625602722)
|
10 |
+
- recall(0.7735527753829956)
|
11 |
+
- accuracy(0.6247606873512268)
|
12 |
+
library_name: transformers
|
13 |
+
co2_eq_emissions:
|
14 |
+
emissions: 8
|
15 |
+
source: Code Carbon
|
16 |
+
training_type: fine-tuning
|
17 |
+
geographical_location: Albany, New York
|
18 |
+
hardware_used: T4
|
19 |
+
base_model:
|
20 |
+
- google-bert/bert-base-uncased
|
21 |
+
- microsoft/resnet-34
|
22 |
+
pipeline_tag: custom
|
23 |
+
tags:
|
24 |
+
- Social Bias
|
25 |
+
- Multimodal
|
26 |
+
|
27 |
+
---
|
28 |
+
|
29 |
+
# Multimodal Bias Classifier
|
30 |
+
|
31 |
+
This model is a multimodal classifier that combines text and image inputs to detect potential bias in content. It uses a BERT-based text encoder and a ResNet-34 image encoder, which are fused for classification purposes. A contrastive learning approach was used during training, leveraging CLIP embeddings as guidance to align the text and image representations.
|
32 |
+
|
33 |
+
## Model Details
|
34 |
+
|
35 |
+
- **Text Encoder**: BERT (`bert-base-uncased`)
|
36 |
+
- **Image Encoder**: ResNet-34 (`microsoft/resnet-34`)
|
37 |
+
- **Projection Dimensionality**: 768
|
38 |
+
- **Fusion Method**: Concatenation (default), Alignment, or Cosine Similarity
|
39 |
+
- **Loss Functions**: Binary Cross-Entropy for classification, Cosine Embedding Loss for contrastive learning
|
40 |
+
- **Purpose**: Detecting bias in multimodal content (text + image)
|
41 |
+
|
42 |
+
## Training
|
43 |
+
|
44 |
+
The model was trained using a multimodal dataset with labeled instances of biased and unbiased content. The training process incorporated both classification and contrastive loss to help align the text and image representations in a shared latent space.
|
45 |
+
|
46 |
+
### Training Losses
|
47 |
+
- **Classification Loss**: Binary Cross-Entropy (BCEWithLogitsLoss) to classify content as biased or unbiased.
|
48 |
+
- **Contrastive Loss**: CosineEmbeddingLoss, which uses CLIP text and image embeddings as ground truth guidance to align text and image features.
|
49 |
+
|
50 |
+
### Excluding CLIP
|
51 |
+
While the CLIP model was used during training to guide the alignment of the image and text embeddings, the final model does **not** retain CLIP weights, as it is designed to function independently once training is complete.
|
52 |
+
|
53 |
+
## How to Load the Model
|
54 |
+
|
55 |
+
You can load this model for bias classification by following the code below. The model accepts text input and an image input, processing them through BERT and ResNet-34 encoders, respectively. The final prediction indicates whether the content is likely biased or unbiased.
|
56 |
+
|
57 |
+
```python
|
58 |
+
import torch
|
59 |
+
from torch import nn
|
60 |
+
from transformers import AutoModel
|
61 |
+
from huggingface_hub import hf_hub_download
|
62 |
+
from typing import Literal
|
63 |
+
import json
|
64 |
+
|
65 |
+
class MultimodalClassifier(nn.Module):
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
text_encoder_id_or_path: str,
|
69 |
+
image_encoder_id_or_path: str,
|
70 |
+
projection_dim: int,
|
71 |
+
fusion_method: Literal["concat", "align", "cosine_similarity"] = "concat",
|
72 |
+
proj_dropout: float = 0.1,
|
73 |
+
fusion_dropout: float = 0.1,
|
74 |
+
num_classes: int = 1,
|
75 |
+
) -> None:
|
76 |
+
super().__init__()
|
77 |
+
|
78 |
+
self.fusion_method = fusion_method
|
79 |
+
self.projection_dim = projection_dim
|
80 |
+
self.num_classes = num_classes
|
81 |
+
|
82 |
+
##### Text Encoder
|
83 |
+
self.text_encoder = AutoModel.from_pretrained(text_encoder_id_or_path)
|
84 |
+
self.text_projection = nn.Sequential(
|
85 |
+
nn.Linear(self.text_encoder.config.hidden_size, self.projection_dim),
|
86 |
+
nn.Dropout(proj_dropout),
|
87 |
+
)
|
88 |
+
|
89 |
+
##### Image Encoder (using ResNet34 from AutoModel with timm)
|
90 |
+
self.image_encoder = AutoModel.from_pretrained(image_encoder_id_or_path, trust_remote_code=True)
|
91 |
+
self.image_encoder.classifier = nn.Identity() # rm the classification head
|
92 |
+
self.image_projection = nn.Sequential(
|
93 |
+
nn.Linear(512, self.projection_dim),
|
94 |
+
nn.Dropout(proj_dropout),
|
95 |
+
)
|
96 |
+
|
97 |
+
##### Fusion Layer
|
98 |
+
fusion_input_dim = self.projection_dim * 2 if fusion_method == "concat" else self.projection_dim
|
99 |
+
self.fusion_layer = nn.Sequential(
|
100 |
+
nn.Dropout(fusion_dropout),
|
101 |
+
nn.Linear(fusion_input_dim, self.projection_dim),
|
102 |
+
nn.GELU(),
|
103 |
+
nn.Dropout(fusion_dropout),
|
104 |
+
)
|
105 |
+
|
106 |
+
##### Classification Layer
|
107 |
+
self.classifier = nn.Linear(self.projection_dim, self.num_classes)
|
108 |
+
|
109 |
+
def forward(self, pixel_values: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
110 |
+
##### Text Encoder Projection #####
|
111 |
+
full_text_features = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True).last_hidden_state
|
112 |
+
full_text_features = full_text_features[:, 0, :] # using cls token
|
113 |
+
full_text_features = self.text_projection(full_text_features)
|
114 |
+
|
115 |
+
##### Image Encoder Projection #####
|
116 |
+
resnet_image_features = self.image_encoder(pixel_values=pixel_values).last_hidden_state
|
117 |
+
|
118 |
+
# global average pooling for resent image features (bad idea? dim problems)
|
119 |
+
resnet_image_features = resnet_image_features.mean(dim=[-2, -1])
|
120 |
+
resnet_image_features = self.image_projection(resnet_image_features)
|
121 |
+
|
122 |
+
##### Fusion and Classification #####
|
123 |
+
if self.fusion_method == "concat":
|
124 |
+
fused_features = torch.cat([full_text_features, resnet_image_features], dim=-1)
|
125 |
+
else:
|
126 |
+
fused_features = full_text_features * resnet_image_features # don't think this works atm (should be dot prod)
|
127 |
+
|
128 |
+
# fusion and classifier layers
|
129 |
+
fused_features = self.fusion_layer(fused_features)
|
130 |
+
classification_output = self.classifier(fused_features)
|
131 |
+
|
132 |
+
return classification_output
|
133 |
+
|
134 |
+
def load_model():
|
135 |
+
config_path = hf_hub_download(repo_id="maximuspowers/multimodal-bias-classifier", filename="config.json")
|
136 |
+
with open(config_path, "r") as f:
|
137 |
+
config = json.load(f)
|
138 |
+
|
139 |
+
model = MultimodalClassifier(
|
140 |
+
text_encoder_id_or_path=config["text_encoder_id_or_path"],
|
141 |
+
image_encoder_id_or_path="microsoft/resnet-34",
|
142 |
+
projection_dim=config["projection_dim"],
|
143 |
+
fusion_method=config["fusion_method"],
|
144 |
+
proj_dropout=config["proj_dropout"],
|
145 |
+
fusion_dropout=config["fusion_dropout"],
|
146 |
+
num_classes=config["num_classes"]
|
147 |
+
)
|
148 |
+
|
149 |
+
model_weights_path = hf_hub_download(repo_id="maximuspowers/multimodal-bias-classifier", filename="model_weights.pth")
|
150 |
+
checkpoint = torch.load(model_weights_path, map_location=torch.device('cpu'))
|
151 |
+
model.load_state_dict(checkpoint, strict=False)
|
152 |
+
|
153 |
+
return model
|
154 |
+
```
|
155 |
+
|
156 |
+
|
157 |
+
```python
|
158 |
+
import torch
|
159 |
+
from transformers import AutoTokenizer
|
160 |
+
from PIL import Image
|
161 |
+
import requests
|
162 |
+
from torchvision import transforms
|
163 |
+
|
164 |
+
model = load_model()
|
165 |
+
model.eval()
|
166 |
+
|
167 |
+
# text input
|
168 |
+
text_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
169 |
+
sample_text = "This is a sample sentence for bias classification."
|
170 |
+
text_inputs = text_tokenizer(
|
171 |
+
sample_text,
|
172 |
+
return_tensors="pt",
|
173 |
+
padding="max_length",
|
174 |
+
truncation=True,
|
175 |
+
max_length=512
|
176 |
+
)
|
177 |
+
|
178 |
+
# image input
|
179 |
+
image = Image.open("./random_image.jpg").convert("RGB")
|
180 |
+
image_transform = transforms.Compose([
|
181 |
+
transforms.Resize((224, 224)),
|
182 |
+
transforms.ToTensor(),
|
183 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
184 |
+
])
|
185 |
+
image_input = image_transform(image).unsqueeze(0) # add batch dim
|
186 |
+
|
187 |
+
# run
|
188 |
+
with torch.no_grad():
|
189 |
+
classification_output = model(
|
190 |
+
pixel_values=image_input,
|
191 |
+
input_ids=text_inputs["input_ids"],
|
192 |
+
attention_mask=text_inputs["attention_mask"]
|
193 |
+
)
|
194 |
+
predicted_class = torch.sigmoid(classification_output).round().item()
|
195 |
+
print("Predicted class:", "Biased" if predicted_class == 1 else "Unbiased")
|
196 |
+
```
|