File size: 2,749 Bytes
ac831c4
62b1b2e
af6e415
52b0cb8
62b1b2e
2d17df1
62b1b2e
4c664b3
 
2d17df1
62b1b2e
 
0fc2ac3
62b1b2e
 
 
0e517e7
 
e055386
0e517e7
e055386
0e517e7
e055386
83eff9f
e055386
 
c3daebd
e055386
83eff9f
 
 
2d17df1
62b1b2e
 
0fc2ac3
62b1b2e
 
 
 
 
af6e415
62b1b2e
 
af6e415
62b1b2e
 
 
 
 
 
0fc2ac3
62b1b2e
 
0fc2ac3
62b1b2e
 
 
 
 
0fc2ac3
62b1b2e
0fc2ac3
62b1b2e
af6e415
0fc2ac3
62b1b2e
 
 
 
 
 
 
 
ac831c4
af6e415
83eff9f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
import supervision as sv
import numpy as np
import cv2
from inference import get_roboflow_model

# Replace with your actual Roboflow model ID and API key
model_id = "nescafe-4base/46"  # Replace with your Roboflow model ID
api_key = "Otg64Ra6wNOgDyjuhMYU"    # Replace with your Roboflow API key

# Load the Roboflow model using the get_roboflow_model function
model = get_roboflow_model(model_id=model_id, api_key=api_key)

# Define the callback function for the SAHI slicer
def callback(image_slice: np.ndarray) -> sv.Detections:
    # Run inference on the image slice
    results = model.infer(image_slice)
    
    # Check if results are in the expected format and handle accordingly
    if isinstance(results, tuple):
        results = results[0]  # Extract the detections from the tuple if necessary
    
    # If the results are a list (likely from Roboflow), access them correctly
    detections = []
    if isinstance(results, list):
        for result in results:
            # Ensure each result is processed into a Detections object
            detections.extend(sv.Detections.from_inference(result))
    
    # Return the list of detections
    return detections

# Initialize the SAHI Inference Slicer
slicer = sv.InferenceSlicer(callback=callback)

# Function to handle image processing, inference, and annotation
def process_image(image):
    # Convert the PIL image to OpenCV format (BGR)
    image = np.array(image)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

    # Run inference using SAHI (splitting the image into slices)
    sliced_detections = slicer(image=image)

    # Annotate the detections with bounding boxes and labels
    label_annotator = sv.LabelAnnotator()
    box_annotator = sv.BoxAnnotator()
    
    annotated_image = box_annotator.annotate(scene=image.copy(), detections=sliced_detections)
    annotated_image = label_annotator.annotate(scene=annotated_image, detections=sliced_detections)

    # Convert the annotated image back to RGB for display in Gradio
    result_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)

    # Count the number of objects detected
    class_count = {}
    for detection in sliced_detections:
        class_name = detection.class_name
        class_count[class_name] = class_count.get(class_name, 0) + 1

    total_count = sum(class_count.values())

    return result_image, class_count, total_count

# Gradio interface
iface = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type="pil", label="Upload Image"),
    outputs=[gr.Image(type="pil", label="Annotated Image"), 
             gr.JSON(label="Object Count"), 
             gr.Number(label="Total Objects Detected")],
    live=True
)

# Launch the Gradio interface
iface.launch()