Yiming-M commited on
Commit
c38041d
·
1 Parent(s): e6bdb71
app.py CHANGED
@@ -40,8 +40,6 @@ truncation = 4
40
  reduction = 8
41
  granularity = "fine"
42
  anchor_points = "average"
43
-
44
- model_name = "clip_vit_l_14"
45
  input_size = 224
46
 
47
  # Comment the lines below to test non-CLIP models.
@@ -50,8 +48,19 @@ num_vpt = 32
50
  vpt_drop = 0.
51
  deep_vpt = True
52
 
53
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
54
 
 
55
 
56
  if truncation is None: # regression, no truncation.
57
  bins, anchor_points = None, None
@@ -62,32 +71,48 @@ else:
62
  anchor_points = config["anchor_points"][granularity]["average"] if anchor_points == "average" else config["anchor_points"][granularity]["middle"]
63
  bins = [(float(b[0]), float(b[1])) for b in bins]
64
  anchor_points = [float(p) for p in anchor_points]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
-
67
- model = get_model(
68
- backbone=model_name,
69
- input_size=input_size,
70
- reduction=reduction,
71
- bins=bins,
72
- anchor_points=anchor_points,
73
- # CLIP parameters
74
- prompt_type=prompt_type,
75
- num_vpt=num_vpt,
76
- vpt_drop=vpt_drop,
77
- deep_vpt=deep_vpt
78
- )
79
-
80
- repo_id = "Yiming-M/CLIP-EBC"
81
- filename = "nwpu_weights/CLIP_EBC_ViT_L_14/model.safetensors"
82
- weights_path = hf_hub_download(repo_id, filename)
83
- # weights_path = os.path.join("CLIP_EBC_ViT_L_14", "model.safetensors")
84
- state_dict = load_file(weights_path)
85
- new_state_dict = {}
86
- for k, v in state_dict.items():
87
- new_state_dict[k.replace("model.", "")] = v
88
- model.load_state_dict(new_state_dict)
89
- model.to(device)
90
- model.eval()
91
 
92
 
93
  # -----------------------------
@@ -114,17 +139,22 @@ def transform(image: Image.Image):
114
  # -----------------------------
115
  # Inference function
116
  # -----------------------------
117
- def predict(image: Image.Image):
118
  """
119
  Given an input image, preprocess it, run the model to obtain a density map,
120
  compute the total crowd count, and prepare the density map for display.
121
  """
 
 
 
 
 
122
  # Preprocess the image
123
  input_width, input_height = image.size
124
  input_tensor = transform(image).to(device) # shape: (1, 3, H, W)
125
 
126
  with torch.no_grad():
127
- density_map = model(input_tensor) # expected shape: (1, 1, H, W)
128
  total_count = density_map.sum().item()
129
  resized_density_map = resize_density_map(density_map, (input_height, input_width)).cpu().squeeze().numpy()
130
 
@@ -149,32 +179,34 @@ def predict(image: Image.Image):
149
  # Build Gradio Interface using Blocks for a two-column layout
150
  # -----------------------------
151
  with gr.Blocks() as demo:
152
- gr.Markdown("# Crowd Counting Demo")
153
  gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.")
154
-
155
  with gr.Row():
156
  with gr.Column():
157
- input_img = gr.Image(
158
- label="Input Image",
159
- sources=["upload", "clipboard"],
160
- type="pil",
161
  )
 
162
  submit_btn = gr.Button("Predict")
163
  with gr.Column():
164
  output_img = gr.Image(label="Predicted Density Map", type="pil")
165
  output_text = gr.Textbox(label="Total Count")
166
-
167
- submit_btn.click(fn=predict, inputs=input_img, outputs=[input_img, output_img, output_text])
168
-
169
- # Optional: add example images. Ensure these files are in your repo.
170
  gr.Examples(
171
  examples=[
172
  ["example1.jpg"],
173
- ["example2.jpg"]
 
 
 
174
  ],
175
  inputs=input_img,
176
  label="Try an example"
177
  )
178
 
179
- # Launch the app
180
- demo.launch()
 
40
  reduction = 8
41
  granularity = "fine"
42
  anchor_points = "average"
 
 
43
  input_size = 224
44
 
45
  # Comment the lines below to test non-CLIP models.
 
48
  vpt_drop = 0.
49
  deep_vpt = True
50
 
51
+ repo_id = "Yiming-M/CLIP-EBC"
52
+ model_configs = {
53
+ "CLIP_EBC_ViT_L_14": {
54
+ "model_name": "clip_vit_l_14",
55
+ "filename": "nwpu_weights/CLIP_EBC_ViT_L_14/model.safetensors",
56
+ },
57
+ "CLIP_EBC_ViT_B_16": {
58
+ "model_name": "clip_vit_b_16",
59
+ "filename": "nwpu_weights/CLIP_EBC_ViT_B_16/model.safetensors",
60
+ },
61
+ }
62
 
63
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
 
65
  if truncation is None: # regression, no truncation.
66
  bins, anchor_points = None, None
 
71
  anchor_points = config["anchor_points"][granularity]["average"] if anchor_points == "average" else config["anchor_points"][granularity]["middle"]
72
  bins = [(float(b[0]), float(b[1])) for b in bins]
73
  anchor_points = [float(p) for p in anchor_points]
74
+ # Use a global reference to store the model instance
75
+ loaded_model = None
76
+
77
+ def load_model(model_choice: str):
78
+ global loaded_model
79
+
80
+ config = model_configs[model_choice]
81
+ model_name = config["model_name"]
82
+ filename = config["filename"]
83
+
84
+ # Prepare bins and anchor_points if using classification
85
+ if truncation is None:
86
+ bins_, anchor_points_ = None, None
87
+ else:
88
+ with open(os.path.join("configs", f"reduction_{reduction}.json"), "r") as f:
89
+ config_json = json.load(f)[str(truncation)]["nwpu"]
90
+ bins_ = config_json["bins"][granularity]
91
+ anchor_points_ = config_json["anchor_points"][granularity]["average"] if anchor_points == "average" else config_json["anchor_points"][granularity]["middle"]
92
+ bins_ = [(float(b[0]), float(b[1])) for b in bins_]
93
+ anchor_points_ = [float(p) for p in anchor_points_]
94
+
95
+ # Build model
96
+ model = get_model(
97
+ backbone=model_name,
98
+ input_size=input_size,
99
+ reduction=reduction,
100
+ bins=bins_,
101
+ anchor_points=anchor_points_,
102
+ prompt_type=prompt_type,
103
+ num_vpt=num_vpt,
104
+ vpt_drop=vpt_drop,
105
+ deep_vpt=deep_vpt,
106
+ )
107
 
108
+ weights_path = hf_hub_download(repo_id, filename)
109
+ state_dict = load_file(weights_path)
110
+ new_state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
111
+ model.load_state_dict(new_state_dict)
112
+ model.to(device)
113
+ model.eval()
114
+
115
+ loaded_model = model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
 
118
  # -----------------------------
 
139
  # -----------------------------
140
  # Inference function
141
  # -----------------------------
142
+ def predict(image: Image.Image, model_choice: str = "CLIP_EBC_ViT_B_16"):
143
  """
144
  Given an input image, preprocess it, run the model to obtain a density map,
145
  compute the total crowd count, and prepare the density map for display.
146
  """
147
+ global loaded_model
148
+
149
+ if loaded_model is None or model_configs[model_choice]["model_name"] not in loaded_model.__class__.__name__:
150
+ load_model(model_choice)
151
+
152
  # Preprocess the image
153
  input_width, input_height = image.size
154
  input_tensor = transform(image).to(device) # shape: (1, 3, H, W)
155
 
156
  with torch.no_grad():
157
+ density_map = loaded_model(input_tensor) # expected shape: (1, 1, H, W)
158
  total_count = density_map.sum().item()
159
  resized_density_map = resize_density_map(density_map, (input_height, input_width)).cpu().squeeze().numpy()
160
 
 
179
  # Build Gradio Interface using Blocks for a two-column layout
180
  # -----------------------------
181
  with gr.Blocks() as demo:
182
+ gr.Markdown("# Crowd Counting by CLIP-EBC (Pre-trained on NWPU-Crowd)")
183
  gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.")
184
+
185
  with gr.Row():
186
  with gr.Column():
187
+ model_choice = gr.Dropdown(
188
+ choices=list(model_configs.keys()),
189
+ value="CLIP_EBC_ViT_B_16",
190
+ label="Select Model"
191
  )
192
+ input_img = gr.Image(label="Input Image", sources=["upload", "clipboard"], type="pil")
193
  submit_btn = gr.Button("Predict")
194
  with gr.Column():
195
  output_img = gr.Image(label="Predicted Density Map", type="pil")
196
  output_text = gr.Textbox(label="Total Count")
197
+
198
+ submit_btn.click(fn=predict, inputs=[input_img, model_choice], outputs=[input_img, output_img, output_text])
199
+
 
200
  gr.Examples(
201
  examples=[
202
  ["example1.jpg"],
203
+ ["example2.jpg"],
204
+ ["example3.jpg"],
205
+ ["example4.jpg"],
206
+ ["example5.jpg"],
207
  ],
208
  inputs=input_img,
209
  label="Try an example"
210
  )
211
 
212
+ demo.launch()
 
models/clip/_clip/__init__.py CHANGED
@@ -13,15 +13,8 @@ from .model import CLIP
13
  curr_dir = os.path.dirname(os.path.abspath(__file__))
14
 
15
  clip_model_names = [
16
- "clip_resnet50",
17
- "clip_resnet101",
18
- "clip_resnet50x4",
19
- "clip_resnet50x16",
20
- "clip_resnet50x64",
21
- "clip_vit_b_32",
22
  "clip_vit_b_16",
23
  "clip_vit_l_14",
24
- "clip_vit_l_14_336px",
25
  ]
26
 
27
  clip_image_encoder_names = [f"clip_image_encoder_{name[5:]}" for name in clip_model_names]
@@ -240,34 +233,10 @@ __all__ = [
240
  # utils
241
  "tokenize",
242
  "transform",
243
- # clip models
244
- "resnet50_clip",
245
- "resnet101_clip",
246
- "resnet50x4_clip",
247
- "resnet50x16_clip",
248
- "resnet50x64_clip",
249
- "vit_b_32_clip",
250
- "vit_b_16_clip",
251
- "vit_l_14_clip",
252
- "vit_l_14_336px_clip",
253
  # clip image encoders
254
- "resnet50_img",
255
- "resnet101_img",
256
- "resnet50x4_img",
257
- "resnet50x16_img",
258
- "resnet50x64_img",
259
- "vit_b_32_img",
260
  "vit_b_16_img",
261
  "vit_l_14_img",
262
- "vit_l_14_336px_img",
263
  # clip text encoders
264
- "resnet50_txt",
265
- "resnet101_txt",
266
- "resnet50x4_txt",
267
- "resnet50x16_txt",
268
- "resnet50x64_txt",
269
- "vit_b_32_txt",
270
  "vit_b_16_txt",
271
  "vit_l_14_txt",
272
- "vit_l_14_336px_txt",
273
  ]
 
13
  curr_dir = os.path.dirname(os.path.abspath(__file__))
14
 
15
  clip_model_names = [
 
 
 
 
 
 
16
  "clip_vit_b_16",
17
  "clip_vit_l_14",
 
18
  ]
19
 
20
  clip_image_encoder_names = [f"clip_image_encoder_{name[5:]}" for name in clip_model_names]
 
233
  # utils
234
  "tokenize",
235
  "transform",
 
 
 
 
 
 
 
 
 
 
236
  # clip image encoders
 
 
 
 
 
 
237
  "vit_b_16_img",
238
  "vit_l_14_img",
 
239
  # clip text encoders
 
 
 
 
 
 
240
  "vit_b_16_txt",
241
  "vit_l_14_txt",
 
242
  ]
models/clip/_clip/prepare.py CHANGED
@@ -9,15 +9,8 @@ from .utils import load
9
 
10
 
11
  model_name_map = {
12
- "RN50": "resnet50",
13
- "RN101": "resnet101",
14
- "RN50x4": "resnet50x4",
15
- "RN50x16": "resnet50x16",
16
- "RN50x64": "resnet50x64",
17
- "ViT-B/32": "vit_b_32",
18
  "ViT-B/16": "vit_b_16",
19
  "ViT-L/14": "vit_l_14",
20
- "ViT-L/14@336px": "vit_l_14_336px",
21
  }
22
 
23
 
@@ -49,7 +42,7 @@ def prepare() -> None:
49
  os.makedirs(config_dir, exist_ok=True)
50
  device = torch.device("cpu")
51
 
52
- for model_name in tqdm(["RN50", "RN101", "RN50x4", "RN50x16", "RN50x64", "ViT-B/32", "ViT-B/16", "ViT-L/14", "ViT-L/14@336px"]):
53
  model = load(model_name, device=device).to(device)
54
  image_encoder = model.visual.to(device)
55
  text_encoder = CLIPTextEncoderTemp(model).to(device)
 
9
 
10
 
11
  model_name_map = {
 
 
 
 
 
 
12
  "ViT-B/16": "vit_b_16",
13
  "ViT-L/14": "vit_l_14",
 
14
  }
15
 
16
 
 
42
  os.makedirs(config_dir, exist_ok=True)
43
  device = torch.device("cpu")
44
 
45
+ for model_name in tqdm(["ViT-B/16", "ViT-L/14"]):
46
  model = load(model_name, device=device).to(device)
47
  image_encoder = model.visual.to(device)
48
  text_encoder = CLIPTextEncoderTemp(model).to(device)