File size: 3,643 Bytes
29ac506
 
 
 
 
 
 
 
 
00ecfb9
 
8a9238b
29ac506
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import numpy as np
import time, os
import tifffile as tif

from datetime import datetime
from zipfile import ZipFile
from pytz import timezone


from transforms import get_pred_transforms



class BasePredictor:
    def __init__(
        self,
        model,
        device,
        input_path,
        output_path,
        make_submission=False,
        exp_name=None,
        algo_params=None,
    ):
        self.model = model
        self.device = device
        self.input_path = input_path
        self.output_path = output_path
        self.make_submission = make_submission
        self.exp_name = exp_name

        # Assign algoritm-specific arguments
        if algo_params:
            self.__dict__.update((k, v) for k, v in algo_params.items())

        # Prepare inference environments
        self._setups()

    @torch.no_grad()
    def conduct_prediction(self):
        self.model.to(self.device)
        self.model.eval()
        total_time = 0
        total_times = []

        for img_name in self.img_names:
            img_data = self._get_img_data(img_name)
            img_data = img_data.to(self.device)

            start = time.time()

            pred_mask = self._inference(img_data)
            pred_mask = self._post_process(pred_mask.squeeze(0).cpu().numpy())
            self.write_pred_mask(
                pred_mask, self.output_path, img_name, self.make_submission
            )

            end = time.time()

            time_cost = end - start
            total_times.append(time_cost)
            total_time += time_cost
            print(
                f"Prediction finished: {img_name}; img size = {img_data.shape}; costing: {time_cost:.2f}s"
            )

        print(f"\n Total Time Cost: {total_time:.2f}s")

        if self.make_submission:
            fname = "%s.zip" % self.exp_name

            os.makedirs("./submissions", exist_ok=True)
            submission_path = os.path.join("./submissions", fname)

            with ZipFile(submission_path, "w") as zipObj2:
                pred_names = sorted(os.listdir(self.output_path))
                for pred_name in pred_names:
                    pred_path = os.path.join(self.output_path, pred_name)
                    zipObj2.write(pred_path)

            print("\n>>>>> Submission file is saved at: %s\n" % submission_path)

        return time_cost

    def write_pred_mask(self, pred_mask, output_dir, image_name, submission=False):

        # All images should contain at least 5 cells
        if submission:
            if not (np.max(pred_mask) > 5):
                print("[!Caution] Only %d Cells Detected!!!\n" % np.max(pred_mask))

        file_name = image_name.split(".")[0]
        file_name = file_name + "_label.tiff"
        file_path = os.path.join(output_dir, file_name)

        tif.imwrite(file_path, pred_mask, compression="zlib")

    def _setups(self):
        self.pred_transforms = get_pred_transforms()
        os.makedirs(self.output_path, exist_ok=True)

        now = datetime.now(timezone("Asia/Seoul"))
        dt_string = now.strftime("%m%d_%H%M")
        self.exp_name = (
            self.exp_name + dt_string if self.exp_name is not None else dt_string
        )

        self.img_names = sorted(os.listdir(self.input_path))

    def _get_img_data(self, img_name):
        img_path = os.path.join(self.input_path, img_name)
        img_data = self.pred_transforms(img_path)
        img_data = img_data.unsqueeze(0)

        return img_data

    def _inference(self, img_data):
        raise NotImplementedError

    def _post_process(self, pred_mask):
        raise NotImplementedError