Spaces:
Build error
Build error
| import torch | |
| from torch import Tensor as T | |
| import torchvision.models as models | |
| import torchvision.transforms as transforms | |
| import torchvision.datasets as datasets | |
| from torchvision.transforms import Compose | |
| from torch.nn import Module | |
| from torch.nn.functional import softmax | |
| import requests | |
| from PIL import Image | |
| import random | |
| from gradio import Blocks, Tabs, TabItem, Row, Column, Image, Dropdown, Button, Label | |
| # Predefined models available in torchvision | |
| IMAGE_PREDICTION_MODELS = { | |
| 'resnet': models.resnet50, | |
| 'alexnet': models.alexnet, | |
| 'vgg': models.vgg16, | |
| 'squeezenet': models.squeezenet1_0, | |
| 'densenet': models.densenet161, | |
| 'inception': models.inception_v3, | |
| 'googlenet': models.googlenet, | |
| 'shufflenet': models.shufflenet_v2_x1_0, | |
| 'mobilenet': models.mobilenet_v2, | |
| 'resnext': models.resnext50_32x4d, | |
| 'wide_resnet': models.wide_resnet50_2, | |
| 'mnasnet': models.mnasnet1_0, | |
| 'efficientnet': models.efficientnet_b0, | |
| 'regnet': models.regnet_y_400mf, | |
| 'vit': models.vit_b_16, | |
| 'convnext': models.convnext_tiny | |
| } | |
| # Load a pretrained model from torchvision | |
| class ModelLoader: | |
| def __init__(self, model_dict : dict): | |
| self.model_dict = model_dict | |
| def load_model(self, model_name : str) -> Module : | |
| model_name_lower = model_name.lower() | |
| if model_name_lower in self.model_dict: | |
| model_class = self.model_dict[model_name_lower] | |
| model = model_class(pretrained=True) | |
| return model | |
| else: | |
| raise ValueError(f"Model {model_name} is not available for image prediction in torchvision.models") | |
| def get_model_names(self) -> list: | |
| return [name.capitalize() for name in self.model_dict.keys()] | |
| # Preprocessor: Prepares image for model input | |
| class Preprocessor: | |
| def __init__(self): | |
| self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| def preprocess(self, model_name : str) -> Compose: | |
| input_size = 224 | |
| if model_name == 'inception': | |
| input_size = 299 | |
| return transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(input_size), | |
| transforms.ToTensor(), | |
| self.normalize, | |
| ]) | |
| # Postprocessor: Processes model output | |
| class Postprocessor: | |
| def __init__(self, labels : list): | |
| self.labels = labels | |
| def postprocess_default(self, output) -> dict: | |
| probabilities = softmax(output[0], dim=0) | |
| top_prob , top_catid = torch.topk(probabilities, 5) | |
| confidences = {self.labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))} | |
| return confidences | |
| def postprocess_inception(self, output) -> dict: | |
| probabilities : T = softmax(output[1], dim=0) | |
| top_prob, top_catid = torch.topk(probabilities, 5) | |
| confidences = {self.labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))} | |
| return confidences | |
| # ImageClassifier: Classifies images using a selected model | |
| class ImageClassifier: | |
| def __init__(self, model_loader : ModelLoader, preprocessor: Preprocessor, postprocessor : Postprocessor): | |
| self.model_loader = model_loader | |
| self.preprocessor = preprocessor | |
| self.postprocessor = postprocessor | |
| def classify(self, input_image : Image, selected_model : str) -> dict: | |
| preprocess_input : Compose = self.preprocessor.preprocess(model_name=selected_model) | |
| input_tensor : T = preprocess_input(input_image) | |
| input_batch = input_tensor.unsqueeze(0) | |
| model = self.model_loader.load_model(selected_model) | |
| if torch.cuda.is_available(): | |
| input_batch = input_batch.to('cuda') | |
| model.to('cuda') | |
| model.eval() | |
| with torch.no_grad(): | |
| output : T = model(input_batch) | |
| if selected_model.lower() == 'inception': | |
| return self.postprocessor.postprocess_inception(output) | |
| else: | |
| return self.postprocessor.postprocess_default(output) | |
| # CIFAR10ImageProvider: Provides random images from CIFAR-10 dataset | |
| class CIFAR10ImageProvider: | |
| def __init__(self, dataset_root='./data', transform = transforms.ToTensor()): | |
| self.dataset_root = dataset_root | |
| self.transform = transform | |
| def get_random_image(self, resize_dim=(256, 256)) -> Image: | |
| cifar10 = datasets.CIFAR10(root=self.dataset_root, train=False, download=True, transform= self.transform) | |
| random_idx = random.randint(0, len(cifar10) - 1) | |
| image, _ = cifar10[random_idx] | |
| image= transforms.ToPILImage()(image) #bak buraya | |
| image = image.resize(resize_dim, ) | |
| return image | |
| # Interface | |
| class GradioApp: | |
| def __init__(self, image_classifier : ImageClassifier, image_provider : CIFAR10ImageProvider, model_list : list): | |
| self.image_classifier = image_classifier | |
| self.image_provider = image_provider | |
| self.model_list = model_list | |
| def launch(self): | |
| with Blocks() as demo: | |
| with Tabs(): | |
| with TabItem("Upload Image"): | |
| with Row(): | |
| with Column(): | |
| upload_image = Image(type='pil', label="Upload Image") | |
| model_dropdown_upload = Dropdown(self.model_list, label="Select Model") | |
| classify_button_upload = Button("Classify") | |
| with Column(): | |
| output_label_upload = Label(num_top_classes=5) | |
| classify_button_upload.click(self.image_classifier.classify, inputs=[upload_image, model_dropdown_upload], outputs=output_label_upload) | |
| with TabItem("Generate Random Image"): | |
| with Row(): | |
| with Column(): | |
| generate_button = Button("Generate Random Image") | |
| random_image_output = Image(type='pil', label="Random CIFAR-10 Image") | |
| with Column(): | |
| model_dropdown_random = Dropdown(self.model_list, label="Select Model") | |
| classify_button_random = Button("Classify") | |
| output_label_random = Label(num_top_classes=5) | |
| generate_button.click(self.image_provider.get_random_image, inputs=[], outputs=random_image_output) | |
| classify_button_random.click(self.image_classifier.classify, inputs=[random_image_output, model_dropdown_random], outputs=output_label_random) | |
| demo.launch() | |
| # Main | |
| if __name__ == "__main__": | |
| # Initialize | |
| model_loader = ModelLoader(IMAGE_PREDICTION_MODELS) | |
| preprocessor = Preprocessor() | |
| response = requests.get("https://git.io/JJkYN") | |
| labels = response.text.split("\n") | |
| postprocessor = Postprocessor(labels) | |
| image_classifier = ImageClassifier(model_loader, preprocessor, postprocessor) | |
| image_provider = CIFAR10ImageProvider() | |
| model_list = model_loader.get_model_names() | |
| # Launch | |
| app = GradioApp(image_classifier, image_provider, model_list) | |
| app.launch() |