Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from torchvision.transforms import transforms | |
| import numpy as np | |
| from typing import Optional | |
| import torch.nn as nn | |
| import os | |
| import shutil | |
| from utils import page_utils | |
| class BasicBlock(nn.Module): | |
| """ResNet Basic Block. | |
| Parameters | |
| ---------- | |
| in_channels : int | |
| Number of input channels | |
| out_channels : int | |
| Number of output channels | |
| stride : int, optional | |
| Convolution stride size, by default 1 | |
| identity_downsample : Optional[torch.nn.Module], optional | |
| Downsampling layer, by default None | |
| """ | |
| def __init__(self, | |
| in_channels: int, | |
| out_channels: int, | |
| stride: int = 1, | |
| identity_downsample: Optional[torch.nn.Module] = None): | |
| super(BasicBlock, self).__init__() | |
| self.conv1 = nn.Conv2d(in_channels, | |
| out_channels, | |
| kernel_size = 3, | |
| stride = stride, | |
| padding = 1) | |
| self.bn1 = nn.BatchNorm2d(out_channels) | |
| self.relu = nn.ReLU() | |
| self.conv2 = nn.Conv2d(out_channels, | |
| out_channels, | |
| kernel_size = 3, | |
| stride = 1, | |
| padding = 1) | |
| self.bn2 = nn.BatchNorm2d(out_channels) | |
| self.identity_downsample = identity_downsample | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Apply forward computation.""" | |
| identity = x | |
| x = self.conv1(x) | |
| x = self.bn1(x) | |
| x = self.relu(x) | |
| x = self.conv2(x) | |
| x = self.bn2(x) | |
| # Apply an operation to the identity output. | |
| # Useful to reduce the layer size and match from conv2 output | |
| if self.identity_downsample is not None: | |
| identity = self.identity_downsample(identity) | |
| x += identity | |
| x = self.relu(x) | |
| return x | |
| class ResNet18(nn.Module): | |
| """Construct ResNet-18 Model. | |
| Parameters | |
| ---------- | |
| input_channels : int | |
| Number of input channels | |
| num_classes : int | |
| Number of class outputs | |
| """ | |
| def __init__(self, input_channels, num_classes): | |
| super(ResNet18, self).__init__() | |
| self.conv1 = nn.Conv2d(input_channels, | |
| 64, kernel_size = 7, | |
| stride = 2, padding=3) | |
| self.bn1 = nn.BatchNorm2d(64) | |
| self.relu = nn.ReLU() | |
| self.maxpool = nn.MaxPool2d(kernel_size = 3, | |
| stride = 2, | |
| padding = 1) | |
| self.layer1 = self._make_layer(64, 64, stride = 1) | |
| self.layer2 = self._make_layer(64, 128, stride = 2) | |
| self.layer3 = self._make_layer(128, 256, stride = 2) | |
| self.layer4 = self._make_layer(256, 512, stride = 2) | |
| # Last layers | |
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.fc = nn.Linear(512, num_classes) | |
| def identity_downsample(self, in_channels: int, out_channels: int) -> nn.Module: | |
| """Downsampling block to reduce the feature sizes.""" | |
| return nn.Sequential( | |
| nn.Conv2d(in_channels, | |
| out_channels, | |
| kernel_size = 3, | |
| stride = 2, | |
| padding = 1), | |
| nn.BatchNorm2d(out_channels) | |
| ) | |
| def _make_layer(self, in_channels: int, out_channels: int, stride: int) -> nn.Module: | |
| """Create sequential basic block.""" | |
| identity_downsample = None | |
| # Add downsampling function | |
| if stride != 1: | |
| identity_downsample = self.identity_downsample(in_channels, out_channels) | |
| return nn.Sequential( | |
| BasicBlock(in_channels, out_channels, identity_downsample=identity_downsample, stride=stride), | |
| BasicBlock(out_channels, out_channels) | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.conv1(x) | |
| x = self.bn1(x) | |
| x = self.relu(x) | |
| x = self.maxpool(x) | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.layer3(x) | |
| x = self.layer4(x) | |
| x = self.avgpool(x) | |
| x = x.view(x.shape[0], -1) | |
| x = self.fc(x) | |
| return x | |
| model = ResNet18(3, 7) | |
| checkpoint = torch.load('ham10000.ckpt', map_location=torch.device('cpu')) | |
| # The state dict will contains net.layer_name | |
| # Our model doesn't contains `net.` so we have to rename it | |
| state_dict = checkpoint['state_dict'] | |
| for key in list(state_dict.keys()): | |
| if 'net.' in key: | |
| state_dict[key.replace('net.', '')] = state_dict[key] | |
| del state_dict[key] | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| class_names = { | |
| 'akk': 'Actinic Keratosis', | |
| 'bcc': 'Basal Cell Carcinoma', | |
| 'bkl': 'Benign Keratosis', | |
| 'df': 'Dermatofibroma', | |
| 'mel': 'Melanoma', | |
| 'nv': 'Melanocytic Nevi', | |
| 'vasc': 'Vascular Lesion' | |
| } | |
| examples_dir = "sample" | |
| transformation_pipeline = transforms.Compose([ | |
| transforms.ToPILImage(), | |
| transforms.Grayscale(num_output_channels=3), | |
| transforms.CenterCrop((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def preprocess_image(image: np.ndarray): | |
| """Preprocess the input image. | |
| Note that the input image is in RGB mode. | |
| Parameters | |
| ---------- | |
| image: np.ndarray | |
| Input image from callback. | |
| """ | |
| image = transformation_pipeline(image) | |
| image = torch.unsqueeze(image, 0) | |
| return image | |
| def image_classifier(inp): | |
| """Image Classifier Function. | |
| Parameters | |
| ---------- | |
| inp: Optional[np.ndarray] = None | |
| Input image from callback | |
| Returns | |
| ------- | |
| Dict | |
| A dictionary class names and its probability | |
| """ | |
| # If input not valid, return dummy data or raise error | |
| if inp is None: | |
| if inp is None: | |
| return {'cat': 0.3, 'dog': 0.7} | |
| #return {'Actinic Keratosis': 0.0, 'Basal Cell Carcinoma': 0.0, 'Benign Keratosis': 0.0, 'Dermatofibroma': 0.0, 'Melanoma': 0.0, 'Melanocytic Nevi': 0.0, 'Vascular Lesion': 0.0} | |
| # preprocess | |
| image = preprocess_image(inp) | |
| image = image.to(dtype=torch.float32) | |
| # inference | |
| result = model(image) | |
| # postprocess | |
| result = torch.nn.functional.softmax(result, dim=1) # apply softmax | |
| result = result[0].detach().numpy().tolist() # take the first batch | |
| labeled_result = {class_names[key]: score for key, score in result.items()} | |
| return labeled_result | |
| # gradio code block for input and output | |
| with gr.Blocks() as app: | |
| gr.Markdown("# Skin Cancer Classification") | |
| with open('index.html', encoding="utf-8") as f: | |
| description = f.read() | |
| # gradio code block for input and output | |
| with gr.Blocks(theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set( | |
| button_primary_background_fill="*primary_600", | |
| button_primary_background_fill_hover="*primary_500", | |
| button_primary_text_color="white", | |
| )) as app: | |
| with gr.Column(): | |
| gr.HTML(description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp_img = gr.Image() | |
| with gr.Row(): | |
| clear_btn = gr.Button(value="Clear") | |
| process_btn = gr.Button(value="Process", variant="primary") | |
| with gr.Column(): | |
| out_txt = gr.Label(label="Probabilities", num_top_classes=3) | |
| process_btn.click(image_classifier, inputs=inp_img, outputs=out_txt) | |
| clear_btn.click(lambda:( | |
| gr.update(value=None), | |
| gr.update(value=None) | |
| ), | |
| inputs=None, | |
| outputs=[inp_img, out_txt]) | |
| gr.Markdown("## Image Examples") | |
| gr.Examples( | |
| examples=[os.path.join(examples_dir, "nv.jpeg"), | |
| os.path.join(examples_dir, "bcc.jpeg"), | |
| os.path.join(examples_dir, "bkl_1.jpeg"), | |
| os.path.join(examples_dir, "akk.jpeg"), | |
| os.path.join(examples_dir, "mel-_3_.jpeg"), | |
| ], | |
| inputs=inp_img, | |
| outputs=out_txt, | |
| fn=image_classifier, | |
| cache_examples=False, | |
| ) | |
| gr.Markdown(line_breaks=True, value='Author: M HAIKAL FEBRIAN P ([email protected]) <div class="row"><a href="https://github.com/HAikalfebrianp96?tab=repositories"><img alt="GitHub" src="https://img.shields.io/badge/haikal%20phona-000000?logo=github"> </div>') | |
| # demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label") | |
| app.launch(share=True) |