muhammadsalmanalfaridzi commited on
Commit
a300468
·
verified ·
1 Parent(s): 7a7b0dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -88
app.py CHANGED
@@ -1,121 +1,103 @@
1
  import gradio as gr
2
- from roboflow import Roboflow
3
- import tempfile
4
  import os
5
- from sahi.slicing import slice_image
6
- import numpy as np
7
  import cv2
 
 
 
8
 
9
- # Inisialisasi Roboflow (for model path)
10
  rf = Roboflow(api_key="Otg64Ra6wNOgDyjuhMYU")
11
  project = rf.workspace("alat-pelindung-diri").project("nescafe-4base")
12
  model = project.version(16).model
13
 
14
- # Fungsi untuk melakukan Non-Maximum Suppression (NMS)
15
- def apply_nms(predictions, iou_threshold=0.5):
16
- boxes = []
17
- scores = []
18
- classes = []
 
 
 
19
 
20
- # Extract boxes, scores, and class info
21
- for prediction in predictions:
22
- boxes.append(prediction['bbox'])
23
- scores.append(prediction['confidence'])
24
- classes.append(prediction['class'])
25
-
26
- boxes = np.array(boxes)
27
- scores = np.array(scores)
28
- classes = np.array(classes)
29
 
30
- # Perform NMS using OpenCV
31
- indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), score_threshold=0.25, nms_threshold=iou_threshold)
32
- nms_predictions = []
33
 
34
- for i in indices.flatten():
35
- nms_predictions.append({
36
- 'class': classes[i],
37
- 'bbox': boxes[i],
38
- 'confidence': scores[i]
39
- })
40
 
41
- return nms_predictions
 
 
42
 
43
- # Fungsi untuk deteksi objek menggunakan Roboflow Model
44
- def detect_objects(image):
45
- # Menyimpan gambar sementara
46
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
47
  image.save(temp_file, format="JPEG")
48
  temp_file_path = temp_file.name
49
 
50
- # Slice gambar menjadi potongan-potongan kecil
51
- slice_image_result = slice_image(
52
- image=temp_file_path,
53
- output_file_name="sliced_image",
54
- output_dir="/tmp/sliced/",
55
- slice_height=256,
56
- slice_width=256,
57
- overlap_height_ratio=0.1,
58
- overlap_width_ratio=0.1
59
- )
60
-
61
- # Print to check the available attributes of the slice_image_result object
62
- print(f"Slice result: {slice_image_result}")
63
-
64
- # Try accessing the sliced image paths from the result object
65
- try:
66
- sliced_image_paths = slice_image_result.sliced_image_paths # Assuming this is the correct attribute
67
- print(f"Sliced image paths: {sliced_image_paths}")
68
- except AttributeError:
69
- print("Failed to access sliced_image_paths attribute.")
70
- sliced_image_paths = []
71
-
72
- # Menyimpan semua prediksi untuk setiap potongan gambar
73
- all_predictions = []
74
-
75
- # Prediksi pada setiap potongan gambar
76
- for sliced_image_path in sliced_image_paths:
77
- if isinstance(sliced_image_path, str):
78
- predictions = model.predict(image_path=sliced_image_path).json()
79
- all_predictions.extend(predictions['predictions'])
80
- else:
81
- print(f"Skipping invalid image path: {sliced_image_path}")
82
-
83
- # Aplikasikan NMS untuk menghapus duplikat deteksi
84
- postprocessed_predictions = apply_nms(all_predictions, iou_threshold=0.5)
85
 
86
- # Annotate gambar dengan hasil prediksi
87
- annotated_image = model.annotate_image_with_predictions(temp_file_path, postprocessed_predictions)
 
 
 
 
 
 
 
88
 
89
- # Simpan gambar hasil annotasi
90
- output_image_path = "/tmp/prediction.jpg"
91
- annotated_image.save(output_image_path)
 
 
 
92
 
93
- # Menghitung jumlah objek per kelas
94
- class_count = {}
95
- for detection in postprocessed_predictions:
96
- class_name = detection['class']
 
 
 
 
 
 
 
 
97
  if class_name in class_count:
98
  class_count[class_name] += 1
99
  else:
100
  class_count[class_name] = 1
101
 
102
- # Hasil perhitungan objek
103
- result_text = "Jumlah objek per kelas:\n"
104
  for class_name, count in class_count.items():
105
- result_text += f"{class_name}: {count} objek\n"
106
 
107
- # Hapus file sementara
108
  os.remove(temp_file_path)
109
 
110
- return output_image_path, result_text
111
 
112
- # Membuat antarmuka Gradio
113
  iface = gr.Interface(
114
- fn=detect_objects, # Fungsi yang dipanggil saat gambar diupload
115
- inputs=gr.Image(type="pil"), # Input berupa gambar
116
- outputs=[gr.Image(), gr.Textbox()], # Output gambar dan teks
117
- live=True # Menampilkan hasil secara langsung
118
  )
119
 
120
- # Menjalankan antarmuka
121
- iface.launch()
 
1
  import gradio as gr
 
 
2
  import os
3
+ import tempfile
4
+ import math
5
  import cv2
6
+ import numpy as np
7
+ import supervision as sv
8
+ from roboflow import Roboflow
9
 
10
+ # Initialize Roboflow
11
  rf = Roboflow(api_key="Otg64Ra6wNOgDyjuhMYU")
12
  project = rf.workspace("alat-pelindung-diri").project("nescafe-4base")
13
  model = project.version(16).model
14
 
15
+ # Helper function for SAHI (Supervision Slicing)
16
+ def calculate_tile_size(image_shape: tuple[int, int], tiles: tuple[int, int], overlap_ratio_wh: tuple[float, float] = (0.0, 0.0)):
17
+ w, h = image_shape
18
+ rows, columns = tiles
19
+ tile_width = math.ceil(w / columns * (1 + overlap_ratio_wh[0]))
20
+ tile_height = math.ceil(h / rows * (1 + overlap_ratio_wh[1]))
21
+ overlap_wh = (math.ceil(tile_width * overlap_ratio_wh[0]), math.ceil(tile_height * overlap_ratio_wh[1]))
22
+ return (tile_width, tile_height), overlap_wh
23
 
24
+ # Function to handle inference and tiles
25
+ def detect_objects(image):
26
+ # Convert PIL image to NumPy array (for OpenCV compatibility)
27
+ img = np.array(image) # Gradio image is in PIL format, convert it to NumPy array
28
+ img_rgb = img # Keep the image as RGB format, avoid unnecessary conversion to BGR
 
 
 
 
29
 
30
+ image_shape = (img.shape[1], img.shape[0])
 
 
31
 
32
+ # Parameters for slicing (tiles and overlap)
33
+ tiles = (8, 8) # Use 8x8 tiles for better detection of small objects
34
+ overlap_ratio_wh = (0.2, 0.2) # 20% overlap between tiles for better context
35
+ slice_wh, overlap_wh = calculate_tile_size(image_shape, tiles, overlap_ratio_wh)
 
 
36
 
37
+ # Generate offsets but don't visualize the tiles with rectangles (remove the drawing step)
38
+ offsets = sv.InferenceSlicer._generate_offset(image_shape, slice_wh, None, overlap_wh)
39
+ tiled_image = img_rgb.copy()
40
 
41
+ # Save the PIL image to a temporary file for Roboflow model prediction
 
 
42
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
43
  image.save(temp_file, format="JPEG")
44
  temp_file_path = temp_file.name
45
 
46
+ # Annotate with Roboflow model predictions using the temporary file path
47
+ predictions = model.predict(temp_file_path, confidence=40, overlap=30).json() # Adjusted confidence for small object detection
48
+ class_count = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ # Define a color palette for different classes
51
+ color_palette = {
52
+ "bearbrand": (0, 255, 0), # Green for class 1
53
+ "nescafe latte": (0, 0, 255), # Red for class 2
54
+ "nescafe original": (255, 0, 0), # Blue for class 3
55
+ "nescafe mocha": (0, 255, 255) # Yellow for class 4
56
+ #"class_5": (255, 0, 255) # Magenta for class 5
57
+ # You can add more colors based on the number of classes you have
58
+ }
59
 
60
+ # Draw bounding boxes with different colors and label classes
61
+ for prediction in predictions['predictions']:
62
+ x1 = int(prediction['x'] - prediction['width'] / 2)
63
+ y1 = int(prediction['y'] - prediction['height'] / 2)
64
+ x2 = int(prediction['x'] + prediction['width'] / 2)
65
+ y2 = int(prediction['y'] + prediction['height'] / 2)
66
 
67
+ class_name = prediction['class']
68
+
69
+ # Choose a color for the class, if the class is not in the palette, use white
70
+ box_color = color_palette.get(class_name, (255, 255, 255))
71
+
72
+ # Draw a bounding box around the detected object
73
+ cv2.rectangle(tiled_image, (x1, y1), (x2, y2), box_color, 2) # Bounding box with thickness=2
74
+
75
+ # Put the class name label on the bounding box
76
+ cv2.putText(tiled_image, class_name, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, box_color, 2) # Label
77
+
78
+ # Count the class occurrences
79
  if class_name in class_count:
80
  class_count[class_name] += 1
81
  else:
82
  class_count[class_name] = 1
83
 
84
+ # Create a result text to show class counts
85
+ result_text = "Object counts per class:\n"
86
  for class_name, count in class_count.items():
87
+ result_text += f"{class_name}: {count} objects\n"
88
 
89
+ # Remove the temporary file after processing
90
  os.remove(temp_file_path)
91
 
92
+ return result_text # Only return result_text for object counting
93
 
94
+ # Gradio Interface
95
  iface = gr.Interface(
96
+ fn=detect_objects,
97
+ inputs=gr.Image(type="pil"),
98
+ outputs=gr.Textbox(), # Only output the text with object counts
99
+ live=True
100
  )
101
 
102
+ # Launch Gradio app
103
+ iface.launch(debug=True)