Abs6187 commited on
Commit
0c2f64c
·
verified ·
1 Parent(s): 40898b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -4
app.py CHANGED
@@ -1,11 +1,101 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer
4
  from PIL import Image
5
  from torchvision import transforms
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # Load model and tokenizer
8
- model = load_model(model_weights.pth)
9
  model.eval()
10
  text_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
11
 
@@ -26,7 +116,7 @@ def predict(image: Image.Image, text: str) -> str:
26
  truncation=True,
27
  max_length=512
28
  )
29
-
30
  # Process image input
31
  image_input = image_transform(image).unsqueeze(0) # Add batch dimension
32
 
@@ -38,7 +128,7 @@ def predict(image: Image.Image, text: str) -> str:
38
  attention_mask=text_inputs["attention_mask"]
39
  )
40
  predicted_class = torch.sigmoid(classification_output).round().item()
41
-
42
  return "Biased" if predicted_class == 1 else "Unbiased"
43
 
44
  # Gradio Interface
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModel
4
  from PIL import Image
5
  from torchvision import transforms
6
+ import json
7
+ from torch import nn
8
+ from typing import Literal
9
+
10
+ # Define Multimodal Classifier
11
+ class MultimodalClassifier(nn.Module):
12
+ def __init__(
13
+ self,
14
+ text_encoder_id_or_path: str,
15
+ image_encoder_id_or_path: str,
16
+ projection_dim: int,
17
+ fusion_method: Literal["concat", "align", "cosine_similarity"] = "concat",
18
+ proj_dropout: float = 0.1,
19
+ fusion_dropout: float = 0.1,
20
+ num_classes: int = 1,
21
+ ) -> None:
22
+ super().__init__()
23
+
24
+ self.fusion_method = fusion_method
25
+ self.projection_dim = projection_dim
26
+ self.num_classes = num_classes
27
+
28
+ # Text Encoder
29
+ self.text_encoder = AutoModel.from_pretrained(text_encoder_id_or_path)
30
+ self.text_projection = nn.Sequential(
31
+ nn.Linear(self.text_encoder.config.hidden_size, self.projection_dim),
32
+ nn.Dropout(proj_dropout),
33
+ )
34
+
35
+ # Image Encoder
36
+ self.image_encoder = AutoModel.from_pretrained(image_encoder_id_or_path, trust_remote_code=True)
37
+ self.image_encoder.classifier = nn.Identity() # Remove classification head
38
+ self.image_projection = nn.Sequential(
39
+ nn.Linear(512, self.projection_dim),
40
+ nn.Dropout(proj_dropout),
41
+ )
42
+
43
+ # Fusion Layer
44
+ fusion_input_dim = self.projection_dim * 2 if fusion_method == "concat" else self.projection_dim
45
+ self.fusion_layer = nn.Sequential(
46
+ nn.Dropout(fusion_dropout),
47
+ nn.Linear(fusion_input_dim, self.projection_dim),
48
+ nn.GELU(),
49
+ nn.Dropout(fusion_dropout),
50
+ )
51
+
52
+ # Classification Layer
53
+ self.classifier = nn.Linear(self.projection_dim, self.num_classes)
54
+
55
+ def forward(self, pixel_values: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
56
+ # Text Encoder Projection
57
+ full_text_features = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True).last_hidden_state
58
+ full_text_features = full_text_features[:, 0, :] # CLS token
59
+ full_text_features = self.text_projection(full_text_features)
60
+
61
+ # Image Encoder Projection
62
+ resnet_image_features = self.image_encoder(pixel_values=pixel_values).last_hidden_state
63
+ resnet_image_features = resnet_image_features.mean(dim=[-2, -1]) # Global average pooling
64
+ resnet_image_features = self.image_projection(resnet_image_features)
65
+
66
+ # Fusion
67
+ if self.fusion_method == "concat":
68
+ fused_features = torch.cat([full_text_features, resnet_image_features], dim=-1)
69
+ else:
70
+ fused_features = full_text_features * resnet_image_features
71
+
72
+ # Classification
73
+ fused_features = self.fusion_layer(fused_features)
74
+ classification_output = self.classifier(fused_features)
75
+ return classification_output
76
+
77
+ # Load the model
78
+ def load_model():
79
+ with open("config.json", "r") as f:
80
+ config = json.load(f)
81
+
82
+ model = MultimodalClassifier(
83
+ text_encoder_id_or_path=config["text_encoder_id_or_path"],
84
+ image_encoder_id_or_path="microsoft/resnet-34",
85
+ projection_dim=config["projection_dim"],
86
+ fusion_method=config["fusion_method"],
87
+ proj_dropout=config["proj_dropout"],
88
+ fusion_dropout=config["fusion_dropout"],
89
+ num_classes=config["num_classes"]
90
+ )
91
+
92
+ checkpoint = torch.load("model_weights.pth", map_location=torch.device('cpu'))
93
+ model.load_state_dict(checkpoint, strict=False)
94
+
95
+ return model
96
 
97
  # Load model and tokenizer
98
+ model = load_model()
99
  model.eval()
100
  text_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
101
 
 
116
  truncation=True,
117
  max_length=512
118
  )
119
+
120
  # Process image input
121
  image_input = image_transform(image).unsqueeze(0) # Add batch dimension
122
 
 
128
  attention_mask=text_inputs["attention_mask"]
129
  )
130
  predicted_class = torch.sigmoid(classification_output).round().item()
131
+
132
  return "Biased" if predicted_class == 1 else "Unbiased"
133
 
134
  # Gradio Interface