Spaces:
Sleeping
Sleeping
File size: 6,079 Bytes
570db9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import torch
import torch.nn.functional as F
from torch import Tensor
import numpy as np
from PIL import Image
import json, os, random
import gradio as gr
import torchvision.transforms.functional as TF
from safetensors.torch import load_file # Import the load_file function from safetensors
from matplotlib import cm
from huggingface_hub import hf_hub_download
from typing import Tuple
from models import get_model
def resize_density_map(x: Tensor, size: Tuple[int, int]) -> Tensor:
x_sum = torch.sum(x, dim=(-1, -2))
x = F.interpolate(x, size=size, mode="bilinear")
scale_factor = torch.nan_to_num(torch.sum(x, dim=(-1, -2)) / x_sum, nan=0.0, posinf=0.0, neginf=0.0)
return x * scale_factor
def init_seeds(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
alpha = 0.8
init_seeds(42)
# -----------------------------
# Define the model architecture
# -----------------------------
truncation = 4
reduction = 8
granularity = "fine"
anchor_points = "average"
model_name = "clip_vit_l_14"
input_size = 224
# Comment the lines below to test non-CLIP models.
prompt_type = "word"
num_vpt = 32
vpt_drop = 0.
deep_vpt = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if truncation is None: # regression, no truncation.
bins, anchor_points = None, None
else:
with open(os.path.join("configs", f"reduction_{reduction}.json"), "r") as f:
config = json.load(f)[str(truncation)]["nwpu"]
bins = config["bins"][granularity]
anchor_points = config["anchor_points"][granularity]["average"] if anchor_points == "average" else config["anchor_points"][granularity]["middle"]
bins = [(float(b[0]), float(b[1])) for b in bins]
anchor_points = [float(p) for p in anchor_points]
model = get_model(
backbone=model_name,
input_size=input_size,
reduction=reduction,
bins=bins,
anchor_points=anchor_points,
# CLIP parameters
prompt_type=prompt_type,
num_vpt=num_vpt,
vpt_drop=vpt_drop,
deep_vpt=deep_vpt
)
repo_id = "Yiming-M/CLIP-EBC"
filename = "nwpu_weights/CLIP_EBC_ViT_L_14/model.safetensors"
weights_path = hf_hub_download(repo_id, filename)
# weights_path = os.path.join("CLIP_EBC_ViT_L_14", "model.safetensors")
state_dict = load_file(weights_path)
new_state_dict = {}
for k, v in state_dict.items():
new_state_dict[k.replace("model.", "")] = v
model.load_state_dict(new_state_dict)
model.to(device)
model.eval()
# -----------------------------
# Preprocessing function
# -----------------------------
# Adjust the image transforms to match what your model expects.
def transform(image: Image.Image):
assert isinstance(image, Image.Image), "Input must be a PIL Image"
image_tensor = TF.to_tensor(image)
image_height, image_width = image_tensor.shape[-2:]
if image_height < input_size or image_width < input_size:
# Find the ratio to resize the image while maintaining the aspect ratio
ratio = max(input_size / image_height, input_size / image_width)
new_height = int(image_height * ratio) + 1
new_width = int(image_width * ratio) + 1
image_tensor = TF.resize(image_tensor, (new_height, new_width), interpolation=TF.InterpolationMode.BICUBIC, antialias=True)
image_tensor = TF.normalize(image_tensor, mean=mean, std=std)
return image_tensor.unsqueeze(0) # Add batch dimension
# -----------------------------
# Inference function
# -----------------------------
def predict(image: Image.Image):
"""
Given an input image, preprocess it, run the model to obtain a density map,
compute the total crowd count, and prepare the density map for display.
"""
# Preprocess the image
input_width, input_height = image.size
input_tensor = transform(image).to(device) # shape: (1, 3, H, W)
with torch.no_grad():
density_map = model(input_tensor) # expected shape: (1, 1, H, W)
total_count = density_map.sum().item()
resized_density_map = resize_density_map(density_map, (input_height, input_width)).cpu().squeeze().numpy()
# Normalize the density map for display purposes
eps = 1e-8
density_map_norm = (resized_density_map - resized_density_map.min()) / (resized_density_map.max() - resized_density_map.min() + eps)
# Apply a colormap (e.g., 'jet') to get an RGBA image
colormap = cm.get_cmap("jet")
# The colormap returns values in [0,1]. Scale to [0,255] and convert to uint8.
density_map_color = (colormap(density_map_norm) * 255).astype(np.uint8)
density_map_color_img = Image.fromarray(density_map_color).convert("RGBA")
# Ensure the original image is in RGBA format.
image_rgba = image.convert("RGBA")
overlayed_image = Image.blend(image_rgba, density_map_color_img, alpha=alpha)
return image, overlayed_image, f"Predicted Count: {total_count:.2f}"
# -----------------------------
# Build Gradio Interface using Blocks for a two-column layout
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown("# Crowd Counting Demo")
gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.")
with gr.Row():
with gr.Column():
input_img = gr.Image(
label="Input Image",
sources=["upload", "clipboard"],
type="pil",
)
submit_btn = gr.Button("Predict")
with gr.Column():
output_img = gr.Image(label="Predicted Density Map", type="pil")
output_text = gr.Textbox(label="Total Count")
submit_btn.click(fn=predict, inputs=input_img, outputs=[input_img, output_img, output_text])
# Optional: add example images. Ensure these files are in your repo.
gr.Examples(
examples=[
["example1.jpg"],
["example2.jpg"]
],
inputs=input_img,
label="Try an example"
)
# Launch the app
demo.launch()
|