Spaces:
Running
Running
File size: 3,643 Bytes
29ac506 00ecfb9 8a9238b 29ac506 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import torch
import numpy as np
import time, os
import tifffile as tif
from datetime import datetime
from zipfile import ZipFile
from pytz import timezone
from transforms import get_pred_transforms
class BasePredictor:
def __init__(
self,
model,
device,
input_path,
output_path,
make_submission=False,
exp_name=None,
algo_params=None,
):
self.model = model
self.device = device
self.input_path = input_path
self.output_path = output_path
self.make_submission = make_submission
self.exp_name = exp_name
# Assign algoritm-specific arguments
if algo_params:
self.__dict__.update((k, v) for k, v in algo_params.items())
# Prepare inference environments
self._setups()
@torch.no_grad()
def conduct_prediction(self):
self.model.to(self.device)
self.model.eval()
total_time = 0
total_times = []
for img_name in self.img_names:
img_data = self._get_img_data(img_name)
img_data = img_data.to(self.device)
start = time.time()
pred_mask = self._inference(img_data)
pred_mask = self._post_process(pred_mask.squeeze(0).cpu().numpy())
self.write_pred_mask(
pred_mask, self.output_path, img_name, self.make_submission
)
end = time.time()
time_cost = end - start
total_times.append(time_cost)
total_time += time_cost
print(
f"Prediction finished: {img_name}; img size = {img_data.shape}; costing: {time_cost:.2f}s"
)
print(f"\n Total Time Cost: {total_time:.2f}s")
if self.make_submission:
fname = "%s.zip" % self.exp_name
os.makedirs("./submissions", exist_ok=True)
submission_path = os.path.join("./submissions", fname)
with ZipFile(submission_path, "w") as zipObj2:
pred_names = sorted(os.listdir(self.output_path))
for pred_name in pred_names:
pred_path = os.path.join(self.output_path, pred_name)
zipObj2.write(pred_path)
print("\n>>>>> Submission file is saved at: %s\n" % submission_path)
return time_cost
def write_pred_mask(self, pred_mask, output_dir, image_name, submission=False):
# All images should contain at least 5 cells
if submission:
if not (np.max(pred_mask) > 5):
print("[!Caution] Only %d Cells Detected!!!\n" % np.max(pred_mask))
file_name = image_name.split(".")[0]
file_name = file_name + "_label.tiff"
file_path = os.path.join(output_dir, file_name)
tif.imwrite(file_path, pred_mask, compression="zlib")
def _setups(self):
self.pred_transforms = get_pred_transforms()
os.makedirs(self.output_path, exist_ok=True)
now = datetime.now(timezone("Asia/Seoul"))
dt_string = now.strftime("%m%d_%H%M")
self.exp_name = (
self.exp_name + dt_string if self.exp_name is not None else dt_string
)
self.img_names = sorted(os.listdir(self.input_path))
def _get_img_data(self, img_name):
img_path = os.path.join(self.input_path, img_name)
img_data = self.pred_transforms(img_path)
img_data = img_data.unsqueeze(0)
return img_data
def _inference(self, img_data):
raise NotImplementedError
def _post_process(self, pred_mask):
raise NotImplementedError
|