File size: 3,382 Bytes
161f7c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
from torch import Tensor
import torchvision.models.detection as models
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor, Compose
from torch.nn import Module
from io import BytesIO
import requests
from PIL import Image as Im, ImageDraw
import gradio as gr

OBJECT_DETECTION_MODELS = {
    "fasterrcnn_resnet50_fpn": models.fasterrcnn_resnet50_fpn,
    "maskrcnn_resnet50_fpn": models.maskrcnn_resnet50_fpn,
}

class ModelLoader:
    def __init__(self, model_dict: dict):
        self.model_dict = model_dict
    
    def load_model(self, model_name: str) -> Module:
        model_name_lower = model_name.lower()
        if model_name_lower in self.model_dict:
            model_class = self.model_dict[model_name_lower]
            model = model_class(pretrained=True)
            model.eval()
            return model
        else:
            raise ValueError(f"Model {model_name} is not supported")

class Preprocessor:
    def __init__(self, transform: Compose = Compose([ToTensor()])):
        self.transform = transform
    
    def preprocess(self, image: Im) -> Tensor:
        return self.transform(image).unsqueeze(0)

class Postprocessor:
    def __init__(self, threshold: float = 0.5):
        self.threshold = threshold
    
    def postprocess(self, image: Im, predictions: dict) -> Im:
        draw = ImageDraw.Draw(image)
        for box, score in zip(predictions['boxes'], predictions['scores']):
            if score > self.threshold:
                draw.rectangle(box.tolist(), outline="red", width=3)
                draw.text((box[0], box[1]), f"{score:.2f}", fill="red")
        return image

class ObjectDetection:
    def __init__(self, model_loader: ModelLoader, preprocessor: Preprocessor, postprocessor: Postprocessor):
        self.model_loader = model_loader
        self.preprocessor = preprocessor
        self.postprocessor = postprocessor
    
    def detect(self, image: Im, selected_model: str) -> Im:
        model = self.model_loader.load_model(selected_model)
        input_tensor = self.preprocessor.preprocess(image)

        if torch.cuda.is_available():
            input_tensor = input_tensor.to("cuda")
            model = model.to("cuda")      

        model.eval()
        with torch.no_grad():
            output = model(input_tensor)
        return self.postprocessor.postprocess(image, output[0])

class GradioApp:
    def __init__(self, object_detection: ObjectDetection):
        self.detector = object_detection
    
    def launch(self):
        with gr.Blocks() as demo:
            with gr.Row():
                with gr.Column():
                    upload_image = gr.Image(type='pil', label="Upload Image")
                    self.model_dropdown = gr.Dropdown(choices=list(OBJECT_DETECTION_MODELS.keys()), label="Select Model")
                    detection_button = gr.Button("Detect")
                with gr.Column():
                    output = gr.Image(type='pil', label="Detection")
            detection_button.click(fn=self.detector.detect, inputs=[upload_image, self.model_dropdown], outputs=output)
        demo.launch()

model_loader = ModelLoader(OBJECT_DETECTION_MODELS)
preprocessor = Preprocessor()
postprocessor = Postprocessor()
object_detection = ObjectDetection(model_loader, preprocessor, postprocessor)
app = GradioApp(object_detection)
app.launch()