Snearec commited on
Commit
0d85e45
·
1 Parent(s): b326c15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -151
app.py CHANGED
@@ -1,156 +1,84 @@
1
- import io
2
  import gradio as gr
3
- import matplotlib.pyplot as plt
4
- import requests, validators
5
  import torch
6
- import pathlib
7
- from PIL import Image
8
- from transformers import AutoFeatureExtractor, DetrForObjectDetection, YolosForObjectDetection
9
-
10
- import os
11
-
12
- # colors for visualization
13
- COLORS = [
14
- [0.000, 0.447, 0.741],
15
- [0.850, 0.325, 0.098],
16
- [0.929, 0.694, 0.125],
17
- [0.494, 0.184, 0.556],
18
- [0.466, 0.674, 0.188],
19
- [0.301, 0.745, 0.933]
20
- ]
21
-
22
- def make_prediction(img, feature_extractor, model):
23
- inputs = feature_extractor(img, return_tensors="pt")
24
- outputs = model(**inputs)
25
- img_size = torch.tensor([tuple(reversed(img.size))])
26
- processed_outputs = feature_extractor.post_process(outputs, img_size)
27
- return processed_outputs[0]
28
-
29
- def fig2img(fig):
30
- buf = io.BytesIO()
31
- fig.savefig(buf)
32
- buf.seek(0)
33
- img = Image.open(buf)
34
- return img
35
-
36
-
37
- def visualize_prediction(pil_img, output_dict, threshold=0.7, id2label=None):
38
- keep = output_dict["scores"] > threshold
39
- boxes = output_dict["boxes"][keep].tolist()
40
- scores = output_dict["scores"][keep].tolist()
41
- labels = output_dict["labels"][keep].tolist()
42
- if id2label is not None:
43
- labels = [id2label[x] for x in labels]
44
-
45
- plt.figure(figsize=(16, 10))
46
- plt.imshow(pil_img)
47
- ax = plt.gca()
48
- colors = COLORS * 100
49
- for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors):
50
- ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=3))
51
- ax.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
52
- plt.axis("off")
53
- return fig2img(plt.gcf())
54
-
55
- def detect_objects(model_name,url_input,image_input,threshold):
56
-
57
- #Extract model and feature extractor
58
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
59
-
60
- if 'detr' in model_name:
61
-
62
- model = DetrForObjectDetection.from_pretrained(model_name)
63
-
64
- elif 'yolos' in model_name:
65
-
66
- model = YolosForObjectDetection.from_pretrained(model_name)
67
-
68
- if validators.url(url_input):
69
- image = Image.open(requests.get(url_input, stream=True).raw)
70
-
71
- elif image_input:
72
- image = image_input
73
-
74
- #Make prediction
75
- processed_outputs = make_prediction(image, feature_extractor, model)
76
-
77
- #Visualize prediction
78
- viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
79
-
80
- return viz_img
81
 
82
- def set_example_image(example: list) -> dict:
83
- return gr.Image.update(value=example[0])
84
 
85
- def set_example_url(example: list) -> dict:
86
- return gr.Textbox.update(value=example[0])
87
-
88
-
89
- title = """<h1 id="title">Object Detection App with DETR and YOLOS</h1>"""
90
-
91
- description = """
92
- Links to HuggingFace Models:
93
-
94
- - [facebook/detr-resnet-50](https://huggingface.co/facebook/detr-resnet-50)
95
- - [facebook/detr-resnet-101](https://huggingface.co/facebook/detr-resnet-101)
96
- - [hustvl/yolos-small](https://huggingface.co/hustvl/yolos-small)
97
- - [hustvl/yolos-tiny](https://huggingface.co/hustvl/yolos-tiny)
98
- - [Snearec/detectorMalezasYolo8](https://huggingface.co/Snearec/detectorMalezasYolo8)
99
-
100
- """
101
-
102
- models = ["facebook/detr-resnet-50","facebook/detr-resnet-101",'hustvl/yolos-small','hustvl/yolos-tiny','[Snearec/detectorMalezasYolo8]']
103
- urls = ["https://c8.alamy.com/comp/J2AB4K/the-new-york-stock-exchange-on-the-wall-street-in-new-york-J2AB4K.jpg"]
104
-
105
- twitter_link = """
106
- [![](https://img.shields.io/twitter/follow/nickmuchi?label=@nickmuchi&style=social)](https://twitter.com/nickmuchi)
107
- """
108
-
109
- css = '''
110
- h1#title {
111
- text-align: center;
112
- }
113
- '''
114
- demo = gr.Blocks(css=css)
115
-
116
- with demo:
117
- gr.Markdown(title)
118
- gr.Markdown(description)
119
- gr.Markdown(twitter_link)
120
- options = gr.Dropdown(choices=models,label='Select Object Detection Model',show_label=True)
121
- slider_input = gr.Slider(minimum=0.2,maximum=1,value=0.7,label='Prediction Threshold')
122
-
123
- with gr.Tabs():
124
- with gr.TabItem('Image URL'):
125
- with gr.Row():
126
- url_input = gr.Textbox(lines=2,label='Enter valid image URL here..')
127
- img_output_from_url = gr.Image(shape=(650,650))
128
-
129
- with gr.Row():
130
- example_url = gr.Dataset(components=[url_input],samples=[[str(url)] for url in urls])
131
-
132
- url_but = gr.Button('Detect')
133
-
134
- with gr.TabItem('Image Upload'):
135
- with gr.Row():
136
- img_input = gr.Image(type='pil')
137
- img_output_from_upload= gr.Image(shape=(650,650))
138
-
139
- with gr.Row():
140
- example_images = gr.Dataset(components=[img_input],
141
- samples=[[path.as_posix()]
142
- for path in sorted(pathlib.Path('images').rglob('*.JPG'))])
143
-
144
- img_but = gr.Button('Detect')
145
-
146
-
147
- url_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=img_output_from_url,queue=True)
148
- img_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=img_output_from_upload,queue=True)
149
- example_images.click(fn=set_example_image,inputs=[example_images],outputs=[img_input])
150
- example_url.click(fn=set_example_url,inputs=[example_url],outputs=[url_input])
151
-
152
-
153
- gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=nickmuchi-object-detection-with-detr-and-yolos)")
154
 
155
-
156
- demo.launch(enable_queue=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
2
  import torch
3
+ from sahi.prediction import ObjectPrediction
4
+ from sahi.utils.cv import visualize_object_predictions, read_image
5
+ from ultralyticsplus import YOLO
6
+
7
+ # Images
8
+ torch.hub.download_url_to_file('https://raw.githubusercontent.com/kadirnar/dethub/main/data/images/highway.jpg', 'highway.jpg')
9
+ torch.hub.download_url_to_file('https://user-images.githubusercontent.com/34196005/142742872-1fefcc4d-d7e6-4c43-bbb7-6b5982f7e4ba.jpg', 'highway1.jpg')
10
+ torch.hub.download_url_to_file('https://raw.githubusercontent.com/obss/sahi/main/tests/data/small-vehicles1.jpeg', 'small-vehicles1.jpeg')
11
+
12
+ def yolov8_inference(
13
+ image: gr.inputs.Image = None,
14
+ model_path: gr.inputs.Dropdown = None,
15
+ image_size: gr.inputs.Slider = 640,
16
+ conf_threshold: gr.inputs.Slider = 0.25,
17
+ iou_threshold: gr.inputs.Slider = 0.45,
18
+ ):
19
+ """
20
+ YOLOv8 inference function
21
+ Args:
22
+ image: Input image
23
+ model_path: Path to the model
24
+ image_size: Image size
25
+ conf_threshold: Confidence threshold
26
+ iou_threshold: IOU threshold
27
+ Returns:
28
+ Rendered image
29
+ """
30
+ model = YOLO(model_path)
31
+ model.conf = conf_threshold
32
+ model.iou = iou_threshold
33
+ results = model.predict(image, imgsz=image_size, return_outputs=True)
34
+ object_prediction_list = []
35
+ for _, image_results in enumerate(results):
36
+ if len(image_results)!=0:
37
+ image_predictions_in_xyxy_format = image_results['det']
38
+ for pred in image_predictions_in_xyxy_format:
39
+ x1, y1, x2, y2 = (
40
+ int(pred[0]),
41
+ int(pred[1]),
42
+ int(pred[2]),
43
+ int(pred[3]),
44
+ )
45
+ bbox = [x1, y1, x2, y2]
46
+ score = pred[4]
47
+ category_name = model.model.names[int(pred[5])]
48
+ category_id = pred[5]
49
+ object_prediction = ObjectPrediction(
50
+ bbox=bbox,
51
+ category_id=int(category_id),
52
+ score=score,
53
+ category_name=category_name,
54
+ )
55
+ object_prediction_list.append(object_prediction)
56
+
57
+ image = read_image(image)
58
+ output_image = visualize_object_predictions(image=image, object_prediction_list=object_prediction_list)
59
+ return output_image['image']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
 
 
61
 
62
+ inputs = [
63
+ gr.inputs.Image(type="filepath", label="Input Image"),
64
+ gr.inputs.Dropdown(["kadirnar/yolov8n-v8.0", "kadirnar/yolov8m-v8.0", "kadirnar/yolov8l-v8.0", "kadirnar/yolov8x-v8.0", "kadirnar/yolov8x6-v8.0"],
65
+ default="kadirnar/yolov8m-v8.0", label="Model"),
66
+ gr.inputs.Slider(minimum=320, maximum=1280, default=640, step=32, label="Image Size"),
67
+ gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.25, step=0.05, label="Confidence Threshold"),
68
+ gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.45, step=0.05, label="IOU Threshold"),
69
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ outputs = gr.outputs.Image(type="filepath", label="Output Image")
72
+ title = "Ultralytics YOLOv8: State-of-the-Art YOLO Models"
73
+
74
+ examples = [['highway.jpg', 'kadirnar/yolov8m-v8.0', 640, 0.25, 0.45], ['highway1.jpg', 'kadirnar/yolov8l-v8.0', 640, 0.25, 0.45], ['small-vehicles1.jpeg', 'kadirnar/yolov8x-v8.0', 1280, 0.25, 0.45]]
75
+ demo_app = gr.Interface(
76
+ fn=yolov8_inference,
77
+ inputs=inputs,
78
+ outputs=outputs,
79
+ title=title,
80
+ examples=examples,
81
+ cache_examples=True,
82
+ theme='huggingface',
83
+ )
84
+ demo_app.launch(debug=True, enable_queue=True)