File size: 2,371 Bytes
1a3844d b41cba1 85f0e91 c596df4 86a4add 42159c7 405fa0d 06aac96 eec0220 06aac96 f08406b 405fa0d 06aac96 42159c7 c596df4 85f0e91 06aac96 85f0e91 06aac96 1a3844d 06aac96 1a3844d 98e85e1 85f0e91 98e85e1 85f0e91 98e85e1 baca7e9 98e85e1 b41cba1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import gradio as gr
import numpy as np
from time import sleep
import torch
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
# from torchvision import transforms
# class Count:
# def __init__(self):
# self.n = 0
# self.imout = np.zeros((1000, 1000))
# def step(self):
# self.n += 1
# cnt = 0
weights2load = 'segformer_ep15_loss0.00.pth'
id2label = {0: 'seal', 255: 'bck'}
label2id = {'seal': 0, 'bck': 255}
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
num_labels=2,
id2label=id2label,
label2id=label2id,
)
image_processor = SegformerImageProcessor(reduce_labels=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(weights2load, weights_only=True, map_location=device))
model.to(device).eval()
# counter = Count()
def segment(im, interval_s=2):
# if (counter.imout.sum() == 0) or ((cnt % 100) == 0):
pixel_values = image_processor(im, return_tensors="pt").pixel_values.to(device)
outputs = model(pixel_values=pixel_values)
logits = outputs.logits.cpu().detach().numpy() ** 2
imout = (logits[0, 0] - logits[0, 0].min()) / (logits[0, 0].max() - logits[0, 0].min())
return imout #, cnt #np.flipud(im)
# with gr.Blocks() as demo:
# inp = gr.Image(sources=["webcam"], streaming=True)
# inp.stream(segment, inputs=inp, outputs=[gr.Image(), gr.Number()])
# demo = gr.Interface(
# segment,
# [gr.Image(sources=["webcam"], streaming=True)],
# ["image"],
# )
# if __name__ == "__main__":
# demo.launch()
from gradio_webrtc import WebRTC
css = """.my-group {max-width: 600px !important; max-height: 600px !important;}
.my-column {display: flex !important; justify-content: center !important; align-items: center !important;}"""
with gr.Blocks(css=css) as demo:
gr.HTML(
)
with gr.Column(elem_classes=["my-column"]):
with gr.Group(elem_classes=["my-group"]):
image = WebRTC(label="Stream")
image.stream(
fn=segment, inputs=[image], outputs=[image], time_limit=10
)
if __name__ == "__main__":
demo.launch() |