Spaces:
Running
Running
import gradio as gr | |
import torch | |
import numpy as np | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from src.model import ClipSegMultiClassModel | |
from src.config import ClipSegMultiClassConfig | |
# === Load model === | |
class_labels = ["background", "Pig", "Horse", "Sheep"] | |
label2color = { | |
0: [0, 0, 0], | |
1: [255, 0, 0], | |
2: [0, 255, 0], | |
3: [0, 0, 255], | |
} | |
config = ClipSegMultiClassConfig( | |
class_labels=class_labels, | |
label2color=label2color, | |
model="CIDAS/clipseg-rd64-refined", | |
) | |
model = ClipSegMultiClassModel.from_pretrained("BioMike/clipsegmulticlass_v1") | |
model.eval() | |
def colorize_mask(mask_tensor, label2color): | |
mask = mask_tensor.squeeze().cpu().numpy() | |
h, w = mask.shape | |
color_mask = np.zeros((h, w, 3), dtype=np.uint8) | |
for class_id, color in label2color.items(): | |
color_mask[mask == class_id] = color | |
return color_mask | |
def segment_with_legend(input_img): | |
if isinstance(input_img, str): | |
input_img = Image.open(input_img).convert("RGB") | |
elif isinstance(input_img, np.ndarray): | |
input_img = Image.fromarray(input_img).convert("RGB") | |
pred_mask = model.predict(input_img) | |
color_mask = colorize_mask(pred_mask, label2color) | |
overlay = Image.blend(input_img.resize((color_mask.shape[1], color_mask.shape[0])), Image.fromarray(color_mask), alpha=0.5) | |
fig, ax = plt.subplots(figsize=(8, 6)) | |
ax.imshow(overlay) | |
ax.axis("off") | |
legend_patches = [ | |
plt.Line2D([0], [0], marker='o', color='w', | |
label=label, | |
markerfacecolor=np.array(color) / 255.0, | |
markersize=10) | |
for label, color in zip(class_labels, label2color.values()) | |
] | |
ax.legend(handles=legend_patches, loc='lower right', framealpha=0.8) | |
return fig | |
if __name__ == "__main__": | |
demo.launch() |