File size: 3,149 Bytes
1a3844d
b41cba1
 
85f0e91
36f4824
c596df4
86a4add
42159c7
405fa0d
 
634c7ac
 
 
 
 
 
 
 
 
 
f08406b
405fa0d
bc295c9
42159c7
c596df4
 
 
 
 
 
 
85f0e91
 
 
 
 
 
06aac96
85f0e91
06aac96
dee36c9
3940ce7
828c539
bc295c9
90a09ad
 
 
 
 
 
 
 
7c3d23b
1a3844d
eac766a
 
 
 
a5c8f5e
 
 
 
 
 
 
 
 
 
 
0cc25a7
a15fee9
98e85e1
cca8764
98e85e1
2858e47
85f0e91
2858e47
 
98e85e1
2858e47
 
 
 
 
 
8262081
2858e47
 
5a09739
 
 
 
 
 
 
 
98e85e1
a15fee9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
import numpy as np
from time import sleep
import torch
import cv2
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
# from torchvision import transforms



class Counter:
    def __init__(self):
        self.count = 0

    def increment(self):
        self.count += 1
        return self.count

counter = Counter()



cnt = 0  
weights2load = 'segformer_ep15_loss0.00.pth'
id2label = {0: 'seal', 255: 'bck'}
label2id = {'seal': 0, 'bck': 255}
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
                                                         num_labels=2,
                                                         id2label=id2label,
                                                         label2id=label2id,
)
image_processor = SegformerImageProcessor(reduce_labels=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(weights2load, weights_only=True, map_location=device))
model.to(device).eval()

# counter = Count()

def segment(im, interval_s=2):
    
    # im = cv2.resize(im, (im.shape[1] // 2, im.shape[0] // 2))
    imout = im.copy()
    
    # if counter.increment() % 3 == 0:
    #     # if (counter.imout.sum() == 0) or ((cnt % 100) == 0):
    #     pixel_values = image_processor(im, return_tensors="pt").pixel_values.to(device)
    #     outputs = model(pixel_values=pixel_values)
    #     logits = outputs.logits.cpu().detach().numpy() ** 2
    #     logits_n = (logits[0, 0] - logits[0, 0].min()) / (logits[0, 0].max() - logits[0, 0].min())
    #     logits_n = cv2.resize(logits_n, (im.shape[1], im.shape[0]))
    #     imout[..., 1] = np.clip(imout[..., 1] + logits_n * 200, 0, 254)
    return imout, counter.count  #np.flipud(im)

# with gr.Blocks() as demo:
#     inp = gr.Image(sources=["webcam"], streaming=True)
#     inp.stream(segment, inputs=inp, outputs=[gr.Image()]) 

demo = gr.Interface(
    segment,
    [gr.Image(sources=["webcam"], streaming=True)],
    [gr.Image(), gr.Number()],
    css=".output-image, .input-image, .image-preview {height: 400px !important}",
    live=True
)
# with gr.Blocks() as demo:
#   inp = gr.Image(sources=["webcam"], streaming=True)
#   out = gr.Image()
#   inp.stream(segment, inputs=inp, outputs=out)
    
if __name__ == "__main__":

    demo.launch()

# from gradio_webrtc import WebRTC

# css = """.my-group {max-width: 600px !important; max-height: 600px !important;}
#          .my-column {display: flex !important; justify-content: center !important; align-items: center !important;}"""

# with gr.Blocks(css=css) as demo:
#     gr.HTML(
#     )
#     with gr.Column(elem_classes=["my-column"]):
#         with gr.Group(elem_classes=["my-group"]):
#             image = WebRTC(label="Stream")
            
#         image.stream(fn=segment, inputs=[image], outputs=[image])

# demo = gr.Interface(
#     fn=segment,
#     inputs=[gr.Image(sources=["webcam"], streaming=True)],
#     outputs=["image"],
#     title="Image Inference",
#     cache_examples=False,
#     live=True
# )

# if __name__ == "__main__":
#     demo.queue().launch()