saim1309 commited on
Commit
69b6522
·
verified ·
1 Parent(s): a4b3c40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -40
app.py CHANGED
@@ -13,8 +13,8 @@ from skimage.color import rgb2gray
13
  from csbdeep.utils import normalize
14
  from stardist.models import StarDist2D
15
  from stardist.plot import render_label
16
- #import MEDIARFormer
17
- #import Predictor
18
  from cellpose import models as cellpose_models, io as cellpose_io, plot as cellpose_plot
19
 
20
  # Load SegFormer
@@ -58,41 +58,41 @@ def infer_stardist(image):
58
  return image, Image.fromarray(overlay)
59
 
60
  # Handle MEDIAR prediction
61
- # def infer_mediar(image, temp_dir="temp_mediar"):
62
- # os.makedirs(temp_dir, exist_ok=True)
63
- # input_path = os.path.join(temp_dir, "input_image.tiff")
64
- # output_path = os.path.join(temp_dir, "input_image_label.tiff")
65
-
66
- # image.save(input_path)
67
-
68
- # model_args = {
69
- # "classes": 3,
70
- # "decoder_channels": [1024, 512, 256, 128, 64],
71
- # "decoder_pab_channels": 256,
72
- # "encoder_name": 'mit_b5',
73
- # "in_channels": 3
74
- # }
75
-
76
- # model = MEDIARFormer(**model_args)
77
- # weights = torch.load("from_phase1.pth", map_location="cpu")
78
- # model.load_state_dict(weights, strict=False)
79
- # model.eval()
80
-
81
- # predictor = Predictor(model, "cpu", temp_dir, temp_dir, algo_params={"use_tta": False})
82
- # predictor.img_names = ["input_image.tiff"]
83
- # _ = predictor.conduct_prediction()
84
-
85
- # pred = imread(output_path)
86
- # fig, ax = plt.subplots(figsize=(6, 6))
87
- # ax.imshow(pred, cmap="cividis")
88
- # ax.axis("off")
89
-
90
- # buf = io.BytesIO()
91
- # plt.savefig(buf, format="png")
92
- # plt.close()
93
- # buf.seek(0)
94
-
95
- # return image, Image.open(buf)
96
  # Handle Cellpose prediction
97
  def infer_cellpose(image, temp_dir="temp_cellpose"):
98
  os.makedirs(temp_dir, exist_ok=True)
@@ -128,8 +128,8 @@ def segment(model_name, image):
128
  return infer_segformer(image)
129
  elif model_name == "StarDist":
130
  return infer_stardist(image)
131
- # elif model_name == "MEDIAR":
132
- # return infer_mediar(image)
133
  elif model_name == "Cellpose":
134
  return infer_cellpose(image)
135
  else:
@@ -142,7 +142,7 @@ with gr.Blocks(title="Cell Segmentation Explorer") as app:
142
  with gr.Row():
143
  with gr.Column():
144
  model_dropdown = gr.Dropdown(
145
- choices=["SegFormer", "StarDist", "Cellpose"],
146
  label="Select Segmentation Model",
147
  value="SegFormer"
148
  )
 
13
  from csbdeep.utils import normalize
14
  from stardist.models import StarDist2D
15
  from stardist.plot import render_label
16
+ from MEDIARFormer import MEDIARFormer
17
+ from Predictor import Predictor
18
  from cellpose import models as cellpose_models, io as cellpose_io, plot as cellpose_plot
19
 
20
  # Load SegFormer
 
58
  return image, Image.fromarray(overlay)
59
 
60
  # Handle MEDIAR prediction
61
+ def infer_mediar(image, temp_dir="temp_mediar"):
62
+ os.makedirs(temp_dir, exist_ok=True)
63
+ input_path = os.path.join(temp_dir, "input_image.tiff")
64
+ output_path = os.path.join(temp_dir, "input_image_label.tiff")
65
+
66
+ image.save(input_path)
67
+
68
+ model_args = {
69
+ "classes": 3,
70
+ "decoder_channels": [1024, 512, 256, 128, 64],
71
+ "decoder_pab_channels": 256,
72
+ "encoder_name": 'mit_b5',
73
+ "in_channels": 3
74
+ }
75
+
76
+ model = MEDIARFormer(**model_args)
77
+ weights = torch.load("from_phase1.pth", map_location="cpu")
78
+ model.load_state_dict(weights, strict=False)
79
+ model.eval()
80
+
81
+ predictor = Predictor(model, "cpu", temp_dir, temp_dir, algo_params={"use_tta": False})
82
+ predictor.img_names = ["input_image.tiff"]
83
+ _ = predictor.conduct_prediction()
84
+
85
+ pred = imread(output_path)
86
+ fig, ax = plt.subplots(figsize=(6, 6))
87
+ ax.imshow(pred, cmap="cividis")
88
+ ax.axis("off")
89
+
90
+ buf = io.BytesIO()
91
+ plt.savefig(buf, format="png")
92
+ plt.close()
93
+ buf.seek(0)
94
+
95
+ return image, Image.open(buf)
96
  # Handle Cellpose prediction
97
  def infer_cellpose(image, temp_dir="temp_cellpose"):
98
  os.makedirs(temp_dir, exist_ok=True)
 
128
  return infer_segformer(image)
129
  elif model_name == "StarDist":
130
  return infer_stardist(image)
131
+ elif model_name == "MEDIAR":
132
+ return infer_mediar(image)
133
  elif model_name == "Cellpose":
134
  return infer_cellpose(image)
135
  else:
 
142
  with gr.Row():
143
  with gr.Column():
144
  model_dropdown = gr.Dropdown(
145
+ choices=["SegFormer", "StarDist", "MEDIAR", "Cellpose"],
146
  label="Select Segmentation Model",
147
  value="SegFormer"
148
  )