import torch from torch import Tensor import torchvision.models.detection as models import torchvision.transforms as transforms from torchvision.transforms import ToTensor, Compose from torch.nn import Module from io import BytesIO import requests from PIL import Image as Im, ImageDraw import gradio as gr OBJECT_DETECTION_MODELS = { "fasterrcnn_resnet50_fpn": models.fasterrcnn_resnet50_fpn, "maskrcnn_resnet50_fpn": models.maskrcnn_resnet50_fpn, } class ModelLoader: def __init__(self, model_dict: dict): self.model_dict = model_dict def load_model(self, model_name: str) -> Module: model_name_lower = model_name.lower() if model_name_lower in self.model_dict: model_class = self.model_dict[model_name_lower] model = model_class(pretrained=True) model.eval() return model else: raise ValueError(f"Model {model_name} is not supported") class Preprocessor: def __init__(self, transform: Compose = Compose([ToTensor()])): self.transform = transform def preprocess(self, image: Im) -> Tensor: return self.transform(image).unsqueeze(0) class Postprocessor: def __init__(self, threshold: float = 0.5): self.threshold = threshold def postprocess(self, image: Im, predictions: dict) -> Im: draw = ImageDraw.Draw(image) for box, score in zip(predictions['boxes'], predictions['scores']): if score > self.threshold: draw.rectangle(box.tolist(), outline="red", width=3) draw.text((box[0], box[1]), f"{score:.2f}", fill="red") return image class ObjectDetection: def __init__(self, model_loader: ModelLoader, preprocessor: Preprocessor, postprocessor: Postprocessor): self.model_loader = model_loader self.preprocessor = preprocessor self.postprocessor = postprocessor def detect(self, image: Im, selected_model: str) -> Im: model = self.model_loader.load_model(selected_model) input_tensor = self.preprocessor.preprocess(image) if torch.cuda.is_available(): input_tensor = input_tensor.to("cuda") model = model.to("cuda") model.eval() with torch.no_grad(): output = model(input_tensor) return self.postprocessor.postprocess(image, output[0]) class GradioApp: def __init__(self, object_detection: ObjectDetection): self.detector = object_detection def launch(self): with gr.Blocks() as demo: with gr.Row(): with gr.Column(): upload_image = gr.Image(type='pil', label="Upload Image") self.model_dropdown = gr.Dropdown(choices=list(OBJECT_DETECTION_MODELS.keys()), label="Select Model") detection_button = gr.Button("Detect") with gr.Column(): output = gr.Image(type='pil', label="Detection") detection_button.click(fn=self.detector.detect, inputs=[upload_image, self.model_dropdown], outputs=output) demo.launch() model_loader = ModelLoader(OBJECT_DETECTION_MODELS) preprocessor = Preprocessor() postprocessor = Postprocessor() object_detection = ObjectDetection(model_loader, preprocessor, postprocessor) app = GradioApp(object_detection) app.launch()