etemkocaaslan commited on
Commit
161f7c0
·
verified ·
1 Parent(s): 03f5b96

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ import torchvision.models.detection as models
4
+ import torchvision.transforms as transforms
5
+ from torchvision.transforms import ToTensor, Compose
6
+ from torch.nn import Module
7
+ from io import BytesIO
8
+ import requests
9
+ from PIL import Image as Im, ImageDraw
10
+ import gradio as gr
11
+
12
+ OBJECT_DETECTION_MODELS = {
13
+ "fasterrcnn_resnet50_fpn": models.fasterrcnn_resnet50_fpn,
14
+ "maskrcnn_resnet50_fpn": models.maskrcnn_resnet50_fpn,
15
+ }
16
+
17
+ class ModelLoader:
18
+ def __init__(self, model_dict: dict):
19
+ self.model_dict = model_dict
20
+
21
+ def load_model(self, model_name: str) -> Module:
22
+ model_name_lower = model_name.lower()
23
+ if model_name_lower in self.model_dict:
24
+ model_class = self.model_dict[model_name_lower]
25
+ model = model_class(pretrained=True)
26
+ model.eval()
27
+ return model
28
+ else:
29
+ raise ValueError(f"Model {model_name} is not supported")
30
+
31
+ class Preprocessor:
32
+ def __init__(self, transform: Compose = Compose([ToTensor()])):
33
+ self.transform = transform
34
+
35
+ def preprocess(self, image: Im) -> Tensor:
36
+ return self.transform(image).unsqueeze(0)
37
+
38
+ class Postprocessor:
39
+ def __init__(self, threshold: float = 0.5):
40
+ self.threshold = threshold
41
+
42
+ def postprocess(self, image: Im, predictions: dict) -> Im:
43
+ draw = ImageDraw.Draw(image)
44
+ for box, score in zip(predictions['boxes'], predictions['scores']):
45
+ if score > self.threshold:
46
+ draw.rectangle(box.tolist(), outline="red", width=3)
47
+ draw.text((box[0], box[1]), f"{score:.2f}", fill="red")
48
+ return image
49
+
50
+ class ObjectDetection:
51
+ def __init__(self, model_loader: ModelLoader, preprocessor: Preprocessor, postprocessor: Postprocessor):
52
+ self.model_loader = model_loader
53
+ self.preprocessor = preprocessor
54
+ self.postprocessor = postprocessor
55
+
56
+ def detect(self, image: Im, selected_model: str) -> Im:
57
+ model = self.model_loader.load_model(selected_model)
58
+ input_tensor = self.preprocessor.preprocess(image)
59
+
60
+ if torch.cuda.is_available():
61
+ input_tensor = input_tensor.to("cuda")
62
+ model = model.to("cuda")
63
+
64
+ model.eval()
65
+ with torch.no_grad():
66
+ output = model(input_tensor)
67
+ return self.postprocessor.postprocess(image, output[0])
68
+
69
+ class GradioApp:
70
+ def __init__(self, object_detection: ObjectDetection):
71
+ self.detector = object_detection
72
+
73
+ def launch(self):
74
+ with gr.Blocks() as demo:
75
+ with gr.Row():
76
+ with gr.Column():
77
+ upload_image = gr.Image(type='pil', label="Upload Image")
78
+ self.model_dropdown = gr.Dropdown(choices=list(OBJECT_DETECTION_MODELS.keys()), label="Select Model")
79
+ detection_button = gr.Button("Detect")
80
+ with gr.Column():
81
+ output = gr.Image(type='pil', label="Detection")
82
+ detection_button.click(fn=self.detector.detect, inputs=[upload_image, self.model_dropdown], outputs=output)
83
+ demo.launch()
84
+
85
+ model_loader = ModelLoader(OBJECT_DETECTION_MODELS)
86
+ preprocessor = Preprocessor()
87
+ postprocessor = Postprocessor()
88
+ object_detection = ObjectDetection(model_loader, preprocessor, postprocessor)
89
+ app = GradioApp(object_detection)
90
+ app.launch()