File size: 6,635 Bytes
6053557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b0bd15
6053557
 
 
 
 
 
 
 
 
 
 
1b0bd15
6053557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import torch
from PIL import Image
import io
import litserve as lit
import base64
from torchvision import transforms
from src.models.catdog_model import ViTTinyClassifier
import hydra
from omegaconf import DictConfig, OmegaConf
from dotenv import load_dotenv, find_dotenv
import rootutils
from loguru import logger
from src.utils.logging_utils import setup_logger
from pathlib import Path

# Load environment variables
load_dotenv(find_dotenv(".env"))

# Setup root directory
root = rootutils.setup_root(__file__, indicator=".project-root")
logger.info(f"Root directory set to: {root}")


class ImageClassifierAPI(lit.LitAPI):
    def __init__(self, cfg: DictConfig):
        """
        Initialize the API with Hydra configuration.
        """
        super().__init__()
        self.cfg = cfg

        # Validate required config keys
        required_keys = ["ckpt_path", "data.image_size", "labels"]
        missing_keys = [key for key in required_keys if not OmegaConf.select(cfg, key)]
        if missing_keys:
            logger.error(f"Missing required config keys: {missing_keys}")
            raise ValueError(f"Missing required config keys: {missing_keys}")
        logger.info(f"Configuration validated: {OmegaConf.to_yaml(cfg)}")

    def setup(self, device):
        """Initialize the model and necessary components."""
        self.device = device
        logger.info("Setting up the model and components.")

        # Log the configuration for debugging
        logger.debug(f"Configuration passed to setup: {OmegaConf.to_yaml(self.cfg)}")

        # Load the model from checkpoint
        try:
            self.model = ViTTinyClassifier.load_from_checkpoint(
                checkpoint_path=self.cfg.ckpt_path
            )
            self.model = self.model.to(device).eval()
            logger.info("Model loaded and moved to device.")
        except FileNotFoundError:
            logger.error(f"Checkpoint file not found: {self.cfg.ckpt_path}")
            raise
        except Exception as e:
            logger.error(f"Error loading model: {e}")
            raise

        # Define transforms
        self.transforms = transforms.Compose(
            [
                transforms.Resize((self.cfg.data.image_size, self.cfg.data.image_size)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],  # Hard-coded mean
                    std=[0.229, 0.224, 0.225],  # Hard-coded std
                ),
            ]
        )
        logger.info("Transforms initialized.")

        # Load labels
        try:
            self.labels = self.cfg.labels
            logger.info(f"Labels loaded: {self.labels}")
        except Exception as e:
            logger.error(f"Error loading labels: {e}")
            raise ValueError("Failed to load labels from the configuration.")

    def decode_request(self, request):
        """Handle both single and batch inputs."""
        # logger.info(f"decode_request received: {request}")
        if not isinstance(request, dict) or "image" not in request:
            logger.error(
                "Invalid request format. Expected a dictionary with key 'image'."
            )
            raise ValueError(
                "Invalid request format. Expected a dictionary with key 'image'."
            )
        return request["image"]

    def batch(self, inputs):
        """Batch process images."""
        # logger.info(f"batch received inputs: {inputs}")
        if not isinstance(inputs, list):
            raise ValueError("Input to batch must be a list.")

        batch_tensors = []
        try:
            for image_bytes in inputs:
                if not isinstance(image_bytes, str):  # Ensure input is a base64 string
                    raise ValueError(
                        f"Input must be a base64-encoded string, got: {type(image_bytes)}"
                    )

                # Decode base64 string to bytes
                img_bytes = base64.b64decode(image_bytes)

                # Convert bytes to PIL Image
                try:
                    image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
                except Exception as img_error:
                    logger.error(f"Failed to process image: {img_error}")
                    raise

                # Apply transforms and add to batch
                tensor = self.transforms(image)
                batch_tensors.append(tensor)

            return torch.stack(batch_tensors).to(self.device)
        except Exception as e:
            logger.error(f"Error decoding image: {e}")
            raise ValueError("Failed to decode and process the images.")

    def predict(self, x):
        """Make predictions on the input batch."""
        with torch.inference_mode():
            outputs = self.model(x)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)
        logger.info("Prediction completed.")
        return probabilities

    def unbatch(self, output):
        """Unbatch the output."""
        return [output[i] for i in range(output.size(0))]

    def encode_response(self, output):
        """Convert model output to API response for batches."""
        try:
            probs, indices = torch.topk(output, k=1)
            responses = {
                "predictions": [
                    {
                        "label": self.labels[idx.item()],
                        "probability": prob.item(),
                    }
                    for prob, idx in zip(probs, indices)
                ]
            }
            logger.info("Batch response successfully encoded.")
            return responses
        except Exception as e:
            logger.error(f"Error encoding batch response: {e}")
            raise ValueError("Failed to encode the batch response.")


@hydra.main(config_path="../configs", config_name="infer", version_base="1.3")
def main(cfg: DictConfig):
    # Initialize loguru
    setup_logger(Path(cfg.paths.log_dir) / "infer.log")
    logger.info("Starting the Image Classifier API server.")

    # Log configuration
    logger.info(f"Configuration: {OmegaConf.to_yaml(cfg)}")

    # Create the API instance with the Hydra config
    api = ImageClassifierAPI(cfg)

    # Configure the server
    server = lit.LitServer(
        api,
        accelerator=cfg.server.accelerator,
        max_batch_size=cfg.server.max_batch_size,
        batch_timeout=cfg.server.batch_timeout,
        devices=cfg.server.devices,
        workers_per_device=cfg.server.workers_per_device,
    )
    server.run(port=cfg.server.port)


if __name__ == "__main__":
    main()