import torch import torchvision.models as models import torchvision.transforms as transforms import torchvision.datasets as datasets from torchvision.transforms import Compose import requests import random import gradio as gr # Predefined models available in torchvision image_prediction_models = { 'resnet': models.resnet50, 'alexnet': models.alexnet, 'vgg': models.vgg16, 'squeezenet': models.squeezenet1_0, 'densenet': models.densenet161, 'inception': models.inception_v3, 'googlenet': models.googlenet, 'shufflenet': models.shufflenet_v2_x1_0, 'mobilenet': models.mobilenet_v2, 'resnext': models.resnext50_32x4d, 'wide_resnet': models.wide_resnet50_2, 'mnasnet': models.mnasnet1_0, 'efficientnet': models.efficientnet_b0, 'regnet': models.regnet_y_400mf, 'vit': models.vit_b_16, 'convnext': models.convnext_tiny } # Load a pretrained model from torchvision class ModelLoader: def __init__(self, model_dict): self.model_dict = model_dict def load_model(self, model_name): model_name_lower = model_name.lower() if model_name_lower in self.model_dict: model_class = self.model_dict[model_name_lower] model = model_class(pretrained=True) return model else: raise ValueError(f"Model {model_name} is not available for image prediction in torchvision.models") def get_model_names(self): return [name.capitalize() for name in self.model_dict.keys()] # Preprocessor: Prepares image for model input class Preprocessor: def __init__(self): self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) def preprocess(self, model_name): input_size = 224 if model_name == 'inception': input_size = 299 return transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(input_size), transforms.ToTensor(), self.normalize, ]) # Postprocessor: Processes model output class Postprocessor: def __init__(self, labels): self.labels = labels def postprocess_default(self, output): probabilities = torch.nn.functional.softmax(output[0], dim=0) top_prob, top_catid = torch.topk(probabilities, 5) confidences = {self.labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))} return confidences def postprocess_inception(self, output): probabilities = torch.nn.functional.softmax(output[1], dim=0) top_prob, top_catid = torch.topk(probabilities, 5) confidences = {self.labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))} return confidences # ImageClassifier: Classifies images using a selected model class ImageClassifier: def __init__(self, model_loader, preprocessor, postprocessor): self.model_loader = model_loader self.preprocessor = preprocessor self.postprocessor = postprocessor def classify(self, input_image, selected_model): preprocess_input = self.preprocessor.preprocess(model_name=selected_model) input_tensor = preprocess_input(input_image) input_batch = input_tensor.unsqueeze(0) model = self.model_loader.load_model(selected_model) if torch.cuda.is_available(): input_batch = input_batch.to('cuda') model.to('cuda') model.eval() with torch.no_grad(): output = model(input_batch) if selected_model.lower() == 'inception': return self.postprocessor.postprocess_inception(output) else: return self.postprocessor.postprocess_default(output) # CIFAR10ImageProvider: Provides random images from CIFAR-10 dataset class CIFAR10ImageProvider: def __init__(self, dataset_root='./data'): self.dataset_root = dataset_root def get_random_image(self): cifar10 = datasets.CIFAR10(root=self.dataset_root, train=False, download=True, transform=transforms.ToTensor()) random_idx = random.randint(0, len(cifar10) - 1) image, _ = cifar10[random_idx] image = transforms.ToPILImage()(image) return image # GradioApp: Sets up the Gradio interface class GradioApp: def __init__(self, image_classifier, image_provider, model_list): self.image_classifier = image_classifier self.image_provider = image_provider self.model_list = model_list def launch(self): with gr.Blocks() as demo: with gr.Tabs(): with gr.TabItem("Upload Image"): with gr.Row(): with gr.Column(): upload_image = gr.Image(type='pil', label="Upload Image") model_dropdown_upload = gr.Dropdown(self.model_list, label="Select Model") classify_button_upload = gr.Button("Classify") with gr.Column(): output_label_upload = gr.Label(num_top_classes=5) classify_button_upload.click(self.image_classifier.classify, inputs=[upload_image, model_dropdown_upload], outputs=output_label_upload) with gr.TabItem("Generate Random Image"): with gr.Row(): with gr.Column(): generate_button = gr.Button("Generate Random Image") random_image_output = gr.Image(type='pil', label="Random CIFAR-10 Image") with gr.Column(): model_dropdown_random = gr.Dropdown(self.model_list, label="Select Model") classify_button_random = gr.Button("Classify") output_label_random = gr.Label(num_top_classes=5) generate_button.click(self.image_provider.get_random_image, inputs=[], outputs=random_image_output) classify_button_random.click(self.image_classifier.classify, inputs=[random_image_output, model_dropdown_random], outputs=output_label_random) demo.launch() # Main Execution if __name__ == "__main__": # Define available models image_prediction_models = { 'resnet': models.resnet50, 'alexnet': models.alexnet, 'vgg': models.vgg16, 'squeezenet': models.squeezenet1_0, 'densenet': models.densenet161, 'inception': models.inception_v3, 'googlenet': models.googlenet, 'shufflenet': models.shufflenet_v2_x1_0, 'mobilenet': models.mobilenet_v2, 'resnext': models.resnext50_32x4d, 'wide_resnet': models.wide_resnet50_2, 'mnasnet': models.mnasnet1_0, 'efficientnet': models.efficientnet_b0, 'regnet': models.regnet_y_400mf, 'vit': models.vit_b_16, 'convnext': models.convnext_tiny } # Initialize components model_loader = ModelLoader(image_prediction_models) preprocessor = Preprocessor() response = requests.get("https://git.io/JJkYN") labels = response.text.split("\n") postprocessor = Postprocessor(labels) image_classifier = ImageClassifier(model_loader, preprocessor, postprocessor) image_provider = CIFAR10ImageProvider() model_list = model_loader.get_model_names() # Launch Gradio app app = GradioApp(image_classifier, image_provider, model_list) app.launch()