File size: 1,943 Bytes
1a3844d
b41cba1
 
85f0e91
c596df4
86a4add
42159c7
f08406b
 
 
507fb4f
eec0220
f08406b
 
 
 
42159c7
c596df4
 
 
 
 
 
 
85f0e91
 
 
 
 
 
f08406b
85f0e91
b90cf97
b41cba1
 
1a3844d
b41cba1
 
 
1a3844d
b41cba1
 
 
f08406b
507fb4f
f08406b
 
 
 
62dda50
1a3844d
b41cba1
 
62dda50
1a3844d
85f0e91
1a3844d
85f0e91
b41cba1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import gradio as gr
import numpy as np
from time import sleep
import torch
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
# from torchvision import transforms

class Count:
    def __init__(self):
        self.n = 0
        self.imout = np.zeros((1000, 1000))
        
    def step(self):
        self.n += 1

        
weights2load = 'segformer_ep15_loss0.00.pth'
id2label = {0: 'seal', 255: 'bck'}
label2id = {'seal': 0, 'bck': 255}
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
                                                         num_labels=2,
                                                         id2label=id2label,
                                                         label2id=label2id,
)
image_processor = SegformerImageProcessor(reduce_labels=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(weights2load, weights_only=True, map_location=device))
model.to(device).eval()

counter = Count()

def flip_periodically(im, interval_s=2):
    """
    Flips the image periodically with the given interval.

    Args:
        im: The input image.
        interval_ms: The interval in milliseconds between flips.

    Returns:
        The flipped image.
    """
    counter.step()
    if (counter.imout.sum() == 0) or ((counter.n % 100) == 0):
        pixel_values = image_processor(im, return_tensors="pt").pixel_values.to(device)
        outputs = model(pixel_values=pixel_values)
        logits = outputs.logits.cpu().detach().numpy() ** 2
        counter.imout = (logits[0, 0] - logits[0, 0].min()) / (logits[0, 0].max() - logits[0, 0].min())
    return counter.imout, counter.n  #np.flipud(im)

with gr.Blocks() as demo:
    inp = gr.Image(sources=["webcam"], streaming=True)
    inp.stream(flip_periodically, inputs=inp, outputs=[gr.Image(), gr.Number()]) 


if __name__ == "__main__":

    demo.launch()