File size: 2,371 Bytes
1a3844d
b41cba1
 
85f0e91
c596df4
86a4add
42159c7
405fa0d
 
06aac96
 
 
 
eec0220
06aac96
 
f08406b
405fa0d
06aac96
42159c7
c596df4
 
 
 
 
 
 
85f0e91
 
 
 
 
 
06aac96
85f0e91
06aac96
 
 
 
 
 
 
1a3844d
06aac96
 
 
1a3844d
98e85e1
 
 
 
 
85f0e91
98e85e1
 
 
 
 
85f0e91
98e85e1
 
 
 
 
 
 
 
baca7e9
98e85e1
 
 
 
 
b41cba1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
import numpy as np
from time import sleep
import torch
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
# from torchvision import transforms



# class Count:
#     def __init__(self):
#         self.n = 0
#         self.imout = np.zeros((1000, 1000))
        
#     def step(self):
#         self.n += 1


# cnt = 0        
weights2load = 'segformer_ep15_loss0.00.pth'
id2label = {0: 'seal', 255: 'bck'}
label2id = {'seal': 0, 'bck': 255}
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
                                                         num_labels=2,
                                                         id2label=id2label,
                                                         label2id=label2id,
)
image_processor = SegformerImageProcessor(reduce_labels=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(weights2load, weights_only=True, map_location=device))
model.to(device).eval()

# counter = Count()

def segment(im, interval_s=2):
    # if (counter.imout.sum() == 0) or ((cnt % 100) == 0):
    pixel_values = image_processor(im, return_tensors="pt").pixel_values.to(device)
    outputs = model(pixel_values=pixel_values)
    logits = outputs.logits.cpu().detach().numpy() ** 2
    imout = (logits[0, 0] - logits[0, 0].min()) / (logits[0, 0].max() - logits[0, 0].min())
    return imout  #, cnt  #np.flipud(im)

# with gr.Blocks() as demo:
#     inp = gr.Image(sources=["webcam"], streaming=True)
#     inp.stream(segment, inputs=inp, outputs=[gr.Image(), gr.Number()]) 

# demo = gr.Interface(
#     segment,
#     [gr.Image(sources=["webcam"], streaming=True)],
#     ["image"],
# )

# if __name__ == "__main__":

#     demo.launch()

from gradio_webrtc import WebRTC

css = """.my-group {max-width: 600px !important; max-height: 600px !important;}
         .my-column {display: flex !important; justify-content: center !important; align-items: center !important;}"""

with gr.Blocks(css=css) as demo:
    gr.HTML(
    )
    with gr.Column(elem_classes=["my-column"]):
        with gr.Group(elem_classes=["my-group"]):
            image = WebRTC(label="Stream")
        image.stream(
            fn=segment, inputs=[image], outputs=[image], time_limit=10
        )

if __name__ == "__main__":
    demo.launch()