Spaces:
Runtime error
Runtime error
| import torch | |
| from PIL import Image | |
| import io | |
| import litserve as lit | |
| import base64 | |
| from torchvision import transforms | |
| from src.models.catdog_model import ViTTinyClassifier | |
| import hydra | |
| from omegaconf import DictConfig, OmegaConf | |
| from dotenv import load_dotenv, find_dotenv | |
| import rootutils | |
| from loguru import logger | |
| from src.utils.logging_utils import setup_logger | |
| from pathlib import Path | |
| # Load environment variables | |
| load_dotenv(find_dotenv(".env")) | |
| # Setup root directory | |
| root = rootutils.setup_root(__file__, indicator=".project-root") | |
| logger.info(f"Root directory set to: {root}") | |
| class ImageClassifierAPI(lit.LitAPI): | |
| def __init__(self, cfg: DictConfig): | |
| """ | |
| Initialize the API with Hydra configuration. | |
| """ | |
| super().__init__() | |
| self.cfg = cfg | |
| # Validate required config keys | |
| required_keys = ["ckpt_path", "data.image_size", "labels"] | |
| missing_keys = [key for key in required_keys if not OmegaConf.select(cfg, key)] | |
| if missing_keys: | |
| logger.error(f"Missing required config keys: {missing_keys}") | |
| raise ValueError(f"Missing required config keys: {missing_keys}") | |
| logger.info(f"Configuration validated: {OmegaConf.to_yaml(cfg)}") | |
| def setup(self, device): | |
| """Initialize the model and necessary components.""" | |
| self.device = device | |
| logger.info("Setting up the model and components.") | |
| # Log the configuration for debugging | |
| logger.debug(f"Configuration passed to setup: {OmegaConf.to_yaml(self.cfg)}") | |
| # Load the model from checkpoint | |
| try: | |
| self.model = ViTTinyClassifier.load_from_checkpoint( | |
| checkpoint_path=self.cfg.ckpt_path | |
| ) | |
| self.model = self.model.to(device).eval() | |
| logger.info("Model loaded and moved to device.") | |
| except FileNotFoundError: | |
| logger.error(f"Checkpoint file not found: {self.cfg.ckpt_path}") | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error loading model: {e}") | |
| raise | |
| # Define transforms | |
| self.transforms = transforms.Compose( | |
| [ | |
| transforms.Resize((self.cfg.data.image_size, self.cfg.data.image_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], # Hard-coded mean | |
| std=[0.229, 0.224, 0.225], # Hard-coded std | |
| ), | |
| ] | |
| ) | |
| logger.info("Transforms initialized.") | |
| # Load labels | |
| try: | |
| self.labels = self.cfg.labels | |
| logger.info(f"Labels loaded: {self.labels}") | |
| except Exception as e: | |
| logger.error(f"Error loading labels: {e}") | |
| raise ValueError("Failed to load labels from the configuration.") | |
| def decode_request(self, request): | |
| """Handle both single and batch inputs.""" | |
| # logger.info(f"decode_request received: {request}") | |
| if not isinstance(request, dict) or "image" not in request: | |
| logger.error( | |
| "Invalid request format. Expected a dictionary with key 'image'." | |
| ) | |
| raise ValueError( | |
| "Invalid request format. Expected a dictionary with key 'image'." | |
| ) | |
| return request["image"] | |
| def batch(self, inputs): | |
| """Batch process images.""" | |
| # logger.info(f"batch received inputs: {inputs}") | |
| if not isinstance(inputs, list): | |
| raise ValueError("Input to batch must be a list.") | |
| batch_tensors = [] | |
| try: | |
| for image_bytes in inputs: | |
| if not isinstance(image_bytes, str): # Ensure input is a base64 string | |
| raise ValueError( | |
| f"Input must be a base64-encoded string, got: {type(image_bytes)}" | |
| ) | |
| # Decode base64 string to bytes | |
| img_bytes = base64.b64decode(image_bytes) | |
| # Convert bytes to PIL Image | |
| try: | |
| image = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
| except Exception as img_error: | |
| logger.error(f"Failed to process image: {img_error}") | |
| raise | |
| # Apply transforms and add to batch | |
| tensor = self.transforms(image) | |
| batch_tensors.append(tensor) | |
| return torch.stack(batch_tensors).to(self.device) | |
| except Exception as e: | |
| logger.error(f"Error decoding image: {e}") | |
| raise ValueError("Failed to decode and process the images.") | |
| def predict(self, x): | |
| """Make predictions on the input batch.""" | |
| with torch.inference_mode(): | |
| outputs = self.model(x) | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
| logger.info("Prediction completed.") | |
| return probabilities | |
| def unbatch(self, output): | |
| """Unbatch the output.""" | |
| return [output[i] for i in range(output.size(0))] | |
| def encode_response(self, output): | |
| """Convert model output to API response for batches.""" | |
| try: | |
| probs, indices = torch.topk(output, k=1) | |
| responses = { | |
| "predictions": [ | |
| { | |
| "label": self.labels[idx.item()], | |
| "probability": prob.item(), | |
| } | |
| for prob, idx in zip(probs, indices) | |
| ] | |
| } | |
| logger.info("Batch response successfully encoded.") | |
| return responses | |
| except Exception as e: | |
| logger.error(f"Error encoding batch response: {e}") | |
| raise ValueError("Failed to encode the batch response.") | |
| def main(cfg: DictConfig): | |
| # Initialize loguru | |
| setup_logger(Path(cfg.paths.log_dir) / "infer.log") | |
| logger.info("Starting the Image Classifier API server.") | |
| # Log configuration | |
| logger.info(f"Configuration: {OmegaConf.to_yaml(cfg)}") | |
| # Create the API instance with the Hydra config | |
| api = ImageClassifierAPI(cfg) | |
| # Configure the server | |
| server = lit.LitServer( | |
| api, | |
| accelerator=cfg.server.accelerator, | |
| max_batch_size=cfg.server.max_batch_size, | |
| batch_timeout=cfg.server.batch_timeout, | |
| devices=cfg.server.devices, | |
| workers_per_device=cfg.server.workers_per_device, | |
| ) | |
| server.run(port=cfg.server.port) | |
| if __name__ == "__main__": | |
| main() | |