Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -13,8 +13,8 @@ from skimage.color import rgb2gray
|
|
13 |
from csbdeep.utils import normalize
|
14 |
from stardist.models import StarDist2D
|
15 |
from stardist.plot import render_label
|
16 |
-
|
17 |
-
|
18 |
from cellpose import models as cellpose_models, io as cellpose_io, plot as cellpose_plot
|
19 |
|
20 |
# Load SegFormer
|
@@ -58,41 +58,41 @@ def infer_stardist(image):
|
|
58 |
return image, Image.fromarray(overlay)
|
59 |
|
60 |
# Handle MEDIAR prediction
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
# Handle Cellpose prediction
|
97 |
def infer_cellpose(image, temp_dir="temp_cellpose"):
|
98 |
os.makedirs(temp_dir, exist_ok=True)
|
@@ -128,8 +128,8 @@ def segment(model_name, image):
|
|
128 |
return infer_segformer(image)
|
129 |
elif model_name == "StarDist":
|
130 |
return infer_stardist(image)
|
131 |
-
|
132 |
-
|
133 |
elif model_name == "Cellpose":
|
134 |
return infer_cellpose(image)
|
135 |
else:
|
@@ -142,7 +142,7 @@ with gr.Blocks(title="Cell Segmentation Explorer") as app:
|
|
142 |
with gr.Row():
|
143 |
with gr.Column():
|
144 |
model_dropdown = gr.Dropdown(
|
145 |
-
choices=["SegFormer", "StarDist", "Cellpose"],
|
146 |
label="Select Segmentation Model",
|
147 |
value="SegFormer"
|
148 |
)
|
|
|
13 |
from csbdeep.utils import normalize
|
14 |
from stardist.models import StarDist2D
|
15 |
from stardist.plot import render_label
|
16 |
+
from MEDIARFormer import MEDIARFormer
|
17 |
+
from Predictor import Predictor
|
18 |
from cellpose import models as cellpose_models, io as cellpose_io, plot as cellpose_plot
|
19 |
|
20 |
# Load SegFormer
|
|
|
58 |
return image, Image.fromarray(overlay)
|
59 |
|
60 |
# Handle MEDIAR prediction
|
61 |
+
def infer_mediar(image, temp_dir="temp_mediar"):
|
62 |
+
os.makedirs(temp_dir, exist_ok=True)
|
63 |
+
input_path = os.path.join(temp_dir, "input_image.tiff")
|
64 |
+
output_path = os.path.join(temp_dir, "input_image_label.tiff")
|
65 |
+
|
66 |
+
image.save(input_path)
|
67 |
+
|
68 |
+
model_args = {
|
69 |
+
"classes": 3,
|
70 |
+
"decoder_channels": [1024, 512, 256, 128, 64],
|
71 |
+
"decoder_pab_channels": 256,
|
72 |
+
"encoder_name": 'mit_b5',
|
73 |
+
"in_channels": 3
|
74 |
+
}
|
75 |
+
|
76 |
+
model = MEDIARFormer(**model_args)
|
77 |
+
weights = torch.load("from_phase1.pth", map_location="cpu")
|
78 |
+
model.load_state_dict(weights, strict=False)
|
79 |
+
model.eval()
|
80 |
+
|
81 |
+
predictor = Predictor(model, "cpu", temp_dir, temp_dir, algo_params={"use_tta": False})
|
82 |
+
predictor.img_names = ["input_image.tiff"]
|
83 |
+
_ = predictor.conduct_prediction()
|
84 |
+
|
85 |
+
pred = imread(output_path)
|
86 |
+
fig, ax = plt.subplots(figsize=(6, 6))
|
87 |
+
ax.imshow(pred, cmap="cividis")
|
88 |
+
ax.axis("off")
|
89 |
+
|
90 |
+
buf = io.BytesIO()
|
91 |
+
plt.savefig(buf, format="png")
|
92 |
+
plt.close()
|
93 |
+
buf.seek(0)
|
94 |
+
|
95 |
+
return image, Image.open(buf)
|
96 |
# Handle Cellpose prediction
|
97 |
def infer_cellpose(image, temp_dir="temp_cellpose"):
|
98 |
os.makedirs(temp_dir, exist_ok=True)
|
|
|
128 |
return infer_segformer(image)
|
129 |
elif model_name == "StarDist":
|
130 |
return infer_stardist(image)
|
131 |
+
elif model_name == "MEDIAR":
|
132 |
+
return infer_mediar(image)
|
133 |
elif model_name == "Cellpose":
|
134 |
return infer_cellpose(image)
|
135 |
else:
|
|
|
142 |
with gr.Row():
|
143 |
with gr.Column():
|
144 |
model_dropdown = gr.Dropdown(
|
145 |
+
choices=["SegFormer", "StarDist", "MEDIAR", "Cellpose"],
|
146 |
label="Select Segmentation Model",
|
147 |
value="SegFormer"
|
148 |
)
|