muhammadsalmanalfaridzi commited on
Commit
b83e3be
·
verified ·
1 Parent(s): cd1aca7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -73
app.py CHANGED
@@ -1,26 +1,24 @@
1
  import gradio as gr
2
- from dotenv import load_dotenv
3
- from roboflow import Roboflow
4
  import tempfile
5
  import os
6
- import requests
7
- import numpy as np # Import numpy to handle image slices
8
- from sahi.predict import get_sliced_prediction # SAHI slicing inference
9
- import supervision as sv # For annotating images with results
10
-
11
- # Muat variabel lingkungan dari file .env
12
- load_dotenv()
13
- api_key = os.getenv("ROBOFLOW_API_KEY")
14
- workspace = os.getenv("ROBOFLOW_WORKSPACE")
15
- project_name = os.getenv("ROBOFLOW_PROJECT")
16
- model_version = int(os.getenv("ROBOFLOW_MODEL_VERSION"))
17
-
18
- # Inisialisasi Roboflow menggunakan data yang diambil dari secrets
19
- rf = Roboflow(api_key=api_key)
20
- project = rf.workspace(workspace).project(project_name)
21
- model = project.version(model_version).model
22
-
23
- # Fungsi untuk menangani input dan output gambar
24
  def detect_objects(image):
25
  # Simpan gambar yang diupload sebagai file sementara
26
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
@@ -28,85 +26,64 @@ def detect_objects(image):
28
  temp_file_path = temp_file.name
29
 
30
  try:
31
- # Perform sliced inference with SAHI using InferenceSlicer
32
- def callback(image_slice: np.ndarray) -> sv.Detections:
33
- results = model.infer(image_slice)[0] # Perform inference on each slice
34
- return sv.Detections.from_inference(results)
35
-
36
- # Configure the SAHI Slicer with specific slice dimensions and overlap
37
- slicer = sv.InferenceSlicer(
38
- callback=callback,
39
- slice_wh=(320, 320), # Adjust slice dimensions as needed
40
- overlap_wh=(64, 64), # Adjust overlap in pixels (DO NOT use overlap_ratio_wh here)
41
- overlap_filter=sv.OverlapFilter.NON_MAX_SUPPRESSION, # Filter overlapping detections
42
- iou_threshold=0.5, # Intersection over Union threshold for NMS
43
  )
44
 
45
- # Run slicing-based inference
46
- detections = slicer(image)
47
-
48
- # Annotate the results on the image
49
- box_annotator = sv.BoxAnnotator()
50
- label_annotator = sv.LabelAnnotator()
51
-
52
- annotated_image = box_annotator.annotate(
53
- scene=image.copy(), detections=detections)
54
-
55
- annotated_image = label_annotator.annotate(
56
- scene=annotated_image, detections=detections)
57
-
58
- # Save the annotated image
59
- output_image_path = "/tmp/prediction_visual.png"
60
- annotated_image.save(output_image_path)
61
-
62
- # Count the number of detected objects per class
63
  class_count = {}
64
- total_count = 0
65
 
66
- for prediction in detections:
67
- class_name = prediction.class_id # or prediction.class_name if available
68
  class_count[class_name] = class_count.get(class_name, 0) + 1
69
- total_count += 1 # Increment the total object count
70
 
71
- # Create a result text with object counts
72
  result_text = "Detected Objects:\n\n"
73
  for class_name, count in class_count.items():
74
  result_text += f"{class_name}: {count}\n"
75
- result_text += f"\nTotal objects detected: {total_count}"
 
 
 
 
76
 
77
- except requests.exceptions.HTTPError as http_err:
78
- # Handle HTTP errors
79
- result_text = f"HTTP error occurred: {http_err}"
80
- output_image_path = temp_file_path # Return the original image in case of error
81
  except Exception as err:
82
- # Handle other errors
83
  result_text = f"An error occurred: {err}"
84
- output_image_path = temp_file_path # Return the original image in case of error
85
 
86
- # Clean up temporary files
87
  os.remove(temp_file_path)
88
-
89
  return output_image_path, result_text
90
 
91
- # Create the Gradio interface
92
  with gr.Blocks() as iface:
93
  with gr.Row():
94
  with gr.Column():
95
  input_image = gr.Image(type="pil", label="Input Image")
96
  with gr.Column():
97
- output_image = gr.Image(label="Detected Objects")
98
  with gr.Column():
99
- output_text = gr.Textbox(label="Object Count")
100
-
101
- # Button to trigger object detection
102
- detect_button = gr.Button("Detect Objects")
103
-
104
- # Link the button to the detect_objects function
105
  detect_button.click(
106
  fn=detect_objects,
107
  inputs=input_image,
108
  outputs=[output_image, output_text]
109
  )
110
 
111
- # Launch the interface
112
  iface.launch()
 
1
  import gradio as gr
 
 
2
  import tempfile
3
  import os
4
+ from sahi import AutoDetectionModel
5
+ from sahi.predict import get_sliced_prediction
6
+ from PIL import Image
7
+
8
+ # Inisialisasi model deteksi menggunakan SAHI
9
+ model_path = "best.pt" # Ganti dengan path model YOLO lokal Anda
10
+ confidence_threshold = 0.6 # Threshold kepercayaan
11
+ sahi_device = 'cuda' # Ganti dengan 'cpu' jika tidak menggunakan GPU
12
+
13
+ # Memuat model YOLO menggunakan SAHI
14
+ sahi_model = AutoDetectionModel.from_pretrained(
15
+ model_type="yolov11", # Tipe model YOLO, sesuaikan jika model YOLO yang digunakan berbeda
16
+ model_path=model_path,
17
+ confidence_threshold=confidence_threshold,
18
+ device=sahi_device
19
+ )
20
+
21
+ # Fungsi untuk deteksi objek menggunakan SAHI
22
  def detect_objects(image):
23
  # Simpan gambar yang diupload sebagai file sementara
24
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
 
26
  temp_file_path = temp_file.name
27
 
28
  try:
29
+ # Lakukan prediksi pada gambar menggunakan SAHI
30
+ results = get_sliced_prediction(
31
+ image=image,
32
+ detection_model=sahi_model,
33
+ slice_height=512, # Ukuran potongan gambar (bisa disesuaikan)
34
+ slice_width=512,
35
+ overlap_height_ratio=0.2,
36
+ overlap_width_ratio=0.2
 
 
 
 
37
  )
38
 
39
+ # Menghitung jumlah objek per kelas
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  class_count = {}
41
+ total_count = 0 # Menyimpan total jumlah objek
42
 
43
+ for prediction in results.object_prediction_list:
44
+ class_name = prediction.category.name # Nama kelas objek
45
  class_count[class_name] = class_count.get(class_name, 0) + 1
46
+ total_count += 1 # Menambah jumlah objek
47
 
48
+ # Menyusun output berupa string hasil perhitungan
49
  result_text = "Detected Objects:\n\n"
50
  for class_name, count in class_count.items():
51
  result_text += f"{class_name}: {count}\n"
52
+ result_text += f"\nTotal Objects: {total_count}"
53
+
54
+ # Menyimpan gambar dengan prediksi
55
+ output_image_path = "/tmp/prediction.jpg"
56
+ results.save(output_image_path) # Menyimpan gambar dengan prediksi
57
 
 
 
 
 
58
  except Exception as err:
59
+ # Menangani kesalahan lain
60
  result_text = f"An error occurred: {err}"
61
+ output_image_path = temp_file_path # Kembalikan gambar asli jika terjadi error
62
 
63
+ # Hapus file sementara setelah prediksi
64
  os.remove(temp_file_path)
65
+
66
  return output_image_path, result_text
67
 
68
+ # Membuat antarmuka Gradio dengan tata letak fleksibel
69
  with gr.Blocks() as iface:
70
  with gr.Row():
71
  with gr.Column():
72
  input_image = gr.Image(type="pil", label="Input Image")
73
  with gr.Column():
74
+ output_image = gr.Image(label="Detect Object")
75
  with gr.Column():
76
+ output_text = gr.Textbox(label="Counting Object")
77
+
78
+ # Tombol untuk memproses input
79
+ detect_button = gr.Button("Detect")
80
+
81
+ # Hubungkan tombol dengan fungsi deteksi
82
  detect_button.click(
83
  fn=detect_objects,
84
  inputs=input_image,
85
  outputs=[output_image, output_text]
86
  )
87
 
88
+ # Menjalankan antarmuka
89
  iface.launch()