import gradio as gr import os os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Force CPU if needed import torch import numpy as np from PIL import Image from PIL import Image as PILImage from pathlib import Path import matplotlib.pyplot as plt import io from skimage.io import imread from skimage.color import rgb2gray from csbdeep.utils import normalize from stardist.models import StarDist2D from stardist.plot import render_label from MEDIARFormer import MEDIARFormer from Predictor import Predictor from cellpose import models as cellpose_models, io as cellpose_io, plot as cellpose_plot # Load SegFormer from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation processor_segformer = SegformerImageProcessor(do_reduce_labels=False) model_segformer = SegformerForSemanticSegmentation.from_pretrained( "nvidia/segformer-b0-finetuned-ade-512-512", num_labels=8, ignore_mismatched_sizes=True ) model_segformer.load_state_dict(torch.load("trained_model_200.pt", map_location="cpu")) model_segformer.eval() # StarDist model model_stardist = StarDist2D.from_pretrained('2D_versatile_fluo') # Cellpose model model_cellpose = cellpose_models.CellposeModel(gpu=False) # Handle SegFormer prediction def infer_segformer(image): image = image.convert("RGB") inputs = processor_segformer(images=image, return_tensors="pt") with torch.no_grad(): logits = model_segformer(**inputs).logits pred_mask = torch.argmax(logits, dim=1)[0].cpu().numpy() # Colorize colors = np.array([[0,0,0], [255,0,0], [0,255,0], [0,0,255], [255,255,0], [255,0,255], [0,255,255], [128,128,128]]) color_mask = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3), dtype=np.uint8) for c in range(8): color_mask[pred_mask == c] = colors[c] return image, Image.fromarray(color_mask) # Handle StarDist prediction def infer_stardist(image): image_gray = rgb2gray(np.array(image)) if image.mode == 'RGB' else np.array(image) labels, _ = model_stardist.predict_instances(normalize(image_gray)) overlay = render_label(labels, img=image_gray) overlay = (overlay[..., :3] * 255).astype(np.uint8) return image, Image.fromarray(overlay) # Handle MEDIAR prediction def infer_mediar(image, temp_dir="temp_mediar"): os.makedirs(temp_dir, exist_ok=True) input_path = os.path.join(temp_dir, "input_image.tiff") output_path = os.path.join(temp_dir, "input_image_label.tiff") image.save(input_path) model_args = { "classes": 3, "decoder_channels": [1024, 512, 256, 128, 64], "decoder_pab_channels": 256, "encoder_name": 'mit_b5', "in_channels": 3 } model = MEDIARFormer(**model_args) weights = torch.load("from_phase1.pth", map_location="cpu") model.load_state_dict(weights, strict=False) model.eval() predictor = Predictor(model, "cpu", temp_dir, temp_dir, algo_params={"use_tta": False}) predictor.img_names = ["input_image.tiff"] _ = predictor.conduct_prediction() pred = imread(output_path) fig, ax = plt.subplots(figsize=(6, 6)) ax.imshow(pred, cmap="cividis") ax.axis("off") buf = io.BytesIO() plt.savefig(buf, format="png") plt.close() buf.seek(0) return image, Image.open(buf) # Handle Cellpose prediction def infer_cellpose(image, temp_dir="temp_cellpose"): os.makedirs(temp_dir, exist_ok=True) input_path = os.path.join(temp_dir, "input_image.tif") output_overlay = os.path.join(temp_dir, "overlay.png") # Save image image.save(input_path) img = cellpose_io.imread(input_path) masks, flows, styles = model_cellpose.eval(img, batch_size=1) fig = plt.figure(figsize=(12,5)) cellpose_plot.show_segmentation(fig, img, masks, flows[0]) plt.tight_layout() fig.savefig(output_overlay) plt.close(fig) return image, Image.open(output_overlay) # Wrapper function def segment(model_name, image): # Gradio passes a PIL.Image without filename attribute # Try to check format if available, else skip check ext = None if hasattr(image, 'format') and image.format is not None: ext = image.format.lower() if model_name == "Cellpose": # Accept only TIFF images for Cellpose if ext not in ["tiff", "tif", None]: return None, f"❌ Cellpose only supports `.tif` or `.tiff` images." # ...existing code... if model_name == "SegFormer": return infer_segformer(image) elif model_name == "StarDist": return infer_stardist(image) elif model_name == "MEDIAR": return infer_mediar(image) elif model_name == "Cellpose": return infer_cellpose(image) else: return None, f"❌ Unknown model: {model_name}" with gr.Blocks(title="Cell Segmentation Explorer") as app: gr.Markdown("## Cell Segmentation Explorer") gr.Markdown("Choose a segmentation model, upload an appropriate image, and view the predicted mask.") with gr.Row(): with gr.Column(): model_dropdown = gr.Dropdown( choices=["SegFormer", "StarDist", "MEDIAR", "Cellpose"], label="Select Segmentation Model", value="SegFormer" ) image_input = gr.Image(type="pil", label="Uploaded Image") description_box = gr.Markdown("Accepted formats: `.png`, `.jpg`, `.tif`, `.tiff`.") submit_btn = gr.Button("Submit") clear_btn = gr.Button("Clear") with gr.Column(): output_image = gr.Image(label="Segmentation Result") def handle_submit(model_name, img): if img is None: return None _, result = segment(model_name, img) # Only return the mask (segmentation result) return result submit_btn.click( fn=handle_submit, inputs=[model_dropdown, image_input], outputs=output_image ) clear_btn.click( lambda: [None, None], inputs=None, outputs=[image_input, output_image] ) # === SAMPLE IMAGES SECTION === gr.Markdown("---") gr.Markdown("### Sample Images (click to use as input)") # Original and resized thumbnails original_sample_paths = [ "img1.png", "img2.png", "img3.png" ] resized_sample_paths = [] for idx, p in enumerate(original_sample_paths): img = PILImage.open(p).resize((128, 128)) temp_path = f"/tmp/sample_resized_{idx}.png" img.save(temp_path) resized_sample_paths.append(temp_path) sample_image_components = [] with gr.Row(): for i, img_path in enumerate(resized_sample_paths): def load_full_image(idx=i): # Capture loop index properly return PILImage.open(original_sample_paths[idx]) sample_img = gr.Image(value=img_path, type="pil", interactive=True, show_label=False) sample_img.select( fn=load_full_image, inputs=[], outputs=image_input ) sample_image_components.append(sample_img) app.launch()