Spaces:
Running
Running
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Force CPU if needed
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
from PIL import Image as PILImage
|
8 |
+
from pathlib import Path
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import io
|
11 |
+
from skimage.io import imread
|
12 |
+
from skimage.color import rgb2gray
|
13 |
+
from csbdeep.utils import normalize
|
14 |
+
from stardist.models import StarDist2D
|
15 |
+
from stardist.plot import render_label
|
16 |
+
#import MEDIARFormer
|
17 |
+
#import Predictor
|
18 |
+
from cellpose import models as cellpose_models, io as cellpose_io, plot as cellpose_plot
|
19 |
+
|
20 |
+
# Load SegFormer
|
21 |
+
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
|
22 |
+
processor_segformer = SegformerImageProcessor(do_reduce_labels=False)
|
23 |
+
model_segformer = SegformerForSemanticSegmentation.from_pretrained(
|
24 |
+
"nvidia/segformer-b0-finetuned-ade-512-512",
|
25 |
+
num_labels=8,
|
26 |
+
ignore_mismatched_sizes=True
|
27 |
+
)
|
28 |
+
model_segformer.load_state_dict(torch.load("trained_model_200.pt", map_location="cpu"))
|
29 |
+
model_segformer.eval()
|
30 |
+
|
31 |
+
# StarDist model
|
32 |
+
model_stardist = StarDist2D.from_pretrained('2D_versatile_fluo')
|
33 |
+
|
34 |
+
# Cellpose model
|
35 |
+
model_cellpose = cellpose_models.CellposeModel(gpu=False)
|
36 |
+
|
37 |
+
# Handle SegFormer prediction
|
38 |
+
def infer_segformer(image):
|
39 |
+
image = image.convert("RGB")
|
40 |
+
inputs = processor_segformer(images=image, return_tensors="pt")
|
41 |
+
with torch.no_grad():
|
42 |
+
logits = model_segformer(**inputs).logits
|
43 |
+
pred_mask = torch.argmax(logits, dim=1)[0].cpu().numpy()
|
44 |
+
|
45 |
+
# Colorize
|
46 |
+
colors = np.array([[0,0,0], [255,0,0], [0,255,0], [0,0,255], [255,255,0], [255,0,255], [0,255,255], [128,128,128]])
|
47 |
+
color_mask = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3), dtype=np.uint8)
|
48 |
+
for c in range(8):
|
49 |
+
color_mask[pred_mask == c] = colors[c]
|
50 |
+
return image, Image.fromarray(color_mask)
|
51 |
+
|
52 |
+
# Handle StarDist prediction
|
53 |
+
def infer_stardist(image):
|
54 |
+
image_gray = rgb2gray(np.array(image)) if image.mode == 'RGB' else np.array(image)
|
55 |
+
labels, _ = model_stardist.predict_instances(normalize(image_gray))
|
56 |
+
overlay = render_label(labels, img=image_gray)
|
57 |
+
overlay = (overlay[..., :3] * 255).astype(np.uint8)
|
58 |
+
return image, Image.fromarray(overlay)
|
59 |
+
|
60 |
+
# Handle MEDIAR prediction
|
61 |
+
# def infer_mediar(image, temp_dir="temp_mediar"):
|
62 |
+
# os.makedirs(temp_dir, exist_ok=True)
|
63 |
+
# input_path = os.path.join(temp_dir, "input_image.tiff")
|
64 |
+
# output_path = os.path.join(temp_dir, "input_image_label.tiff")
|
65 |
+
|
66 |
+
# image.save(input_path)
|
67 |
+
|
68 |
+
# model_args = {
|
69 |
+
# "classes": 3,
|
70 |
+
# "decoder_channels": [1024, 512, 256, 128, 64],
|
71 |
+
# "decoder_pab_channels": 256,
|
72 |
+
# "encoder_name": 'mit_b5',
|
73 |
+
# "in_channels": 3
|
74 |
+
# }
|
75 |
+
|
76 |
+
# model = MEDIARFormer(**model_args)
|
77 |
+
# weights = torch.load("from_phase1.pth", map_location="cpu")
|
78 |
+
# model.load_state_dict(weights, strict=False)
|
79 |
+
# model.eval()
|
80 |
+
|
81 |
+
# predictor = Predictor(model, "cpu", temp_dir, temp_dir, algo_params={"use_tta": False})
|
82 |
+
# predictor.img_names = ["input_image.tiff"]
|
83 |
+
# _ = predictor.conduct_prediction()
|
84 |
+
|
85 |
+
# pred = imread(output_path)
|
86 |
+
# fig, ax = plt.subplots(figsize=(6, 6))
|
87 |
+
# ax.imshow(pred, cmap="cividis")
|
88 |
+
# ax.axis("off")
|
89 |
+
|
90 |
+
# buf = io.BytesIO()
|
91 |
+
# plt.savefig(buf, format="png")
|
92 |
+
# plt.close()
|
93 |
+
# buf.seek(0)
|
94 |
+
|
95 |
+
# return image, Image.open(buf)
|
96 |
+
# Handle Cellpose prediction
|
97 |
+
def infer_cellpose(image, temp_dir="temp_cellpose"):
|
98 |
+
os.makedirs(temp_dir, exist_ok=True)
|
99 |
+
input_path = os.path.join(temp_dir, "input_image.tif")
|
100 |
+
output_overlay = os.path.join(temp_dir, "overlay.png")
|
101 |
+
|
102 |
+
# Save image
|
103 |
+
image.save(input_path)
|
104 |
+
img = cellpose_io.imread(input_path)
|
105 |
+
masks, flows, styles = model_cellpose.eval(img, batch_size=1)
|
106 |
+
|
107 |
+
fig = plt.figure(figsize=(12,5))
|
108 |
+
cellpose_plot.show_segmentation(fig, img, masks, flows[0])
|
109 |
+
plt.tight_layout()
|
110 |
+
fig.savefig(output_overlay)
|
111 |
+
plt.close(fig)
|
112 |
+
|
113 |
+
return image, Image.open(output_overlay)
|
114 |
+
|
115 |
+
# Wrapper function
|
116 |
+
def segment(model_name, image):
|
117 |
+
# Gradio passes a PIL.Image without filename attribute
|
118 |
+
# Try to check format if available, else skip check
|
119 |
+
ext = None
|
120 |
+
if hasattr(image, 'format') and image.format is not None:
|
121 |
+
ext = image.format.lower()
|
122 |
+
if model_name == "Cellpose":
|
123 |
+
# Accept only TIFF images for Cellpose
|
124 |
+
if ext not in ["tiff", "tif", None]:
|
125 |
+
return None, f"❌ Cellpose only supports `.tif` or `.tiff` images."
|
126 |
+
# ...existing code...
|
127 |
+
if model_name == "SegFormer":
|
128 |
+
return infer_segformer(image)
|
129 |
+
elif model_name == "StarDist":
|
130 |
+
return infer_stardist(image)
|
131 |
+
# elif model_name == "MEDIAR":
|
132 |
+
# return infer_mediar(image)
|
133 |
+
elif model_name == "Cellpose":
|
134 |
+
return infer_cellpose(image)
|
135 |
+
else:
|
136 |
+
return None, f"❌ Unknown model: {model_name}"
|
137 |
+
|
138 |
+
with gr.Blocks(title="Cell Segmentation Explorer") as app:
|
139 |
+
gr.Markdown("## Cell Segmentation Explorer")
|
140 |
+
gr.Markdown("Choose a segmentation model, upload an appropriate image, and view the predicted mask.")
|
141 |
+
|
142 |
+
with gr.Row():
|
143 |
+
with gr.Column():
|
144 |
+
model_dropdown = gr.Dropdown(
|
145 |
+
choices=["SegFormer", "StarDist", "Cellpose"],
|
146 |
+
label="Select Segmentation Model",
|
147 |
+
value="SegFormer"
|
148 |
+
)
|
149 |
+
image_input = gr.Image(type="pil", label="Uploaded Image")
|
150 |
+
description_box = gr.Markdown("Accepted formats: `.png`, `.jpg`, `.tif`, `.tiff`.")
|
151 |
+
submit_btn = gr.Button("Submit")
|
152 |
+
clear_btn = gr.Button("Clear")
|
153 |
+
with gr.Column():
|
154 |
+
output_image = gr.Image(label="Segmentation Result")
|
155 |
+
|
156 |
+
def handle_submit(model_name, img):
|
157 |
+
if img is None:
|
158 |
+
return None
|
159 |
+
_, result = segment(model_name, img) # Only return the mask (segmentation result)
|
160 |
+
return result
|
161 |
+
|
162 |
+
submit_btn.click(
|
163 |
+
fn=handle_submit,
|
164 |
+
inputs=[model_dropdown, image_input],
|
165 |
+
outputs=output_image
|
166 |
+
)
|
167 |
+
|
168 |
+
clear_btn.click(
|
169 |
+
lambda: [None, None],
|
170 |
+
inputs=None,
|
171 |
+
outputs=[image_input, output_image]
|
172 |
+
)
|
173 |
+
|
174 |
+
# === SAMPLE IMAGES SECTION ===
|
175 |
+
gr.Markdown("---")
|
176 |
+
gr.Markdown("### Sample Images (click to use as input)")
|
177 |
+
|
178 |
+
# Original and resized thumbnails
|
179 |
+
original_sample_paths = [
|
180 |
+
"img1.png",
|
181 |
+
"img2.png",
|
182 |
+
"img3.png"
|
183 |
+
]
|
184 |
+
|
185 |
+
resized_sample_paths = []
|
186 |
+
for idx, p in enumerate(original_sample_paths):
|
187 |
+
img = PILImage.open(p).resize((128, 128))
|
188 |
+
temp_path = f"/tmp/sample_resized_{idx}.png"
|
189 |
+
img.save(temp_path)
|
190 |
+
resized_sample_paths.append(temp_path)
|
191 |
+
|
192 |
+
sample_image_components = []
|
193 |
+
with gr.Row():
|
194 |
+
for i, img_path in enumerate(resized_sample_paths):
|
195 |
+
def load_full_image(idx=i): # Capture loop index properly
|
196 |
+
return PILImage.open(original_sample_paths[idx])
|
197 |
+
|
198 |
+
sample_img = gr.Image(value=img_path, type="pil", interactive=True, show_label=False)
|
199 |
+
sample_img.select(
|
200 |
+
fn=load_full_image,
|
201 |
+
inputs=[],
|
202 |
+
outputs=image_input
|
203 |
+
)
|
204 |
+
sample_image_components.append(sample_img)
|
205 |
+
|
206 |
+
|
207 |
+
app.launch()
|