maximuspowers commited on
Commit
1f9027d
·
verified ·
1 Parent(s): 1422a23

basic setup details and stuffffff

Browse files
Files changed (1) hide show
  1. readme.md +196 -0
readme.md ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - vector-institute/newsmediabias-plus
5
+ language:
6
+ - en
7
+ metrics:
8
+ - f1(0.698616087436676)
9
+ - precision(0.6369158625602722)
10
+ - recall(0.7735527753829956)
11
+ - accuracy(0.6247606873512268)
12
+ library_name: transformers
13
+ co2_eq_emissions:
14
+ emissions: 8
15
+ source: Code Carbon
16
+ training_type: fine-tuning
17
+ geographical_location: Albany, New York
18
+ hardware_used: T4
19
+ base_model:
20
+ - google-bert/bert-base-uncased
21
+ - microsoft/resnet-34
22
+ pipeline_tag: custom
23
+ tags:
24
+ - Social Bias
25
+ - Multimodal
26
+
27
+ ---
28
+
29
+ # Multimodal Bias Classifier
30
+
31
+ This model is a multimodal classifier that combines text and image inputs to detect potential bias in content. It uses a BERT-based text encoder and a ResNet-34 image encoder, which are fused for classification purposes. A contrastive learning approach was used during training, leveraging CLIP embeddings as guidance to align the text and image representations.
32
+
33
+ ## Model Details
34
+
35
+ - **Text Encoder**: BERT (`bert-base-uncased`)
36
+ - **Image Encoder**: ResNet-34 (`microsoft/resnet-34`)
37
+ - **Projection Dimensionality**: 768
38
+ - **Fusion Method**: Concatenation (default), Alignment, or Cosine Similarity
39
+ - **Loss Functions**: Binary Cross-Entropy for classification, Cosine Embedding Loss for contrastive learning
40
+ - **Purpose**: Detecting bias in multimodal content (text + image)
41
+
42
+ ## Training
43
+
44
+ The model was trained using a multimodal dataset with labeled instances of biased and unbiased content. The training process incorporated both classification and contrastive loss to help align the text and image representations in a shared latent space.
45
+
46
+ ### Training Losses
47
+ - **Classification Loss**: Binary Cross-Entropy (BCEWithLogitsLoss) to classify content as biased or unbiased.
48
+ - **Contrastive Loss**: CosineEmbeddingLoss, which uses CLIP text and image embeddings as ground truth guidance to align text and image features.
49
+
50
+ ### Excluding CLIP
51
+ While the CLIP model was used during training to guide the alignment of the image and text embeddings, the final model does **not** retain CLIP weights, as it is designed to function independently once training is complete.
52
+
53
+ ## How to Load the Model
54
+
55
+ You can load this model for bias classification by following the code below. The model accepts text input and an image input, processing them through BERT and ResNet-34 encoders, respectively. The final prediction indicates whether the content is likely biased or unbiased.
56
+
57
+ ```python
58
+ import torch
59
+ from torch import nn
60
+ from transformers import AutoModel
61
+ from huggingface_hub import hf_hub_download
62
+ from typing import Literal
63
+ import json
64
+
65
+ class MultimodalClassifier(nn.Module):
66
+ def __init__(
67
+ self,
68
+ text_encoder_id_or_path: str,
69
+ image_encoder_id_or_path: str,
70
+ projection_dim: int,
71
+ fusion_method: Literal["concat", "align", "cosine_similarity"] = "concat",
72
+ proj_dropout: float = 0.1,
73
+ fusion_dropout: float = 0.1,
74
+ num_classes: int = 1,
75
+ ) -> None:
76
+ super().__init__()
77
+
78
+ self.fusion_method = fusion_method
79
+ self.projection_dim = projection_dim
80
+ self.num_classes = num_classes
81
+
82
+ ##### Text Encoder
83
+ self.text_encoder = AutoModel.from_pretrained(text_encoder_id_or_path)
84
+ self.text_projection = nn.Sequential(
85
+ nn.Linear(self.text_encoder.config.hidden_size, self.projection_dim),
86
+ nn.Dropout(proj_dropout),
87
+ )
88
+
89
+ ##### Image Encoder (using ResNet34 from AutoModel with timm)
90
+ self.image_encoder = AutoModel.from_pretrained(image_encoder_id_or_path, trust_remote_code=True)
91
+ self.image_encoder.classifier = nn.Identity() # rm the classification head
92
+ self.image_projection = nn.Sequential(
93
+ nn.Linear(512, self.projection_dim),
94
+ nn.Dropout(proj_dropout),
95
+ )
96
+
97
+ ##### Fusion Layer
98
+ fusion_input_dim = self.projection_dim * 2 if fusion_method == "concat" else self.projection_dim
99
+ self.fusion_layer = nn.Sequential(
100
+ nn.Dropout(fusion_dropout),
101
+ nn.Linear(fusion_input_dim, self.projection_dim),
102
+ nn.GELU(),
103
+ nn.Dropout(fusion_dropout),
104
+ )
105
+
106
+ ##### Classification Layer
107
+ self.classifier = nn.Linear(self.projection_dim, self.num_classes)
108
+
109
+ def forward(self, pixel_values: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
110
+ ##### Text Encoder Projection #####
111
+ full_text_features = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True).last_hidden_state
112
+ full_text_features = full_text_features[:, 0, :] # using cls token
113
+ full_text_features = self.text_projection(full_text_features)
114
+
115
+ ##### Image Encoder Projection #####
116
+ resnet_image_features = self.image_encoder(pixel_values=pixel_values).last_hidden_state
117
+
118
+ # global average pooling for resent image features (bad idea? dim problems)
119
+ resnet_image_features = resnet_image_features.mean(dim=[-2, -1])
120
+ resnet_image_features = self.image_projection(resnet_image_features)
121
+
122
+ ##### Fusion and Classification #####
123
+ if self.fusion_method == "concat":
124
+ fused_features = torch.cat([full_text_features, resnet_image_features], dim=-1)
125
+ else:
126
+ fused_features = full_text_features * resnet_image_features # don't think this works atm (should be dot prod)
127
+
128
+ # fusion and classifier layers
129
+ fused_features = self.fusion_layer(fused_features)
130
+ classification_output = self.classifier(fused_features)
131
+
132
+ return classification_output
133
+
134
+ def load_model():
135
+ config_path = hf_hub_download(repo_id="maximuspowers/multimodal-bias-classifier", filename="config.json")
136
+ with open(config_path, "r") as f:
137
+ config = json.load(f)
138
+
139
+ model = MultimodalClassifier(
140
+ text_encoder_id_or_path=config["text_encoder_id_or_path"],
141
+ image_encoder_id_or_path="microsoft/resnet-34",
142
+ projection_dim=config["projection_dim"],
143
+ fusion_method=config["fusion_method"],
144
+ proj_dropout=config["proj_dropout"],
145
+ fusion_dropout=config["fusion_dropout"],
146
+ num_classes=config["num_classes"]
147
+ )
148
+
149
+ model_weights_path = hf_hub_download(repo_id="maximuspowers/multimodal-bias-classifier", filename="model_weights.pth")
150
+ checkpoint = torch.load(model_weights_path, map_location=torch.device('cpu'))
151
+ model.load_state_dict(checkpoint, strict=False)
152
+
153
+ return model
154
+ ```
155
+
156
+
157
+ ```python
158
+ import torch
159
+ from transformers import AutoTokenizer
160
+ from PIL import Image
161
+ import requests
162
+ from torchvision import transforms
163
+
164
+ model = load_model()
165
+ model.eval()
166
+
167
+ # text input
168
+ text_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
169
+ sample_text = "This is a sample sentence for bias classification."
170
+ text_inputs = text_tokenizer(
171
+ sample_text,
172
+ return_tensors="pt",
173
+ padding="max_length",
174
+ truncation=True,
175
+ max_length=512
176
+ )
177
+
178
+ # image input
179
+ image = Image.open("./random_image.jpg").convert("RGB")
180
+ image_transform = transforms.Compose([
181
+ transforms.Resize((224, 224)),
182
+ transforms.ToTensor(),
183
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
184
+ ])
185
+ image_input = image_transform(image).unsqueeze(0) # add batch dim
186
+
187
+ # run
188
+ with torch.no_grad():
189
+ classification_output = model(
190
+ pixel_values=image_input,
191
+ input_ids=text_inputs["input_ids"],
192
+ attention_mask=text_inputs["attention_mask"]
193
+ )
194
+ predicted_class = torch.sigmoid(classification_output).round().item()
195
+ print("Predicted class:", "Biased" if predicted_class == 1 else "Unbiased")
196
+ ```