Cell-Segmentation / BasePredictor.py
saim1309's picture
Update BasePredictor.py
00ecfb9 verified
import torch
import numpy as np
import time, os
import tifffile as tif
from datetime import datetime
from zipfile import ZipFile
from pytz import timezone
from transforms import get_pred_transforms
class BasePredictor:
def __init__(
self,
model,
device,
input_path,
output_path,
make_submission=False,
exp_name=None,
algo_params=None,
):
self.model = model
self.device = device
self.input_path = input_path
self.output_path = output_path
self.make_submission = make_submission
self.exp_name = exp_name
# Assign algoritm-specific arguments
if algo_params:
self.__dict__.update((k, v) for k, v in algo_params.items())
# Prepare inference environments
self._setups()
@torch.no_grad()
def conduct_prediction(self):
self.model.to(self.device)
self.model.eval()
total_time = 0
total_times = []
for img_name in self.img_names:
img_data = self._get_img_data(img_name)
img_data = img_data.to(self.device)
start = time.time()
pred_mask = self._inference(img_data)
pred_mask = self._post_process(pred_mask.squeeze(0).cpu().numpy())
self.write_pred_mask(
pred_mask, self.output_path, img_name, self.make_submission
)
end = time.time()
time_cost = end - start
total_times.append(time_cost)
total_time += time_cost
print(
f"Prediction finished: {img_name}; img size = {img_data.shape}; costing: {time_cost:.2f}s"
)
print(f"\n Total Time Cost: {total_time:.2f}s")
if self.make_submission:
fname = "%s.zip" % self.exp_name
os.makedirs("./submissions", exist_ok=True)
submission_path = os.path.join("./submissions", fname)
with ZipFile(submission_path, "w") as zipObj2:
pred_names = sorted(os.listdir(self.output_path))
for pred_name in pred_names:
pred_path = os.path.join(self.output_path, pred_name)
zipObj2.write(pred_path)
print("\n>>>>> Submission file is saved at: %s\n" % submission_path)
return time_cost
def write_pred_mask(self, pred_mask, output_dir, image_name, submission=False):
# All images should contain at least 5 cells
if submission:
if not (np.max(pred_mask) > 5):
print("[!Caution] Only %d Cells Detected!!!\n" % np.max(pred_mask))
file_name = image_name.split(".")[0]
file_name = file_name + "_label.tiff"
file_path = os.path.join(output_dir, file_name)
tif.imwrite(file_path, pred_mask, compression="zlib")
def _setups(self):
self.pred_transforms = get_pred_transforms()
os.makedirs(self.output_path, exist_ok=True)
now = datetime.now(timezone("Asia/Seoul"))
dt_string = now.strftime("%m%d_%H%M")
self.exp_name = (
self.exp_name + dt_string if self.exp_name is not None else dt_string
)
self.img_names = sorted(os.listdir(self.input_path))
def _get_img_data(self, img_name):
img_path = os.path.join(self.input_path, img_name)
img_data = self.pred_transforms(img_path)
img_data = img_data.unsqueeze(0)
return img_data
def _inference(self, img_data):
raise NotImplementedError
def _post_process(self, pred_mask):
raise NotImplementedError