gra / run.py
noamholz's picture
Update run.py
8cf581a verified
raw
history blame
2.72 kB
import gradio as gr
import numpy as np
from time import sleep
import torch
import cv2
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
# from torchvision import transforms
# class Count:
# def __init__(self):
# self.n = 0
# self.imout = np.zeros((1000, 1000))
# def step(self):
# self.n += 1
# cnt = 0
weights2load = 'segformer_ep15_loss0.00.pth'
id2label = {0: 'seal', 255: 'bck'}
label2id = {'seal': 0, 'bck': 255}
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
num_labels=2,
id2label=id2label,
label2id=label2id,
)
image_processor = SegformerImageProcessor(reduce_labels=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(weights2load, weights_only=True, map_location=device))
model.to(device).eval()
# counter = Count()
def segment(im, interval_s=2):
print(im)
# if (counter.imout.sum() == 0) or ((cnt % 100) == 0):
pixel_values = image_processor(im, return_tensors="pt").pixel_values.to(device)
outputs = model(pixel_values=pixel_values)
logits = outputs.logits.cpu().detach().numpy() ** 2
logits_n = (logits[0, 0] - logits[0, 0].min()) / (logits[0, 0].max() - logits[0, 0].min())
logits_n = cv2.resize(logits_n, im.shape[::-1])
imout = im.copy()
imout[..., 0] = imout[..., 0] + logits_n / 10
return imout #, cnt #np.flipud(im)
with gr.Blocks() as demo:
inp = gr.Image(sources=["webcam"], streaming=True)
inp.stream(segment, inputs=inp, outputs=[gr.Image()])
# demo = gr.Interface(
# segment,
# [gr.Image(sources=["webcam"], streaming=False)],
# ["image"],
# )
if __name__ == "__main__":
demo.queue().launch()
# from gradio_webrtc import WebRTC
# css = """.my-group {max-width: 600px !important; max-height: 600px !important;}
# .my-column {display: flex !important; justify-content: center !important; align-items: center !important;}"""
# with gr.Blocks(css=css) as demo:
# gr.HTML(
# )
# with gr.Column(elem_classes=["my-column"]):
# with gr.Group(elem_classes=["my-group"]):
# image = WebRTC(label="Stream")
# image.stream(fn=segment, inputs=[image], outputs=[image])
# demo = gr.Interface(
# fn=segment,
# inputs=[gr.Image(sources=["webcam"], streaming=True)],
# outputs=["image"],
# title="Image Inference",
# cache_examples=False,
# live=True
# )
# if __name__ == "__main__":
# demo.queue().launch()