BioMike's picture
Update app.py
74ae720 verified
raw
history blame
2.23 kB
import gradio as gr
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from src.model import ClipSegMultiClassModel
from src.config import ClipSegMultiClassConfig
# === Load model ===
class_labels = ["background", "Pig", "Horse", "Sheep"]
label2color = {
0: [0, 0, 0],
1: [255, 0, 0],
2: [0, 255, 0],
3: [0, 0, 255],
}
config = ClipSegMultiClassConfig(
class_labels=class_labels,
label2color=label2color,
model="CIDAS/clipseg-rd64-refined",
)
model = ClipSegMultiClassModel.from_pretrained("BioMike/clipsegmulticlass_v1")
model.eval()
def colorize_mask(mask_tensor, label2color):
mask = mask_tensor.squeeze().cpu().numpy()
h, w = mask.shape
color_mask = np.zeros((h, w, 3), dtype=np.uint8)
for class_id, color in label2color.items():
color_mask[mask == class_id] = color
return color_mask
def segment_with_legend(input_img):
if isinstance(input_img, str):
input_img = Image.open(input_img).convert("RGB")
elif isinstance(input_img, np.ndarray):
input_img = Image.fromarray(input_img).convert("RGB")
pred_mask = model.predict(input_img)
color_mask = colorize_mask(pred_mask, label2color)
overlay = Image.blend(input_img.resize((color_mask.shape[1], color_mask.shape[0])), Image.fromarray(color_mask), alpha=0.5)
fig, ax = plt.subplots(figsize=(8, 6))
ax.imshow(overlay)
ax.axis("off")
legend_patches = [
plt.Line2D([0], [0], marker='o', color='w',
label=label,
markerfacecolor=np.array(color) / 255.0,
markersize=10)
for label, color in zip(class_labels, label2color.values())
]
ax.legend(handles=legend_patches, loc='lower right', framealpha=0.8)
return fig
demo = gr.Interface(
fn=segment_with_legend,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=gr.Plot(label="Segmentation with Legend"),
title="ClipSeg MultiClass Demo",
description="Upload an image containing pigs, sheep, or horses. The model will segment the animals and colorize them. \
<br><br><b>Legend:</b> Red = Pig, Green = Horse, Blue = Sheep."
)
if __name__ == "__main__":
demo.launch()