File size: 3,149 Bytes
1a3844d b41cba1 85f0e91 36f4824 c596df4 86a4add 42159c7 405fa0d 634c7ac f08406b 405fa0d bc295c9 42159c7 c596df4 85f0e91 06aac96 85f0e91 06aac96 dee36c9 3940ce7 828c539 bc295c9 90a09ad 7c3d23b 1a3844d eac766a a5c8f5e 0cc25a7 a15fee9 98e85e1 cca8764 98e85e1 2858e47 85f0e91 2858e47 98e85e1 2858e47 8262081 2858e47 5a09739 98e85e1 a15fee9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import gradio as gr
import numpy as np
from time import sleep
import torch
import cv2
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
# from torchvision import transforms
class Counter:
def __init__(self):
self.count = 0
def increment(self):
self.count += 1
return self.count
counter = Counter()
cnt = 0
weights2load = 'segformer_ep15_loss0.00.pth'
id2label = {0: 'seal', 255: 'bck'}
label2id = {'seal': 0, 'bck': 255}
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
num_labels=2,
id2label=id2label,
label2id=label2id,
)
image_processor = SegformerImageProcessor(reduce_labels=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(weights2load, weights_only=True, map_location=device))
model.to(device).eval()
# counter = Count()
def segment(im, interval_s=2):
# im = cv2.resize(im, (im.shape[1] // 2, im.shape[0] // 2))
imout = im.copy()
# if counter.increment() % 3 == 0:
# # if (counter.imout.sum() == 0) or ((cnt % 100) == 0):
# pixel_values = image_processor(im, return_tensors="pt").pixel_values.to(device)
# outputs = model(pixel_values=pixel_values)
# logits = outputs.logits.cpu().detach().numpy() ** 2
# logits_n = (logits[0, 0] - logits[0, 0].min()) / (logits[0, 0].max() - logits[0, 0].min())
# logits_n = cv2.resize(logits_n, (im.shape[1], im.shape[0]))
# imout[..., 1] = np.clip(imout[..., 1] + logits_n * 200, 0, 254)
return imout, counter.count #np.flipud(im)
# with gr.Blocks() as demo:
# inp = gr.Image(sources=["webcam"], streaming=True)
# inp.stream(segment, inputs=inp, outputs=[gr.Image()])
demo = gr.Interface(
segment,
[gr.Image(sources=["webcam"], streaming=True)],
[gr.Image(), gr.Number()],
css=".output-image, .input-image, .image-preview {height: 400px !important}",
live=True
)
# with gr.Blocks() as demo:
# inp = gr.Image(sources=["webcam"], streaming=True)
# out = gr.Image()
# inp.stream(segment, inputs=inp, outputs=out)
if __name__ == "__main__":
demo.launch()
# from gradio_webrtc import WebRTC
# css = """.my-group {max-width: 600px !important; max-height: 600px !important;}
# .my-column {display: flex !important; justify-content: center !important; align-items: center !important;}"""
# with gr.Blocks(css=css) as demo:
# gr.HTML(
# )
# with gr.Column(elem_classes=["my-column"]):
# with gr.Group(elem_classes=["my-group"]):
# image = WebRTC(label="Stream")
# image.stream(fn=segment, inputs=[image], outputs=[image])
# demo = gr.Interface(
# fn=segment,
# inputs=[gr.Image(sources=["webcam"], streaming=True)],
# outputs=["image"],
# title="Image Inference",
# cache_examples=False,
# live=True
# )
# if __name__ == "__main__":
# demo.queue().launch() |