gra / run.py
noamholz's picture
Update run.py
e772d1d verified
raw
history blame
2.97 kB
import gradio as gr
import numpy as np
from time import sleep
import torch
import cv2
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
# from torchvision import transforms
# class Count:
# def __init__(self):
# self.n = 0
# self.imout = np.zeros((1000, 1000))
# def step(self):
# self.n += 1
cnt = 0
weights2load = 'segformer_ep15_loss0.00.pth'
id2label = {0: 'seal', 255: 'bck'}
label2id = {'seal': 0, 'bck': 255}
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
num_labels=2,
id2label=id2label,
label2id=label2id,
)
image_processor = SegformerImageProcessor(reduce_labels=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(weights2load, weights_only=True, map_location=device))
model.to(device).eval()
# counter = Count()
def segment(im, interval_s=2):
global cnt
cnt += 1
im = cv2.resize(im, (im.shape[1] // 2, im.shape[0] // 2))
imout = im.copy()
if cnt % 10 == 0:
# if (counter.imout.sum() == 0) or ((cnt % 100) == 0):
pixel_values = image_processor(im, return_tensors="pt").pixel_values.to(device)
outputs = model(pixel_values=pixel_values)
logits = outputs.logits.cpu().detach().numpy() ** 2
logits_n = (logits[0, 0] - logits[0, 0].min()) / (logits[0, 0].max() - logits[0, 0].min())
logits_n = cv2.resize(logits_n, (im.shape[1], im.shape[0]))
imout[..., 1] = np.clip(imout[..., 1] + logits_n * 200, 0, 254)
return imout, cnt #np.flipud(im)
# with gr.Blocks() as demo:
# inp = gr.Image(sources=["webcam"], streaming=True)
# inp.stream(segment, inputs=inp, outputs=[gr.Image()])
demo = gr.Interface(
segment,
[gr.Image(sources=["webcam"], streaming=True)],
[gr.Image(), gr.Number()],
css=".output-image, .input-image, .image-preview {height: 400px !important}"
)
if __name__ == "__main__":
demo.queue().launch()
# from gradio_webrtc import WebRTC
# css = """.my-group {max-width: 600px !important; max-height: 600px !important;}
# .my-column {display: flex !important; justify-content: center !important; align-items: center !important;}"""
# with gr.Blocks(css=css) as demo:
# gr.HTML(
# )
# with gr.Column(elem_classes=["my-column"]):
# with gr.Group(elem_classes=["my-group"]):
# image = WebRTC(label="Stream")
# image.stream(fn=segment, inputs=[image], outputs=[image])
# demo = gr.Interface(
# fn=segment,
# inputs=[gr.Image(sources=["webcam"], streaming=True)],
# outputs=["image"],
# title="Image Inference",
# cache_examples=False,
# live=True
# )
# if __name__ == "__main__":
# demo.queue().launch()