File size: 4,838 Bytes
2c85ae1
 
2ee95b8
 
 
2c85ae1
2ee95b8
 
2c85ae1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ee95b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c85ae1
2ee95b8
 
 
 
2c85ae1
2ee95b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c85ae1
2ee95b8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.transforms import Compose
import requests
import random
import gradio as gr

image_prediction_models = {
    'resnet': models.resnet50,
    'alexnet': models.alexnet,
    'vgg': models.vgg16,
    'squeezenet': models.squeezenet1_0,
    'densenet': models.densenet161,
    'inception': models.inception_v3,
    'googlenet': models.googlenet,
    'shufflenet': models.shufflenet_v2_x1_0,
    'mobilenet': models.mobilenet_v2,
    'resnext': models.resnext50_32x4d,
    'wide_resnet': models.wide_resnet50_2,
    'mnasnet': models.mnasnet1_0,
    'efficientnet': models.efficientnet_b0,
    'regnet': models.regnet_y_400mf,
    'vit': models.vit_b_16,
    'convnext': models.convnext_tiny
}

def load_pretrained_model(model_name):
    model_name_lower = model_name.lower()
    if model_name_lower in image_prediction_models:
        model_class = image_prediction_models[model_name_lower]
        model = model_class(pretrained=True)
        return model
    else:
        raise ValueError(f"Model {model_name} is not available for image prediction in torchvision.models")
    
def get_model_names(models_dict):
    return [name.capitalize() for name in models_dict.keys()]

model_list = get_model_names(image_prediction_models)

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

def preprocess(model_name):
    input_size = 224
    if model_name == 'inception':
        input_size = 299
    return transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        normalize,
    ])

response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")

def postprocess_default(output):
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    top_prob, top_catid = torch.topk(probabilities, 5)
    confidences = {labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))}
    return confidences

def postprocess_inception(output):
    probabilities = torch.nn.functional.softmax(output[1], dim=0)
    top_prob, top_catid = torch.topk(probabilities, 5)
    confidences = {labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))}
    return confidences

def classify_image(input_image, selected_model):
    preprocess_input = preprocess(model_name=selected_model)
    input_tensor = preprocess_input(input_image)
    input_batch = input_tensor.unsqueeze(0)
    model = load_pretrained_model(selected_model)

    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        model.to('cuda')

    model.eval()
    with torch.no_grad():
        output = model(input_batch)

    if selected_model.lower() == 'inception':
        return postprocess_inception(output)
    else:
        return postprocess_default(output)

def get_random_image():
    cifar10 = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
    random_idx = random.randint(0, len(cifar10) - 1)
    image, _ = cifar10[random_idx]
    image = transforms.ToPILImage()(image)
    return image

def generate_random_image():
    image = get_random_image()
    return image

def classify_generated_image(image, model):
    return classify_image(image, model)

with gr.Blocks() as demo:
    with gr.Tabs():
        with gr.TabItem("Upload Image"):
            with gr.Row():
                with gr.Column():
                    upload_image = gr.Image(type='pil', label="Upload Image")
                    model_dropdown_upload = gr.Dropdown(model_list, label="Select Model")
                    classify_button_upload = gr.Button("Classify")
                with gr.Column():
                    output_label_upload = gr.Label(num_top_classes=5)
            classify_button_upload.click(classify_image, inputs=[upload_image, model_dropdown_upload], outputs=output_label_upload)
        
        with gr.TabItem("Generate Random Image"):
            with gr.Row():
                with gr.Column():
                    generate_button = gr.Button("Generate Random Image")
                    random_image_output = gr.Image(type='pil', label="Random CIFAR-10 Image")
                with gr.Column():
                    model_dropdown_random = gr.Dropdown(model_list, label="Select Model")
                    classify_button_random = gr.Button("Classify")
                    output_label_random = gr.Label(num_top_classes=5)
            generate_button.click(generate_random_image, inputs=[], outputs=random_image_output)
            classify_button_random.click(classify_generated_image, inputs=[random_image_output, model_dropdown_random], outputs=output_label_random)

demo.launch()