File size: 3,166 Bytes
d4c3acc
 
246f207
ceb95cf
 
 
c21a752
6ed6f8e
 
 
 
648d371
b842d10
6ed6f8e
e3205c6
6ed6f8e
b842d10
6ed6f8e
 
b842d10
6ed6f8e
 
648d371
ceb95cf
 
 
 
 
 
d5aac08
aca9f11
6ed6f8e
 
793cc29
ceb95cf
793cc29
 
 
ceb95cf
793cc29
ceb95cf
793cc29
 
6564a3a
b842d10
 
6ed6f8e
cfd90f9
e3205c6
 
6ed6f8e
 
 
cfd90f9
6ed6f8e
cfd90f9
b3747be
 
5769b69
6ed6f8e
 
 
 
5769b69
 
6ed6f8e
5769b69
6ed6f8e
e3205c6
6ed6f8e
e14364d
aca9f11
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import gradio as gr
from transformers import pipeline, DetrForObjectDetection, DetrConfig, DetrImageProcessor
import numpy as np
import cv2
from PIL import Image

# Pre-load the base configuration and models (without setting a threshold yet)
base_config = DetrConfig.from_pretrained("facebook/detr-resnet-50")
base_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=base_config)
base_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")

def load_model(threshold):
    # Adjust the configuration for the current threshold
    config = DetrConfig.from_pretrained("facebook/detr-resnet-50", threshold=threshold)
    # Create a new model instance with the updated configuration
    model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=config)
    # Image processor does not need to be re-loaded
    return pipeline(task='object-detection', model=model, image_processor=base_processor)

# Initialize the pipeline with a default threshold
od_pipe = load_model(0.25)  # Set a default threshold here

def draw_detections(image, detections):
    np_image = np.array(image)
    np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
    
    for detection in detections:
        score = detection['score']
        label = detection['label']
        box = detection['box']
        x_min, y_min = box['xmin'], box['ymin']
        x_max, y_max = box['xmax'], box['ymax']
        # Draw rectangles and text with a larger font
        cv2.rectangle(np_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
        label_text = f'{label} {score:.2f}'
        # Increase the font size and text thickness
        cv2.putText(np_image, label_text, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 255, 255), 4)
    
    # Convert BGR to RGB for displaying
    final_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
    final_pil_image = Image.fromarray(final_image)
    return final_pil_image

def get_pipeline_prediction(threshold, pil_image):
    global od_pipe
    od_pipe = load_model(threshold)  # reload model with the specified threshold
    try:
        if not isinstance(pil_image, Image.Image):
            pil_image = Image.fromarray(np.array(pil_image).astype('uint8'), 'RGB')
        result = od_pipe(pil_image)
        processed_image = draw_detections(pil_image, result)
        return processed_image, result
    except Exception as e:
        return pil_image, {"error": str(e)}

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            gr.Markdown("## Object Detection")
            inp_image = gr.Image(label="Upload your image here")
            threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.25, label="Detection Threshold")
            run_button = gr.Button("Detect Objects")
        with gr.Column():
            with gr.Tab("Annotated Image"):
                output_image = gr.Image()
            with gr.Tab("Detection Results"):
                output_data = gr.JSON()

    run_button.click(get_pipeline_prediction, inputs=[threshold_slider, inp_image], outputs=[output_image, output_data])

demo.launch()