muhammadsalmanalfaridzi's picture
Update app.py
83eff9f verified
raw
history blame
2.75 kB
import gradio as gr
import supervision as sv
import numpy as np
import cv2
from inference import get_roboflow_model
# Replace with your actual Roboflow model ID and API key
model_id = "nescafe-4base/46" # Replace with your Roboflow model ID
api_key = "Otg64Ra6wNOgDyjuhMYU" # Replace with your Roboflow API key
# Load the Roboflow model using the get_roboflow_model function
model = get_roboflow_model(model_id=model_id, api_key=api_key)
# Define the callback function for the SAHI slicer
def callback(image_slice: np.ndarray) -> sv.Detections:
# Run inference on the image slice
results = model.infer(image_slice)
# Check if results are in the expected format and handle accordingly
if isinstance(results, tuple):
results = results[0] # Extract the detections from the tuple if necessary
# If the results are a list (likely from Roboflow), access them correctly
detections = []
if isinstance(results, list):
for result in results:
# Ensure each result is processed into a Detections object
detections.extend(sv.Detections.from_inference(result))
# Return the list of detections
return detections
# Initialize the SAHI Inference Slicer
slicer = sv.InferenceSlicer(callback=callback)
# Function to handle image processing, inference, and annotation
def process_image(image):
# Convert the PIL image to OpenCV format (BGR)
image = np.array(image)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# Run inference using SAHI (splitting the image into slices)
sliced_detections = slicer(image=image)
# Annotate the detections with bounding boxes and labels
label_annotator = sv.LabelAnnotator()
box_annotator = sv.BoxAnnotator()
annotated_image = box_annotator.annotate(scene=image.copy(), detections=sliced_detections)
annotated_image = label_annotator.annotate(scene=annotated_image, detections=sliced_detections)
# Convert the annotated image back to RGB for display in Gradio
result_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
# Count the number of objects detected
class_count = {}
for detection in sliced_detections:
class_name = detection.class_name
class_count[class_name] = class_count.get(class_name, 0) + 1
total_count = sum(class_count.values())
return result_image, class_count, total_count
# Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=[gr.Image(type="pil", label="Annotated Image"),
gr.JSON(label="Object Count"),
gr.Number(label="Total Objects Detected")],
live=True
)
# Launch the Gradio interface
iface.launch()