muhammadsalmanalfaridzi commited on
Commit
af6e415
·
verified ·
1 Parent(s): 1f35d81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -106
app.py CHANGED
@@ -1,115 +1,52 @@
1
  import gradio as gr
2
- import tempfile
3
- import os
4
  import cv2
5
- from sahi import AutoDetectionModel
6
- from sahi.predict import get_sliced_prediction
7
  from inference import get_roboflow_model
8
- from PIL import Image
9
- import numpy as np
10
-
11
- # Inisialisasi model deteksi menggunakan SAHI
12
- model_path = get_roboflow_model(model_id="nescafe-4base/46", api_key="Otg64Ra6wNOgDyjuhMYU")
13
- confidence_threshold = 0.6 # Threshold kepercayaan
14
- sahi_device = 'cuda' # Ganti dengan 'cpu' jika tidak menggunakan GPU
15
-
16
- # Memuat model YOLO menggunakan SAHI
17
- sahi_model = AutoDetectionModel.from_pretrained(
18
- model_type="yolov11", # Tipe model YOLO, sesuaikan jika model YOLO yang digunakan berbeda
19
- model_path=model_path,
20
- confidence_threshold=confidence_threshold,
21
- device=sahi_device
22
- )
23
-
24
- # Fungsi untuk deteksi objek menggunakan SAHI
25
- def detect_objects(image):
26
- # Simpan gambar yang diupload sebagai file sementara
27
- with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
28
- image.save(temp_file, format="JPEG")
29
- temp_file_path = temp_file.name
30
-
31
- try:
32
- # Lakukan prediksi pada gambar menggunakan SAHI
33
- results = get_sliced_prediction(
34
- image=image,
35
- detection_model=sahi_model,
36
- slice_height=512, # Ukuran potongan gambar (bisa disesuaikan)
37
- slice_width=512,
38
- overlap_height_ratio=0.2,
39
- overlap_width_ratio=0.2
40
- )
41
-
42
- # Menghitung jumlah objek per kelas
43
- class_count = {}
44
- total_count = 0 # Menyimpan total jumlah objek
45
-
46
- # Menggambar bounding boxes pada gambar
47
- output_image = np.array(image) # Convert PIL Image to numpy array for OpenCV processing
48
 
49
- for prediction in results.object_prediction_list:
50
- bbox = prediction.bbox
51
- class_name = prediction.category.name # Nama kelas objek
52
- confidence = prediction.score.value # Skor prediksi
53
-
54
- # Hanya gambar bounding box jika skor kepercayaan lebih besar dari threshold
55
- if confidence >= confidence_threshold:
56
- # Gambar bounding box
57
- cv2.rectangle(output_image,
58
- (int(bbox.minx), int(bbox.miny)),
59
- (int(bbox.maxx), int(bbox.maxy)),
60
- (0, 255, 0), 2) # Gambar kotak hijau
61
 
62
- # Gambar label dan skor
63
- cv2.putText(output_image,
64
- f"{class_name} {confidence:.2f}",
65
- (int(bbox.minx), int(bbox.miny) - 10),
66
- cv2.FONT_HERSHEY_SIMPLEX, 0.9,
67
- (0, 255, 0), 2)
68
 
69
- # Hitung jumlah objek per kelas
70
- class_count[class_name] = class_count.get(class_name, 0) + 1
71
- total_count += 1 # Menambah jumlah objek
72
 
73
- # Menyusun output berupa string hasil perhitungan
74
- result_text = "Detected Objects:\n\n"
75
- for class_name, count in class_count.items():
76
- result_text += f"{class_name}: {count}\n"
77
- result_text += f"\nTotal Objects: {total_count}"
78
-
79
- # Convert output_image (numpy array) back to PIL Image to save
80
- output_image_pil = Image.fromarray(output_image)
81
- output_image_path = "/tmp/prediction.jpg"
82
- output_image_pil.save(output_image_path) # Menyimpan gambar dengan prediksi
83
-
84
- except Exception as err:
85
- # Menangani kesalahan lain
86
- result_text = f"An error occurred: {err}"
87
- output_image_path = temp_file_path # Kembalikan gambar asli jika terjadi error
88
-
89
- # Hapus file sementara setelah prediksi
90
- os.remove(temp_file_path)
91
-
92
- return output_image_path, result_text
93
-
94
- # Membuat antarmuka Gradio dengan tata letak fleksibel
95
- with gr.Blocks() as iface:
96
- with gr.Row():
97
- with gr.Column():
98
- input_image = gr.Image(type="pil", label="Input Image")
99
- with gr.Column():
100
- output_image = gr.Image(label="Detect Object")
101
- with gr.Column():
102
- output_text = gr.Textbox(label="Counting Object")
103
-
104
- # Tombol untuk memproses input
105
- detect_button = gr.Button("Detect")
106
-
107
- # Hubungkan tombol dengan fungsi deteksi
108
- detect_button.click(
109
- fn=detect_objects,
110
- inputs=input_image,
111
- outputs=[output_image, output_text]
112
- )
113
 
114
- # Menjalankan antarmuka
115
  iface.launch()
 
1
  import gradio as gr
2
+ import supervision as sv
3
+ import numpy as np
4
  import cv2
 
 
5
  from inference import get_roboflow_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ # Define the Roboflow model
8
+ model = get_roboflow_model(model_id="people-detection-general/5", api_key="API_KEY")
 
 
 
 
 
 
 
 
 
 
9
 
10
+ def callback(image_slice: np.ndarray) -> sv.Detections:
11
+ results = model.infer(image_slice)[0]
12
+ return sv.Detections.from_inference(results)
 
 
 
13
 
14
+ # Define the slicer
15
+ slicer = sv.InferenceSlicer(callback=callback)
 
16
 
17
+ def detect_objects(image):
18
+ image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) # Convert from RGB (Gradio) to BGR (OpenCV)
19
+
20
+ # Run inference
21
+ sliced_detections = slicer(image=image)
22
+
23
+ # Annotating the image with boxes and labels
24
+ label_annotator = sv.LabelAnnotator()
25
+ box_annotator = sv.BoxAnnotator()
26
+
27
+ annotated_image = box_annotator.annotate(scene=image.copy(), detections=sliced_detections)
28
+ annotated_image = label_annotator.annotate(scene=annotated_image, detections=sliced_detections)
29
+
30
+ # Count detected objects per class
31
+ class_counts = {}
32
+ for detection in sliced_detections:
33
+ class_name = detection.class_name
34
+ class_counts[class_name] = class_counts.get(class_name, 0) + 1
35
+
36
+ # Total objects detected
37
+ total_count = sum(class_counts.values())
38
+
39
+ # Display results: annotated image and object counts
40
+ result_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB) # Convert back to RGB for Gradio
41
+ return result_image, class_counts, total_count
42
+
43
+ # Create a Gradio interface
44
+ iface = gr.Interface(
45
+ fn=detect_objects,
46
+ inputs=gr.Image(type="pil"),
47
+ outputs=[gr.Image(type="pil"), gr.JSON(), gr.Number(label="Total Objects Detected")],
48
+ live=True
49
+ )
 
 
 
 
 
 
 
50
 
51
+ # Launch the Gradio interface
52
  iface.launch()