File size: 1,780 Bytes
1a3844d
b41cba1
 
85f0e91
c596df4
86a4add
42159c7
405fa0d
 
06aac96
 
 
 
eec0220
06aac96
 
f08406b
405fa0d
06aac96
42159c7
c596df4
 
 
 
 
 
 
85f0e91
 
 
 
 
 
06aac96
85f0e91
06aac96
 
 
 
 
 
 
1a3844d
06aac96
 
 
1a3844d
06aac96
 
ff3bcea
06aac96
 
85f0e91
1a3844d
85f0e91
b41cba1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import gradio as gr
import numpy as np
from time import sleep
import torch
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
# from torchvision import transforms



# class Count:
#     def __init__(self):
#         self.n = 0
#         self.imout = np.zeros((1000, 1000))
        
#     def step(self):
#         self.n += 1


# cnt = 0        
weights2load = 'segformer_ep15_loss0.00.pth'
id2label = {0: 'seal', 255: 'bck'}
label2id = {'seal': 0, 'bck': 255}
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
                                                         num_labels=2,
                                                         id2label=id2label,
                                                         label2id=label2id,
)
image_processor = SegformerImageProcessor(reduce_labels=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(weights2load, weights_only=True, map_location=device))
model.to(device).eval()

# counter = Count()

def segment(im, interval_s=2):
    # if (counter.imout.sum() == 0) or ((cnt % 100) == 0):
    pixel_values = image_processor(im, return_tensors="pt").pixel_values.to(device)
    outputs = model(pixel_values=pixel_values)
    logits = outputs.logits.cpu().detach().numpy() ** 2
    imout = (logits[0, 0] - logits[0, 0].min()) / (logits[0, 0].max() - logits[0, 0].min())
    return imout  #, cnt  #np.flipud(im)

# with gr.Blocks() as demo:
#     inp = gr.Image(sources=["webcam"], streaming=True)
#     inp.stream(segment, inputs=inp, outputs=[gr.Image(), gr.Number()]) 

demo = gr.Interface(
    segment,
    [gr.Image(sources=["webcam"], streaming=True)],
    ["image"],
)

if __name__ == "__main__":

    demo.launch()