File size: 2,230 Bytes
1bfa982
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315f9a8
74ae720
 
 
 
 
 
 
 
 
 
 
315f9a8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from src.model import ClipSegMultiClassModel
from src.config import ClipSegMultiClassConfig

# === Load model ===
class_labels = ["background", "Pig", "Horse", "Sheep"]
label2color = {
    0: [0, 0, 0],
    1: [255, 0, 0],
    2: [0, 255, 0],
    3: [0, 0, 255],
}

config = ClipSegMultiClassConfig(
    class_labels=class_labels,
    label2color=label2color,
    model="CIDAS/clipseg-rd64-refined",
)

model = ClipSegMultiClassModel.from_pretrained("BioMike/clipsegmulticlass_v1")

model.eval()

def colorize_mask(mask_tensor, label2color):
    mask = mask_tensor.squeeze().cpu().numpy()
    h, w = mask.shape
    color_mask = np.zeros((h, w, 3), dtype=np.uint8)
    for class_id, color in label2color.items():
        color_mask[mask == class_id] = color
    return color_mask 

def segment_with_legend(input_img):
    if isinstance(input_img, str):
        input_img = Image.open(input_img).convert("RGB")
    elif isinstance(input_img, np.ndarray):
        input_img = Image.fromarray(input_img).convert("RGB")

    pred_mask = model.predict(input_img)
    color_mask = colorize_mask(pred_mask, label2color)
    overlay = Image.blend(input_img.resize((color_mask.shape[1], color_mask.shape[0])), Image.fromarray(color_mask), alpha=0.5)

    fig, ax = plt.subplots(figsize=(8, 6))
    ax.imshow(overlay)
    ax.axis("off")

    legend_patches = [
        plt.Line2D([0], [0], marker='o', color='w',
                   label=label,
                   markerfacecolor=np.array(color) / 255.0,
                   markersize=10)
        for label, color in zip(class_labels, label2color.values())
    ]
    ax.legend(handles=legend_patches, loc='lower right', framealpha=0.8)

    return fig

demo = gr.Interface(
    fn=segment_with_legend,
    inputs=gr.Image(type="pil", label="Input Image"),
    outputs=gr.Plot(label="Segmentation with Legend"),
    title="ClipSeg MultiClass Demo",
    description="Upload an image containing pigs, sheep, or horses. The model will segment the animals and colorize them. \
<br><br><b>Legend:</b> Red = Pig, Green = Horse, Blue = Sheep."
)



if __name__ == "__main__":
    demo.launch()