|
import torch |
|
from torch import Tensor |
|
import torchvision.models.detection as models |
|
import torchvision.transforms as transforms |
|
from torchvision.transforms import ToTensor, Compose |
|
from torch.nn import Module |
|
from io import BytesIO |
|
import requests |
|
from PIL import Image as Im, ImageDraw |
|
import gradio as gr |
|
|
|
OBJECT_DETECTION_MODELS = { |
|
"fasterrcnn_resnet50_fpn": models.fasterrcnn_resnet50_fpn, |
|
"maskrcnn_resnet50_fpn": models.maskrcnn_resnet50_fpn, |
|
} |
|
|
|
class ModelLoader: |
|
def __init__(self, model_dict: dict): |
|
self.model_dict = model_dict |
|
|
|
def load_model(self, model_name: str) -> Module: |
|
model_name_lower = model_name.lower() |
|
if model_name_lower in self.model_dict: |
|
model_class = self.model_dict[model_name_lower] |
|
model = model_class(pretrained=True) |
|
model.eval() |
|
return model |
|
else: |
|
raise ValueError(f"Model {model_name} is not supported") |
|
|
|
class Preprocessor: |
|
def __init__(self, transform: Compose = Compose([ToTensor()])): |
|
self.transform = transform |
|
|
|
def preprocess(self, image: Im) -> Tensor: |
|
return self.transform(image).unsqueeze(0) |
|
|
|
class Postprocessor: |
|
def __init__(self, threshold: float = 0.5): |
|
self.threshold = threshold |
|
|
|
def postprocess(self, image: Im, predictions: dict) -> Im: |
|
draw = ImageDraw.Draw(image) |
|
for box, score in zip(predictions['boxes'], predictions['scores']): |
|
if score > self.threshold: |
|
draw.rectangle(box.tolist(), outline="red", width=3) |
|
draw.text((box[0], box[1]), f"{score:.2f}", fill="red") |
|
return image |
|
|
|
class ObjectDetection: |
|
def __init__(self, model_loader: ModelLoader, preprocessor: Preprocessor, postprocessor: Postprocessor): |
|
self.model_loader = model_loader |
|
self.preprocessor = preprocessor |
|
self.postprocessor = postprocessor |
|
|
|
def detect(self, image: Im, selected_model: str) -> Im: |
|
model = self.model_loader.load_model(selected_model) |
|
input_tensor = self.preprocessor.preprocess(image) |
|
|
|
if torch.cuda.is_available(): |
|
input_tensor = input_tensor.to("cuda") |
|
model = model.to("cuda") |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
output = model(input_tensor) |
|
return self.postprocessor.postprocess(image, output[0]) |
|
|
|
class GradioApp: |
|
def __init__(self, object_detection: ObjectDetection): |
|
self.detector = object_detection |
|
|
|
def launch(self): |
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
upload_image = gr.Image(type='pil', label="Upload Image") |
|
self.model_dropdown = gr.Dropdown(choices=list(OBJECT_DETECTION_MODELS.keys()), label="Select Model") |
|
detection_button = gr.Button("Detect") |
|
with gr.Column(): |
|
output = gr.Image(type='pil', label="Detection") |
|
detection_button.click(fn=self.detector.detect, inputs=[upload_image, self.model_dropdown], outputs=output) |
|
demo.launch() |
|
|
|
model_loader = ModelLoader(OBJECT_DETECTION_MODELS) |
|
preprocessor = Preprocessor() |
|
postprocessor = Postprocessor() |
|
object_detection = ObjectDetection(model_loader, preprocessor, postprocessor) |
|
app = GradioApp(object_detection) |
|
app.launch() |