File size: 3,445 Bytes
d4c3acc
 
246f207
ceb95cf
 
 
c21a752
648d371
 
 
 
 
b842d10
e3205c6
 
b842d10
 
 
 
 
648d371
ceb95cf
 
 
 
 
 
 
e3205c6
ceb95cf
 
d5aac08
aca9f11
 
 
 
 
 
e3205c6
ceb95cf
b6fa5d6
e3205c6
ceb95cf
 
 
 
 
6564a3a
b842d10
 
cfd90f9
e3205c6
 
 
 
 
 
 
 
 
 
cfd90f9
 
 
 
e3205c6
 
 
cfd90f9
e3205c6
b3747be
 
5769b69
4ce9fc8
e3205c6
 
5769b69
 
 
 
 
 
e3205c6
b842d10
e14364d
aca9f11
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import gradio as gr
from transformers import pipeline, DetrForObjectDetection, DetrConfig, DetrImageProcessor
import numpy as np
import cv2
from PIL import Image

# Initialize the model
config = DetrConfig.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=config)
image_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")

def load_model(threshold):
    # Reinitialize the model with the desired detection threshold
    config = DetrConfig.from_pretrained("facebook/detr-resnet-50", threshold=threshold)
    model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=config)
    image_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
    return pipeline(task='object-detection', model=model, image_processor=image_processor)

od_pipe = load_model(0.5)  # Default threshold

def draw_detections(image, detections):
    # Convert PIL image to a numpy array
    np_image = np.array(image)
    
    # Convert RGB to BGR for OpenCV
    np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
    
    # Draw detections
    for detection in detections:
        score = detection['score']
        label = detection['label']
        box = detection['box']
        x_min = box['xmin']
        y_min = box['ymin']
        x_max = box['xmax']
        y_max = box['ymax']

        # Increase font size for better visibility
        cv2.rectangle(np_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
        label_text = f'{label} {score:.2f}'
        cv2.putText(np_image, label_text, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 255, 255), 4)
    
    # Convert BGR to RGB for displaying
    final_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
    final_pil_image = Image.fromarray(final_image)
    return final_pil_image

def get_pipeline_prediction(threshold, pil_image):
    global od_pipe
    try:
        # Check if the model threshold needs adjusting
        if od_pipe.config.threshold != threshold:
            od_pipe = load_model(threshold)
            print("Model reloaded with new threshold:", threshold)
        
        # Ensure input is a PIL image
        if not isinstance(pil_image, Image.Image):
            pil_image = Image.fromarray(np.array(pil_image).astype('uint8'), 'RGB')
        
        # Run detection and return annotated image and results
        pipeline_output = od_pipe(pil_image)
        processed_image = draw_detections(pil_image, pipeline_output)
        return processed_image, pipeline_output
    except Exception as e:
        error_message = f"An error occurred: {str(e)}"
        print(error_message)
        return pil_image, {"error": error_message}

# Gradio interface
with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            inp_image = gr.Image(label="Input image")
            slider = gr.Slider(minimum=0, maximum=1, step=0.05, label="Detection Sensitivity", value=0.5)
            gr.Markdown("Adjust the slider to change detection sensitivity.")
            btn_run = gr.Button('Run Detection')
        with gr.Column():
            with gr.Tab("Annotated Image"):
                out_image = gr.Image()
            with gr.Tab("Detection Results"):
                out_json = gr.JSON()

    btn_run.click(get_pipeline_prediction, inputs=[slider, inp_image], outputs=[out_image, out_json])

demo.launch()