etemkocaaslan's picture
Create app.py
161f7c0 verified
raw
history blame
3.38 kB
import torch
from torch import Tensor
import torchvision.models.detection as models
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor, Compose
from torch.nn import Module
from io import BytesIO
import requests
from PIL import Image as Im, ImageDraw
import gradio as gr
OBJECT_DETECTION_MODELS = {
"fasterrcnn_resnet50_fpn": models.fasterrcnn_resnet50_fpn,
"maskrcnn_resnet50_fpn": models.maskrcnn_resnet50_fpn,
}
class ModelLoader:
def __init__(self, model_dict: dict):
self.model_dict = model_dict
def load_model(self, model_name: str) -> Module:
model_name_lower = model_name.lower()
if model_name_lower in self.model_dict:
model_class = self.model_dict[model_name_lower]
model = model_class(pretrained=True)
model.eval()
return model
else:
raise ValueError(f"Model {model_name} is not supported")
class Preprocessor:
def __init__(self, transform: Compose = Compose([ToTensor()])):
self.transform = transform
def preprocess(self, image: Im) -> Tensor:
return self.transform(image).unsqueeze(0)
class Postprocessor:
def __init__(self, threshold: float = 0.5):
self.threshold = threshold
def postprocess(self, image: Im, predictions: dict) -> Im:
draw = ImageDraw.Draw(image)
for box, score in zip(predictions['boxes'], predictions['scores']):
if score > self.threshold:
draw.rectangle(box.tolist(), outline="red", width=3)
draw.text((box[0], box[1]), f"{score:.2f}", fill="red")
return image
class ObjectDetection:
def __init__(self, model_loader: ModelLoader, preprocessor: Preprocessor, postprocessor: Postprocessor):
self.model_loader = model_loader
self.preprocessor = preprocessor
self.postprocessor = postprocessor
def detect(self, image: Im, selected_model: str) -> Im:
model = self.model_loader.load_model(selected_model)
input_tensor = self.preprocessor.preprocess(image)
if torch.cuda.is_available():
input_tensor = input_tensor.to("cuda")
model = model.to("cuda")
model.eval()
with torch.no_grad():
output = model(input_tensor)
return self.postprocessor.postprocess(image, output[0])
class GradioApp:
def __init__(self, object_detection: ObjectDetection):
self.detector = object_detection
def launch(self):
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
upload_image = gr.Image(type='pil', label="Upload Image")
self.model_dropdown = gr.Dropdown(choices=list(OBJECT_DETECTION_MODELS.keys()), label="Select Model")
detection_button = gr.Button("Detect")
with gr.Column():
output = gr.Image(type='pil', label="Detection")
detection_button.click(fn=self.detector.detect, inputs=[upload_image, self.model_dropdown], outputs=output)
demo.launch()
model_loader = ModelLoader(OBJECT_DETECTION_MODELS)
preprocessor = Preprocessor()
postprocessor = Postprocessor()
object_detection = ObjectDetection(model_loader, preprocessor, postprocessor)
app = GradioApp(object_detection)
app.launch()