File size: 2,966 Bytes
1a3844d b41cba1 85f0e91 36f4824 c596df4 86a4add 42159c7 405fa0d 06aac96 eec0220 06aac96 f08406b 405fa0d bc295c9 42159c7 c596df4 85f0e91 06aac96 85f0e91 06aac96 bc295c9 4e9033c 828c539 bc295c9 c82d217 1a3844d eac766a e772d1d eac766a 85f0e91 a15fee9 98e85e1 a15fee9 98e85e1 2858e47 85f0e91 2858e47 98e85e1 2858e47 8262081 2858e47 5a09739 98e85e1 a15fee9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import gradio as gr
import numpy as np
from time import sleep
import torch
import cv2
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
# from torchvision import transforms
# class Count:
# def __init__(self):
# self.n = 0
# self.imout = np.zeros((1000, 1000))
# def step(self):
# self.n += 1
cnt = 0
weights2load = 'segformer_ep15_loss0.00.pth'
id2label = {0: 'seal', 255: 'bck'}
label2id = {'seal': 0, 'bck': 255}
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
num_labels=2,
id2label=id2label,
label2id=label2id,
)
image_processor = SegformerImageProcessor(reduce_labels=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(weights2load, weights_only=True, map_location=device))
model.to(device).eval()
# counter = Count()
def segment(im, interval_s=2):
global cnt
cnt += 1
im = cv2.resize(im, (im.shape[1] // 2, im.shape[0] // 2))
imout = im.copy()
if cnt % 10 == 0:
# if (counter.imout.sum() == 0) or ((cnt % 100) == 0):
pixel_values = image_processor(im, return_tensors="pt").pixel_values.to(device)
outputs = model(pixel_values=pixel_values)
logits = outputs.logits.cpu().detach().numpy() ** 2
logits_n = (logits[0, 0] - logits[0, 0].min()) / (logits[0, 0].max() - logits[0, 0].min())
logits_n = cv2.resize(logits_n, (im.shape[1], im.shape[0]))
imout[..., 1] = np.clip(imout[..., 1] + logits_n * 200, 0, 254)
return imout, cnt #np.flipud(im)
# with gr.Blocks() as demo:
# inp = gr.Image(sources=["webcam"], streaming=True)
# inp.stream(segment, inputs=inp, outputs=[gr.Image()])
demo = gr.Interface(
segment,
[gr.Image(sources=["webcam"], streaming=True)],
[gr.Image(), gr.Number()],
css=".output-image, .input-image, .image-preview {height: 400px !important}"
)
if __name__ == "__main__":
demo.queue().launch()
# from gradio_webrtc import WebRTC
# css = """.my-group {max-width: 600px !important; max-height: 600px !important;}
# .my-column {display: flex !important; justify-content: center !important; align-items: center !important;}"""
# with gr.Blocks(css=css) as demo:
# gr.HTML(
# )
# with gr.Column(elem_classes=["my-column"]):
# with gr.Group(elem_classes=["my-group"]):
# image = WebRTC(label="Stream")
# image.stream(fn=segment, inputs=[image], outputs=[image])
# demo = gr.Interface(
# fn=segment,
# inputs=[gr.Image(sources=["webcam"], streaming=True)],
# outputs=["image"],
# title="Image Inference",
# cache_examples=False,
# live=True
# )
# if __name__ == "__main__":
# demo.queue().launch() |