File size: 6,079 Bytes
570db9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import torch
import torch.nn.functional as F
from torch import Tensor
import numpy as np
from PIL import Image
import json, os, random
import gradio as gr
import torchvision.transforms.functional as TF
from safetensors.torch import load_file  # Import the load_file function from safetensors
from matplotlib import cm
from huggingface_hub import hf_hub_download

from typing import Tuple

from models import get_model


def resize_density_map(x: Tensor, size: Tuple[int, int]) -> Tensor:
    x_sum = torch.sum(x, dim=(-1, -2))
    x = F.interpolate(x, size=size, mode="bilinear")
    scale_factor = torch.nan_to_num(torch.sum(x, dim=(-1, -2)) / x_sum, nan=0.0, posinf=0.0, neginf=0.0)
    return x * scale_factor


def init_seeds(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
alpha = 0.8
init_seeds(42)

# -----------------------------
# Define the model architecture
# -----------------------------
truncation = 4
reduction = 8
granularity = "fine"
anchor_points = "average"

model_name = "clip_vit_l_14"
input_size = 224

# Comment the lines below to test non-CLIP models.
prompt_type = "word"
num_vpt = 32
vpt_drop = 0.
deep_vpt = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


if truncation is None:  # regression, no truncation.
    bins, anchor_points = None, None
else:
    with open(os.path.join("configs", f"reduction_{reduction}.json"), "r") as f:
        config = json.load(f)[str(truncation)]["nwpu"]
    bins = config["bins"][granularity]
    anchor_points = config["anchor_points"][granularity]["average"] if anchor_points == "average" else config["anchor_points"][granularity]["middle"]
    bins = [(float(b[0]), float(b[1])) for b in bins]
    anchor_points = [float(p) for p in anchor_points]


model = get_model(
    backbone=model_name,
    input_size=input_size, 
    reduction=reduction,
    bins=bins,
    anchor_points=anchor_points,
    # CLIP parameters
    prompt_type=prompt_type,
    num_vpt=num_vpt,
    vpt_drop=vpt_drop,
    deep_vpt=deep_vpt
)

repo_id = "Yiming-M/CLIP-EBC"
filename = "nwpu_weights/CLIP_EBC_ViT_L_14/model.safetensors"
weights_path = hf_hub_download(repo_id, filename)
# weights_path = os.path.join("CLIP_EBC_ViT_L_14", "model.safetensors")
state_dict = load_file(weights_path)
new_state_dict = {}
for k, v in state_dict.items():
    new_state_dict[k.replace("model.", "")] = v
model.load_state_dict(new_state_dict)
model.to(device)
model.eval()


# -----------------------------
# Preprocessing function
# -----------------------------
# Adjust the image transforms to match what your model expects.
def transform(image: Image.Image):
    assert isinstance(image, Image.Image), "Input must be a PIL Image"
    image_tensor = TF.to_tensor(image)

    image_height, image_width = image_tensor.shape[-2:]
    if image_height < input_size or image_width < input_size:
        # Find the ratio to resize the image while maintaining the aspect ratio
        ratio = max(input_size / image_height, input_size / image_width)
        new_height = int(image_height * ratio) + 1
        new_width = int(image_width * ratio) + 1
        image_tensor = TF.resize(image_tensor, (new_height, new_width), interpolation=TF.InterpolationMode.BICUBIC, antialias=True)

    image_tensor = TF.normalize(image_tensor, mean=mean, std=std)
    return image_tensor.unsqueeze(0)  # Add batch dimension



# -----------------------------
# Inference function
# -----------------------------
def predict(image: Image.Image):
    """
    Given an input image, preprocess it, run the model to obtain a density map,
    compute the total crowd count, and prepare the density map for display.
    """
    # Preprocess the image
    input_width, input_height = image.size
    input_tensor = transform(image).to(device)  # shape: (1, 3, H, W)
    
    with torch.no_grad():
        density_map = model(input_tensor)  # expected shape: (1, 1, H, W)
        total_count = density_map.sum().item()
        resized_density_map = resize_density_map(density_map, (input_height, input_width)).cpu().squeeze().numpy()
    
    # Normalize the density map for display purposes
    eps = 1e-8
    density_map_norm = (resized_density_map - resized_density_map.min()) / (resized_density_map.max() - resized_density_map.min() + eps)
    
    # Apply a colormap (e.g., 'jet') to get an RGBA image
    colormap = cm.get_cmap("jet")
    # The colormap returns values in [0,1]. Scale to [0,255] and convert to uint8.
    density_map_color = (colormap(density_map_norm) * 255).astype(np.uint8)
    density_map_color_img = Image.fromarray(density_map_color).convert("RGBA")
    
    # Ensure the original image is in RGBA format.
    image_rgba = image.convert("RGBA")
    overlayed_image = Image.blend(image_rgba, density_map_color_img, alpha=alpha)
    
    return image, overlayed_image, f"Predicted Count: {total_count:.2f}"


# -----------------------------
# Build Gradio Interface using Blocks for a two-column layout
# -----------------------------
with gr.Blocks() as demo:
    gr.Markdown("# Crowd Counting Demo")
    gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.")
    
    with gr.Row():
        with gr.Column():
            input_img = gr.Image(
                label="Input Image",
                sources=["upload", "clipboard"],
                type="pil",
            )
            submit_btn = gr.Button("Predict")
        with gr.Column():
            output_img = gr.Image(label="Predicted Density Map", type="pil")
            output_text = gr.Textbox(label="Total Count")
    
    submit_btn.click(fn=predict, inputs=input_img, outputs=[input_img, output_img, output_text])
    
    # Optional: add example images. Ensure these files are in your repo.
    gr.Examples(
        examples=[
            ["example1.jpg"],
            ["example2.jpg"]
        ],
        inputs=input_img,
        label="Try an example"
    )

# Launch the app
demo.launch()