saim1309 commited on
Commit
ff197f1
·
verified ·
1 Parent(s): 29ac506

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -0
app.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Force CPU if needed
4
+ import torch
5
+ import numpy as np
6
+ from PIL import Image
7
+ from PIL import Image as PILImage
8
+ from pathlib import Path
9
+ import matplotlib.pyplot as plt
10
+ import io
11
+ from skimage.io import imread
12
+ from skimage.color import rgb2gray
13
+ from csbdeep.utils import normalize
14
+ from stardist.models import StarDist2D
15
+ from stardist.plot import render_label
16
+ #import MEDIARFormer
17
+ #import Predictor
18
+ from cellpose import models as cellpose_models, io as cellpose_io, plot as cellpose_plot
19
+
20
+ # Load SegFormer
21
+ from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
22
+ processor_segformer = SegformerImageProcessor(do_reduce_labels=False)
23
+ model_segformer = SegformerForSemanticSegmentation.from_pretrained(
24
+ "nvidia/segformer-b0-finetuned-ade-512-512",
25
+ num_labels=8,
26
+ ignore_mismatched_sizes=True
27
+ )
28
+ model_segformer.load_state_dict(torch.load("trained_model_200.pt", map_location="cpu"))
29
+ model_segformer.eval()
30
+
31
+ # StarDist model
32
+ model_stardist = StarDist2D.from_pretrained('2D_versatile_fluo')
33
+
34
+ # Cellpose model
35
+ model_cellpose = cellpose_models.CellposeModel(gpu=False)
36
+
37
+ # Handle SegFormer prediction
38
+ def infer_segformer(image):
39
+ image = image.convert("RGB")
40
+ inputs = processor_segformer(images=image, return_tensors="pt")
41
+ with torch.no_grad():
42
+ logits = model_segformer(**inputs).logits
43
+ pred_mask = torch.argmax(logits, dim=1)[0].cpu().numpy()
44
+
45
+ # Colorize
46
+ colors = np.array([[0,0,0], [255,0,0], [0,255,0], [0,0,255], [255,255,0], [255,0,255], [0,255,255], [128,128,128]])
47
+ color_mask = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3), dtype=np.uint8)
48
+ for c in range(8):
49
+ color_mask[pred_mask == c] = colors[c]
50
+ return image, Image.fromarray(color_mask)
51
+
52
+ # Handle StarDist prediction
53
+ def infer_stardist(image):
54
+ image_gray = rgb2gray(np.array(image)) if image.mode == 'RGB' else np.array(image)
55
+ labels, _ = model_stardist.predict_instances(normalize(image_gray))
56
+ overlay = render_label(labels, img=image_gray)
57
+ overlay = (overlay[..., :3] * 255).astype(np.uint8)
58
+ return image, Image.fromarray(overlay)
59
+
60
+ # Handle MEDIAR prediction
61
+ # def infer_mediar(image, temp_dir="temp_mediar"):
62
+ # os.makedirs(temp_dir, exist_ok=True)
63
+ # input_path = os.path.join(temp_dir, "input_image.tiff")
64
+ # output_path = os.path.join(temp_dir, "input_image_label.tiff")
65
+
66
+ # image.save(input_path)
67
+
68
+ # model_args = {
69
+ # "classes": 3,
70
+ # "decoder_channels": [1024, 512, 256, 128, 64],
71
+ # "decoder_pab_channels": 256,
72
+ # "encoder_name": 'mit_b5',
73
+ # "in_channels": 3
74
+ # }
75
+
76
+ # model = MEDIARFormer(**model_args)
77
+ # weights = torch.load("from_phase1.pth", map_location="cpu")
78
+ # model.load_state_dict(weights, strict=False)
79
+ # model.eval()
80
+
81
+ # predictor = Predictor(model, "cpu", temp_dir, temp_dir, algo_params={"use_tta": False})
82
+ # predictor.img_names = ["input_image.tiff"]
83
+ # _ = predictor.conduct_prediction()
84
+
85
+ # pred = imread(output_path)
86
+ # fig, ax = plt.subplots(figsize=(6, 6))
87
+ # ax.imshow(pred, cmap="cividis")
88
+ # ax.axis("off")
89
+
90
+ # buf = io.BytesIO()
91
+ # plt.savefig(buf, format="png")
92
+ # plt.close()
93
+ # buf.seek(0)
94
+
95
+ # return image, Image.open(buf)
96
+ # Handle Cellpose prediction
97
+ def infer_cellpose(image, temp_dir="temp_cellpose"):
98
+ os.makedirs(temp_dir, exist_ok=True)
99
+ input_path = os.path.join(temp_dir, "input_image.tif")
100
+ output_overlay = os.path.join(temp_dir, "overlay.png")
101
+
102
+ # Save image
103
+ image.save(input_path)
104
+ img = cellpose_io.imread(input_path)
105
+ masks, flows, styles = model_cellpose.eval(img, batch_size=1)
106
+
107
+ fig = plt.figure(figsize=(12,5))
108
+ cellpose_plot.show_segmentation(fig, img, masks, flows[0])
109
+ plt.tight_layout()
110
+ fig.savefig(output_overlay)
111
+ plt.close(fig)
112
+
113
+ return image, Image.open(output_overlay)
114
+
115
+ # Wrapper function
116
+ def segment(model_name, image):
117
+ # Gradio passes a PIL.Image without filename attribute
118
+ # Try to check format if available, else skip check
119
+ ext = None
120
+ if hasattr(image, 'format') and image.format is not None:
121
+ ext = image.format.lower()
122
+ if model_name == "Cellpose":
123
+ # Accept only TIFF images for Cellpose
124
+ if ext not in ["tiff", "tif", None]:
125
+ return None, f"❌ Cellpose only supports `.tif` or `.tiff` images."
126
+ # ...existing code...
127
+ if model_name == "SegFormer":
128
+ return infer_segformer(image)
129
+ elif model_name == "StarDist":
130
+ return infer_stardist(image)
131
+ # elif model_name == "MEDIAR":
132
+ # return infer_mediar(image)
133
+ elif model_name == "Cellpose":
134
+ return infer_cellpose(image)
135
+ else:
136
+ return None, f"❌ Unknown model: {model_name}"
137
+
138
+ with gr.Blocks(title="Cell Segmentation Explorer") as app:
139
+ gr.Markdown("## Cell Segmentation Explorer")
140
+ gr.Markdown("Choose a segmentation model, upload an appropriate image, and view the predicted mask.")
141
+
142
+ with gr.Row():
143
+ with gr.Column():
144
+ model_dropdown = gr.Dropdown(
145
+ choices=["SegFormer", "StarDist", "Cellpose"],
146
+ label="Select Segmentation Model",
147
+ value="SegFormer"
148
+ )
149
+ image_input = gr.Image(type="pil", label="Uploaded Image")
150
+ description_box = gr.Markdown("Accepted formats: `.png`, `.jpg`, `.tif`, `.tiff`.")
151
+ submit_btn = gr.Button("Submit")
152
+ clear_btn = gr.Button("Clear")
153
+ with gr.Column():
154
+ output_image = gr.Image(label="Segmentation Result")
155
+
156
+ def handle_submit(model_name, img):
157
+ if img is None:
158
+ return None
159
+ _, result = segment(model_name, img) # Only return the mask (segmentation result)
160
+ return result
161
+
162
+ submit_btn.click(
163
+ fn=handle_submit,
164
+ inputs=[model_dropdown, image_input],
165
+ outputs=output_image
166
+ )
167
+
168
+ clear_btn.click(
169
+ lambda: [None, None],
170
+ inputs=None,
171
+ outputs=[image_input, output_image]
172
+ )
173
+
174
+ # === SAMPLE IMAGES SECTION ===
175
+ gr.Markdown("---")
176
+ gr.Markdown("### Sample Images (click to use as input)")
177
+
178
+ # Original and resized thumbnails
179
+ original_sample_paths = [
180
+ "img1.png",
181
+ "img2.png",
182
+ "img3.png"
183
+ ]
184
+
185
+ resized_sample_paths = []
186
+ for idx, p in enumerate(original_sample_paths):
187
+ img = PILImage.open(p).resize((128, 128))
188
+ temp_path = f"/tmp/sample_resized_{idx}.png"
189
+ img.save(temp_path)
190
+ resized_sample_paths.append(temp_path)
191
+
192
+ sample_image_components = []
193
+ with gr.Row():
194
+ for i, img_path in enumerate(resized_sample_paths):
195
+ def load_full_image(idx=i): # Capture loop index properly
196
+ return PILImage.open(original_sample_paths[idx])
197
+
198
+ sample_img = gr.Image(value=img_path, type="pil", interactive=True, show_label=False)
199
+ sample_img.select(
200
+ fn=load_full_image,
201
+ inputs=[],
202
+ outputs=image_input
203
+ )
204
+ sample_image_components.append(sample_img)
205
+
206
+
207
+ app.launch()