import gradio as gr import torch import numpy as np from PIL import Image import matplotlib.pyplot as plt from src.model import ClipSegMultiClassModel from src.config import ClipSegMultiClassConfig # === Load model === class_labels = ["background", "Pig", "Horse", "Sheep"] label2color = { 0: [0, 0, 0], 1: [255, 0, 0], 2: [0, 255, 0], 3: [0, 0, 255], } config = ClipSegMultiClassConfig( class_labels=class_labels, label2color=label2color, model="CIDAS/clipseg-rd64-refined", ) model = ClipSegMultiClassModel.from_pretrained("BioMike/clipsegmulticlass_v1") model.eval() def colorize_mask(mask_tensor, label2color): mask = mask_tensor.squeeze().cpu().numpy() h, w = mask.shape color_mask = np.zeros((h, w, 3), dtype=np.uint8) for class_id, color in label2color.items(): color_mask[mask == class_id] = color return color_mask def segment_with_legend(input_img): if isinstance(input_img, str): input_img = Image.open(input_img).convert("RGB") elif isinstance(input_img, np.ndarray): input_img = Image.fromarray(input_img).convert("RGB") pred_mask = model.predict(input_img) color_mask = colorize_mask(pred_mask, label2color) overlay = Image.blend(input_img.resize((color_mask.shape[1], color_mask.shape[0])), Image.fromarray(color_mask), alpha=0.5) fig, ax = plt.subplots(figsize=(8, 6)) ax.imshow(overlay) ax.axis("off") legend_patches = [ plt.Line2D([0], [0], marker='o', color='w', label=label, markerfacecolor=np.array(color) / 255.0, markersize=10) for label, color in zip(class_labels, label2color.values()) ] ax.legend(handles=legend_patches, loc='lower right', framealpha=0.8) return fig demo = gr.Interface( fn=segment_with_legend, inputs=gr.Image(type="pil", label="Input Image"), outputs=gr.Plot(label="Segmentation with Legend"), title="ClipSeg MultiClass Demo", description="Upload an image containing pigs, sheep, or horses. The model will segment the animals and colorize them. \

Legend: Red = Pig, Green = Horse, Blue = Sheep." ) if __name__ == "__main__": demo.launch()