muhammadsalmanalfaridzi commited on
Commit
ed878d7
·
verified ·
1 Parent(s): a300468

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -70
app.py CHANGED
@@ -1,103 +1,125 @@
1
  import gradio as gr
2
- import os
3
  import tempfile
4
- import math
5
- import cv2
6
  import numpy as np
7
- import supervision as sv
8
- from roboflow import Roboflow
9
 
10
- # Initialize Roboflow
11
  rf = Roboflow(api_key="Otg64Ra6wNOgDyjuhMYU")
12
  project = rf.workspace("alat-pelindung-diri").project("nescafe-4base")
13
  model = project.version(16).model
14
 
15
- # Helper function for SAHI (Supervision Slicing)
16
- def calculate_tile_size(image_shape: tuple[int, int], tiles: tuple[int, int], overlap_ratio_wh: tuple[float, float] = (0.0, 0.0)):
17
- w, h = image_shape
18
- rows, columns = tiles
19
- tile_width = math.ceil(w / columns * (1 + overlap_ratio_wh[0]))
20
- tile_height = math.ceil(h / rows * (1 + overlap_ratio_wh[1]))
21
- overlap_wh = (math.ceil(tile_width * overlap_ratio_wh[0]), math.ceil(tile_height * overlap_ratio_wh[1]))
22
- return (tile_width, tile_height), overlap_wh
23
-
24
- # Function to handle inference and tiles
25
- def detect_objects(image):
26
- # Convert PIL image to NumPy array (for OpenCV compatibility)
27
- img = np.array(image) # Gradio image is in PIL format, convert it to NumPy array
28
- img_rgb = img # Keep the image as RGB format, avoid unnecessary conversion to BGR
29
 
30
- image_shape = (img.shape[1], img.shape[0])
 
 
 
 
31
 
32
- # Parameters for slicing (tiles and overlap)
33
- tiles = (8, 8) # Use 8x8 tiles for better detection of small objects
34
- overlap_ratio_wh = (0.2, 0.2) # 20% overlap between tiles for better context
35
- slice_wh, overlap_wh = calculate_tile_size(image_shape, tiles, overlap_ratio_wh)
36
 
37
- # Generate offsets but don't visualize the tiles with rectangles (remove the drawing step)
38
- offsets = sv.InferenceSlicer._generate_offset(image_shape, slice_wh, None, overlap_wh)
39
- tiled_image = img_rgb.copy()
40
 
41
- # Save the PIL image to a temporary file for Roboflow model prediction
42
- with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
43
- image.save(temp_file, format="JPEG")
44
- temp_file_path = temp_file.name
45
 
46
- # Annotate with Roboflow model predictions using the temporary file path
47
- predictions = model.predict(temp_file_path, confidence=40, overlap=30).json() # Adjusted confidence for small object detection
48
- class_count = {}
49
 
50
- # Define a color palette for different classes
51
- color_palette = {
52
- "bearbrand": (0, 255, 0), # Green for class 1
53
- "nescafe latte": (0, 0, 255), # Red for class 2
54
- "nescafe original": (255, 0, 0), # Blue for class 3
55
- "nescafe mocha": (0, 255, 255) # Yellow for class 4
56
- #"class_5": (255, 0, 255) # Magenta for class 5
57
- # You can add more colors based on the number of classes you have
58
- }
59
 
60
- # Draw bounding boxes with different colors and label classes
61
- for prediction in predictions['predictions']:
62
- x1 = int(prediction['x'] - prediction['width'] / 2)
63
- y1 = int(prediction['y'] - prediction['height'] / 2)
64
- x2 = int(prediction['x'] + prediction['width'] / 2)
65
- y2 = int(prediction['y'] + prediction['height'] / 2)
66
 
67
- class_name = prediction['class']
 
 
 
 
 
68
 
69
- # Choose a color for the class, if the class is not in the palette, use white
70
- box_color = color_palette.get(class_name, (255, 255, 255))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- # Draw a bounding box around the detected object
73
- cv2.rectangle(tiled_image, (x1, y1), (x2, y2), box_color, 2) # Bounding box with thickness=2
74
 
75
- # Put the class name label on the bounding box
76
- cv2.putText(tiled_image, class_name, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, box_color, 2) # Label
 
77
 
78
- # Count the class occurrences
 
 
 
79
  if class_name in class_count:
80
  class_count[class_name] += 1
81
  else:
82
  class_count[class_name] = 1
83
 
84
- # Create a result text to show class counts
85
- result_text = "Object counts per class:\n"
86
  for class_name, count in class_count.items():
87
- result_text += f"{class_name}: {count} objects\n"
88
 
89
- # Remove the temporary file after processing
90
  os.remove(temp_file_path)
91
 
92
- return result_text # Only return result_text for object counting
93
 
94
- # Gradio Interface
95
  iface = gr.Interface(
96
- fn=detect_objects,
97
- inputs=gr.Image(type="pil"),
98
- outputs=gr.Textbox(), # Only output the text with object counts
99
- live=True
100
  )
101
 
102
- # Launch Gradio app
103
- iface.launch(debug=True)
 
1
  import gradio as gr
2
+ from roboflow import Roboflow
3
  import tempfile
4
+ import os
5
+ from sahi.slicing import slice_image
6
  import numpy as np
7
+ import cv2
 
8
 
9
+ # Inisialisasi Roboflow (for model path)
10
  rf = Roboflow(api_key="Otg64Ra6wNOgDyjuhMYU")
11
  project = rf.workspace("alat-pelindung-diri").project("nescafe-4base")
12
  model = project.version(16).model
13
 
14
+ # Fungsi untuk melakukan Non-Maximum Suppression (NMS)
15
+ def apply_nms(predictions, iou_threshold=0.5):
16
+ boxes = []
17
+ scores = []
18
+ classes = []
 
 
 
 
 
 
 
 
 
19
 
20
+ # Extract boxes, scores, and class info
21
+ for prediction in predictions:
22
+ boxes.append(prediction['bbox'])
23
+ scores.append(prediction['confidence'])
24
+ classes.append(prediction['class'])
25
 
26
+ boxes = np.array(boxes)
27
+ scores = np.array(scores)
28
+ classes = np.array(classes)
 
29
 
30
+ # Perform NMS using OpenCV
31
+ indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), score_threshold=0.25, nms_threshold=iou_threshold)
 
32
 
33
+ # Convert tuple of indices to a flat NumPy array
34
+ indices = indices.flatten() if isinstance(indices, tuple) else indices
 
 
35
 
36
+ nms_predictions = []
 
 
37
 
38
+ for i in indices:
39
+ nms_predictions.append({
40
+ 'class': classes[i],
41
+ 'bbox': boxes[i],
42
+ 'confidence': scores[i]
43
+ })
 
 
 
44
 
45
+ return nms_predictions
 
 
 
 
 
46
 
47
+ # Fungsi untuk deteksi objek menggunakan Roboflow Model
48
+ def detect_objects(image):
49
+ # Menyimpan gambar sementara
50
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
51
+ image.save(temp_file, format="JPEG")
52
+ temp_file_path = temp_file.name
53
 
54
+ # Slice gambar menjadi potongan-potongan kecil
55
+ slice_image_result = slice_image(
56
+ image=temp_file_path,
57
+ output_file_name="sliced_image",
58
+ output_dir="/tmp/sliced/",
59
+ slice_height=256,
60
+ slice_width=256,
61
+ overlap_height_ratio=0.1,
62
+ overlap_width_ratio=0.1
63
+ )
64
+
65
+ # Print to check the available attributes of the slice_image_result object
66
+ print(f"Slice result: {slice_image_result}")
67
+
68
+ # Try accessing the sliced image paths from the result object
69
+ try:
70
+ sliced_image_paths = slice_image_result.sliced_image_paths # Assuming this is the correct attribute
71
+ print(f"Sliced image paths: {sliced_image_paths}")
72
+ except AttributeError:
73
+ print("Failed to access sliced_image_paths attribute.")
74
+ sliced_image_paths = []
75
+
76
+ # Menyimpan semua prediksi untuk setiap potongan gambar
77
+ all_predictions = []
78
+
79
+ # Prediksi pada setiap potongan gambar
80
+ for sliced_image_path in sliced_image_paths:
81
+ if isinstance(sliced_image_path, str):
82
+ predictions = model.predict(image_path=sliced_image_path).json()
83
+ all_predictions.extend(predictions['predictions'])
84
+ else:
85
+ print(f"Skipping invalid image path: {sliced_image_path}")
86
+
87
+ # Aplikasikan NMS untuk menghapus duplikat deteksi
88
+ postprocessed_predictions = apply_nms(all_predictions, iou_threshold=0.5)
89
 
90
+ # Annotate gambar dengan hasil prediksi
91
+ annotated_image = model.annotate_image_with_predictions(temp_file_path, postprocessed_predictions)
92
 
93
+ # Simpan gambar hasil annotasi
94
+ output_image_path = "/tmp/prediction.jpg"
95
+ annotated_image.save(output_image_path)
96
 
97
+ # Menghitung jumlah objek per kelas
98
+ class_count = {}
99
+ for detection in postprocessed_predictions:
100
+ class_name = detection['class']
101
  if class_name in class_count:
102
  class_count[class_name] += 1
103
  else:
104
  class_count[class_name] = 1
105
 
106
+ # Hasil perhitungan objek
107
+ result_text = "Jumlah objek per kelas:\n"
108
  for class_name, count in class_count.items():
109
+ result_text += f"{class_name}: {count} objek\n"
110
 
111
+ # Hapus file sementara
112
  os.remove(temp_file_path)
113
 
114
+ return output_image_path, result_text
115
 
116
+ # Membuat antarmuka Gradio
117
  iface = gr.Interface(
118
+ fn=detect_objects, # Fungsi yang dipanggil saat gambar diupload
119
+ inputs=gr.Image(type="pil"), # Input berupa gambar
120
+ outputs=[gr.Image(), gr.Textbox()], # Output gambar dan teks
121
+ live=True # Menampilkan hasil secara langsung
122
  )
123
 
124
+ # Menjalankan antarmuka
125
+ iface.launch()