Spaces:
Running
Running
import torch | |
import numpy as np | |
import time, os | |
import tifffile as tif | |
from datetime import datetime | |
from zipfile import ZipFile | |
from pytz import timezone | |
from transforms import get_pred_transforms | |
class BasePredictor: | |
def __init__( | |
self, | |
model, | |
device, | |
input_path, | |
output_path, | |
make_submission=False, | |
exp_name=None, | |
algo_params=None, | |
): | |
self.model = model | |
self.device = device | |
self.input_path = input_path | |
self.output_path = output_path | |
self.make_submission = make_submission | |
self.exp_name = exp_name | |
# Assign algoritm-specific arguments | |
if algo_params: | |
self.__dict__.update((k, v) for k, v in algo_params.items()) | |
# Prepare inference environments | |
self._setups() | |
def conduct_prediction(self): | |
self.model.to(self.device) | |
self.model.eval() | |
total_time = 0 | |
total_times = [] | |
for img_name in self.img_names: | |
img_data = self._get_img_data(img_name) | |
img_data = img_data.to(self.device) | |
start = time.time() | |
pred_mask = self._inference(img_data) | |
pred_mask = self._post_process(pred_mask.squeeze(0).cpu().numpy()) | |
self.write_pred_mask( | |
pred_mask, self.output_path, img_name, self.make_submission | |
) | |
end = time.time() | |
time_cost = end - start | |
total_times.append(time_cost) | |
total_time += time_cost | |
print( | |
f"Prediction finished: {img_name}; img size = {img_data.shape}; costing: {time_cost:.2f}s" | |
) | |
print(f"\n Total Time Cost: {total_time:.2f}s") | |
if self.make_submission: | |
fname = "%s.zip" % self.exp_name | |
os.makedirs("./submissions", exist_ok=True) | |
submission_path = os.path.join("./submissions", fname) | |
with ZipFile(submission_path, "w") as zipObj2: | |
pred_names = sorted(os.listdir(self.output_path)) | |
for pred_name in pred_names: | |
pred_path = os.path.join(self.output_path, pred_name) | |
zipObj2.write(pred_path) | |
print("\n>>>>> Submission file is saved at: %s\n" % submission_path) | |
return time_cost | |
def write_pred_mask(self, pred_mask, output_dir, image_name, submission=False): | |
# All images should contain at least 5 cells | |
if submission: | |
if not (np.max(pred_mask) > 5): | |
print("[!Caution] Only %d Cells Detected!!!\n" % np.max(pred_mask)) | |
file_name = image_name.split(".")[0] | |
file_name = file_name + "_label.tiff" | |
file_path = os.path.join(output_dir, file_name) | |
tif.imwrite(file_path, pred_mask, compression="zlib") | |
def _setups(self): | |
self.pred_transforms = get_pred_transforms() | |
os.makedirs(self.output_path, exist_ok=True) | |
now = datetime.now(timezone("Asia/Seoul")) | |
dt_string = now.strftime("%m%d_%H%M") | |
self.exp_name = ( | |
self.exp_name + dt_string if self.exp_name is not None else dt_string | |
) | |
self.img_names = sorted(os.listdir(self.input_path)) | |
def _get_img_data(self, img_name): | |
img_path = os.path.join(self.input_path, img_name) | |
img_data = self.pred_transforms(img_path) | |
img_data = img_data.unsqueeze(0) | |
return img_data | |
def _inference(self, img_data): | |
raise NotImplementedError | |
def _post_process(self, pred_mask): | |
raise NotImplementedError | |