CLIP-EBC / app.py
Yiming-M's picture
🐣 born
570db9a
raw
history blame
6.08 kB
import torch
import torch.nn.functional as F
from torch import Tensor
import numpy as np
from PIL import Image
import json, os, random
import gradio as gr
import torchvision.transforms.functional as TF
from safetensors.torch import load_file # Import the load_file function from safetensors
from matplotlib import cm
from huggingface_hub import hf_hub_download
from typing import Tuple
from models import get_model
def resize_density_map(x: Tensor, size: Tuple[int, int]) -> Tensor:
x_sum = torch.sum(x, dim=(-1, -2))
x = F.interpolate(x, size=size, mode="bilinear")
scale_factor = torch.nan_to_num(torch.sum(x, dim=(-1, -2)) / x_sum, nan=0.0, posinf=0.0, neginf=0.0)
return x * scale_factor
def init_seeds(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
alpha = 0.8
init_seeds(42)
# -----------------------------
# Define the model architecture
# -----------------------------
truncation = 4
reduction = 8
granularity = "fine"
anchor_points = "average"
model_name = "clip_vit_l_14"
input_size = 224
# Comment the lines below to test non-CLIP models.
prompt_type = "word"
num_vpt = 32
vpt_drop = 0.
deep_vpt = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if truncation is None: # regression, no truncation.
bins, anchor_points = None, None
else:
with open(os.path.join("configs", f"reduction_{reduction}.json"), "r") as f:
config = json.load(f)[str(truncation)]["nwpu"]
bins = config["bins"][granularity]
anchor_points = config["anchor_points"][granularity]["average"] if anchor_points == "average" else config["anchor_points"][granularity]["middle"]
bins = [(float(b[0]), float(b[1])) for b in bins]
anchor_points = [float(p) for p in anchor_points]
model = get_model(
backbone=model_name,
input_size=input_size,
reduction=reduction,
bins=bins,
anchor_points=anchor_points,
# CLIP parameters
prompt_type=prompt_type,
num_vpt=num_vpt,
vpt_drop=vpt_drop,
deep_vpt=deep_vpt
)
repo_id = "Yiming-M/CLIP-EBC"
filename = "nwpu_weights/CLIP_EBC_ViT_L_14/model.safetensors"
weights_path = hf_hub_download(repo_id, filename)
# weights_path = os.path.join("CLIP_EBC_ViT_L_14", "model.safetensors")
state_dict = load_file(weights_path)
new_state_dict = {}
for k, v in state_dict.items():
new_state_dict[k.replace("model.", "")] = v
model.load_state_dict(new_state_dict)
model.to(device)
model.eval()
# -----------------------------
# Preprocessing function
# -----------------------------
# Adjust the image transforms to match what your model expects.
def transform(image: Image.Image):
assert isinstance(image, Image.Image), "Input must be a PIL Image"
image_tensor = TF.to_tensor(image)
image_height, image_width = image_tensor.shape[-2:]
if image_height < input_size or image_width < input_size:
# Find the ratio to resize the image while maintaining the aspect ratio
ratio = max(input_size / image_height, input_size / image_width)
new_height = int(image_height * ratio) + 1
new_width = int(image_width * ratio) + 1
image_tensor = TF.resize(image_tensor, (new_height, new_width), interpolation=TF.InterpolationMode.BICUBIC, antialias=True)
image_tensor = TF.normalize(image_tensor, mean=mean, std=std)
return image_tensor.unsqueeze(0) # Add batch dimension
# -----------------------------
# Inference function
# -----------------------------
def predict(image: Image.Image):
"""
Given an input image, preprocess it, run the model to obtain a density map,
compute the total crowd count, and prepare the density map for display.
"""
# Preprocess the image
input_width, input_height = image.size
input_tensor = transform(image).to(device) # shape: (1, 3, H, W)
with torch.no_grad():
density_map = model(input_tensor) # expected shape: (1, 1, H, W)
total_count = density_map.sum().item()
resized_density_map = resize_density_map(density_map, (input_height, input_width)).cpu().squeeze().numpy()
# Normalize the density map for display purposes
eps = 1e-8
density_map_norm = (resized_density_map - resized_density_map.min()) / (resized_density_map.max() - resized_density_map.min() + eps)
# Apply a colormap (e.g., 'jet') to get an RGBA image
colormap = cm.get_cmap("jet")
# The colormap returns values in [0,1]. Scale to [0,255] and convert to uint8.
density_map_color = (colormap(density_map_norm) * 255).astype(np.uint8)
density_map_color_img = Image.fromarray(density_map_color).convert("RGBA")
# Ensure the original image is in RGBA format.
image_rgba = image.convert("RGBA")
overlayed_image = Image.blend(image_rgba, density_map_color_img, alpha=alpha)
return image, overlayed_image, f"Predicted Count: {total_count:.2f}"
# -----------------------------
# Build Gradio Interface using Blocks for a two-column layout
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown("# Crowd Counting Demo")
gr.Markdown("Upload an image or select an example below to see the predicted crowd density map and total count.")
with gr.Row():
with gr.Column():
input_img = gr.Image(
label="Input Image",
sources=["upload", "clipboard"],
type="pil",
)
submit_btn = gr.Button("Predict")
with gr.Column():
output_img = gr.Image(label="Predicted Density Map", type="pil")
output_text = gr.Textbox(label="Total Count")
submit_btn.click(fn=predict, inputs=input_img, outputs=[input_img, output_img, output_text])
# Optional: add example images. Ensure these files are in your repo.
gr.Examples(
examples=[
["example1.jpg"],
["example2.jpg"]
],
inputs=input_img,
label="Try an example"
)
# Launch the app
demo.launch()