| import torch | |
| from torch import Tensor | |
| import torchvision.models.detection as models | |
| import torchvision.transforms as transforms | |
| from torchvision.transforms import ToTensor, Compose | |
| from torch.nn import Module | |
| from io import BytesIO | |
| import requests | |
| from PIL import Image as Im, ImageDraw | |
| import gradio as gr | |
| OBJECT_DETECTION_MODELS = { | |
| "fasterrcnn_resnet50_fpn": models.fasterrcnn_resnet50_fpn, | |
| "maskrcnn_resnet50_fpn": models.maskrcnn_resnet50_fpn, | |
| } | |
| class ModelLoader: | |
| def __init__(self, model_dict: dict): | |
| self.model_dict = model_dict | |
| def load_model(self, model_name: str) -> Module: | |
| model_name_lower = model_name.lower() | |
| if model_name_lower in self.model_dict: | |
| model_class = self.model_dict[model_name_lower] | |
| model = model_class(pretrained=True) | |
| model.eval() | |
| return model | |
| else: | |
| raise ValueError(f"Model {model_name} is not supported") | |
| class Preprocessor: | |
| def __init__(self, transform: Compose = Compose([ToTensor()])): | |
| self.transform = transform | |
| def preprocess(self, image: Im) -> Tensor: | |
| return self.transform(image).unsqueeze(0) | |
| class Postprocessor: | |
| def __init__(self, threshold: float = 0.5): | |
| self.threshold = threshold | |
| def postprocess(self, image: Im, predictions: dict) -> Im: | |
| draw = ImageDraw.Draw(image) | |
| for box, score in zip(predictions['boxes'], predictions['scores']): | |
| if score > self.threshold: | |
| draw.rectangle(box.tolist(), outline="red", width=3) | |
| draw.text((box[0], box[1]), f"{score:.2f}", fill="red") | |
| return image | |
| class ObjectDetection: | |
| def __init__(self, model_loader: ModelLoader, preprocessor: Preprocessor, postprocessor: Postprocessor): | |
| self.model_loader = model_loader | |
| self.preprocessor = preprocessor | |
| self.postprocessor = postprocessor | |
| def detect(self, image: Im, selected_model: str) -> Im: | |
| model = self.model_loader.load_model(selected_model) | |
| input_tensor = self.preprocessor.preprocess(image) | |
| if torch.cuda.is_available(): | |
| input_tensor = input_tensor.to("cuda") | |
| model = model.to("cuda") | |
| model.eval() | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| return self.postprocessor.postprocess(image, output[0]) | |
| class GradioApp: | |
| def __init__(self, object_detection: ObjectDetection): | |
| self.detector = object_detection | |
| def launch(self): | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| upload_image = gr.Image(type='pil', label="Upload Image") | |
| self.model_dropdown = gr.Dropdown(choices=list(OBJECT_DETECTION_MODELS.keys()), label="Select Model") | |
| detection_button = gr.Button("Detect") | |
| with gr.Column(): | |
| output = gr.Image(type='pil', label="Detection") | |
| detection_button.click(fn=self.detector.detect, inputs=[upload_image, self.model_dropdown], outputs=output) | |
| demo.launch() | |
| model_loader = ModelLoader(OBJECT_DETECTION_MODELS) | |
| preprocessor = Preprocessor() | |
| postprocessor = Postprocessor() | |
| object_detection = ObjectDetection(model_loader, preprocessor, postprocessor) | |
| app = GradioApp(object_detection) | |
| app.launch() |