File size: 2,575 Bytes
1a3844d b41cba1 85f0e91 c596df4 86a4add 42159c7 405fa0d 06aac96 eec0220 06aac96 f08406b 405fa0d 06aac96 42159c7 c596df4 85f0e91 06aac96 85f0e91 06aac96 23b1a56 06aac96 1a3844d 5a09739 1bffe26 1a3844d a15fee9 85f0e91 a15fee9 98e85e1 a15fee9 98e85e1 2858e47 85f0e91 2858e47 98e85e1 2858e47 8262081 2858e47 5a09739 98e85e1 a15fee9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import gradio as gr
import numpy as np
from time import sleep
import torch
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
# from torchvision import transforms
# class Count:
# def __init__(self):
# self.n = 0
# self.imout = np.zeros((1000, 1000))
# def step(self):
# self.n += 1
# cnt = 0
weights2load = 'segformer_ep15_loss0.00.pth'
id2label = {0: 'seal', 255: 'bck'}
label2id = {'seal': 0, 'bck': 255}
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
num_labels=2,
id2label=id2label,
label2id=label2id,
)
image_processor = SegformerImageProcessor(reduce_labels=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(weights2load, weights_only=True, map_location=device))
model.to(device).eval()
# counter = Count()
def segment(im, interval_s=2):
print(im)
# if (counter.imout.sum() == 0) or ((cnt % 100) == 0):
pixel_values = image_processor(im, return_tensors="pt").pixel_values.to(device)
outputs = model(pixel_values=pixel_values)
logits = outputs.logits.cpu().detach().numpy() ** 2
imout = (logits[0, 0] - logits[0, 0].min()) / (logits[0, 0].max() - logits[0, 0].min())
return imout #, cnt #np.flipud(im)
with gr.Blocks() as demo:
inp = gr.Image(sources=["webcam"], streaming=True)
inp.stream(segment, inputs=inp, outputs=[gr.Image()])
demo = gr.Interface(
segment,
[gr.Image(sources=["webcam"], streaming=True)],
["image"],
)
if __name__ == "__main__":
demo.queue().launch()
# from gradio_webrtc import WebRTC
# css = """.my-group {max-width: 600px !important; max-height: 600px !important;}
# .my-column {display: flex !important; justify-content: center !important; align-items: center !important;}"""
# with gr.Blocks(css=css) as demo:
# gr.HTML(
# )
# with gr.Column(elem_classes=["my-column"]):
# with gr.Group(elem_classes=["my-group"]):
# image = WebRTC(label="Stream")
# image.stream(fn=segment, inputs=[image], outputs=[image])
# demo = gr.Interface(
# fn=segment,
# inputs=[gr.Image(sources=["webcam"], streaming=True)],
# outputs=["image"],
# title="Image Inference",
# cache_examples=False,
# live=True
# )
# if __name__ == "__main__":
# demo.queue().launch() |