muhammadsalmanalfaridzi's picture
Update app.py
b83e3be verified
raw
history blame
3.13 kB
import gradio as gr
import tempfile
import os
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
from PIL import Image
# Inisialisasi model deteksi menggunakan SAHI
model_path = "best.pt" # Ganti dengan path model YOLO lokal Anda
confidence_threshold = 0.6 # Threshold kepercayaan
sahi_device = 'cuda' # Ganti dengan 'cpu' jika tidak menggunakan GPU
# Memuat model YOLO menggunakan SAHI
sahi_model = AutoDetectionModel.from_pretrained(
model_type="yolov11", # Tipe model YOLO, sesuaikan jika model YOLO yang digunakan berbeda
model_path=model_path,
confidence_threshold=confidence_threshold,
device=sahi_device
)
# Fungsi untuk deteksi objek menggunakan SAHI
def detect_objects(image):
# Simpan gambar yang diupload sebagai file sementara
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
image.save(temp_file, format="JPEG")
temp_file_path = temp_file.name
try:
# Lakukan prediksi pada gambar menggunakan SAHI
results = get_sliced_prediction(
image=image,
detection_model=sahi_model,
slice_height=512, # Ukuran potongan gambar (bisa disesuaikan)
slice_width=512,
overlap_height_ratio=0.2,
overlap_width_ratio=0.2
)
# Menghitung jumlah objek per kelas
class_count = {}
total_count = 0 # Menyimpan total jumlah objek
for prediction in results.object_prediction_list:
class_name = prediction.category.name # Nama kelas objek
class_count[class_name] = class_count.get(class_name, 0) + 1
total_count += 1 # Menambah jumlah objek
# Menyusun output berupa string hasil perhitungan
result_text = "Detected Objects:\n\n"
for class_name, count in class_count.items():
result_text += f"{class_name}: {count}\n"
result_text += f"\nTotal Objects: {total_count}"
# Menyimpan gambar dengan prediksi
output_image_path = "/tmp/prediction.jpg"
results.save(output_image_path) # Menyimpan gambar dengan prediksi
except Exception as err:
# Menangani kesalahan lain
result_text = f"An error occurred: {err}"
output_image_path = temp_file_path # Kembalikan gambar asli jika terjadi error
# Hapus file sementara setelah prediksi
os.remove(temp_file_path)
return output_image_path, result_text
# Membuat antarmuka Gradio dengan tata letak fleksibel
with gr.Blocks() as iface:
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
with gr.Column():
output_image = gr.Image(label="Detect Object")
with gr.Column():
output_text = gr.Textbox(label="Counting Object")
# Tombol untuk memproses input
detect_button = gr.Button("Detect")
# Hubungkan tombol dengan fungsi deteksi
detect_button.click(
fn=detect_objects,
inputs=input_image,
outputs=[output_image, output_text]
)
# Menjalankan antarmuka
iface.launch()