gra / run.py
noamholz's picture
Update run.py
7c3d23b verified
raw
history blame
2.98 kB
import gradio as gr
import numpy as np
from time import sleep
import torch
import cv2
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
# from torchvision import transforms
class Counter:
def __init__(self):
self.count = 0
def increment(self):
self.count += 1
return self.count
counter = Counter()
cnt = 0
weights2load = 'segformer_ep15_loss0.00.pth'
id2label = {0: 'seal', 255: 'bck'}
label2id = {'seal': 0, 'bck': 255}
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
num_labels=2,
id2label=id2label,
label2id=label2id,
)
image_processor = SegformerImageProcessor(reduce_labels=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(weights2load, weights_only=True, map_location=device))
model.to(device).eval()
# counter = Count()
def segment(im, interval_s=2):
global counter
counter.increment()
im = cv2.resize(im, (im.shape[1] // 2, im.shape[0] // 2))
imout = im.copy()
if cnt % 10 == 0:
# if (counter.imout.sum() == 0) or ((cnt % 100) == 0):
pixel_values = image_processor(im, return_tensors="pt").pixel_values.to(device)
outputs = model(pixel_values=pixel_values)
logits = outputs.logits.cpu().detach().numpy() ** 2
logits_n = (logits[0, 0] - logits[0, 0].min()) / (logits[0, 0].max() - logits[0, 0].min())
logits_n = cv2.resize(logits_n, (im.shape[1], im.shape[0]))
imout[..., 1] = np.clip(imout[..., 1] + logits_n * 200, 0, 254)
return imout, counter.count #np.flipud(im)
# with gr.Blocks() as demo:
# inp = gr.Image(sources=["webcam"], streaming=True)
# inp.stream(segment, inputs=inp, outputs=[gr.Image()])
demo = gr.Interface(
segment,
[gr.Image(sources=["webcam"], streaming=True)],
[gr.Image(), gr.Number()],
css=".output-image, .input-image, .image-preview {height: 400px !important}"
)
if __name__ == "__main__":
demo.launch()
# from gradio_webrtc import WebRTC
# css = """.my-group {max-width: 600px !important; max-height: 600px !important;}
# .my-column {display: flex !important; justify-content: center !important; align-items: center !important;}"""
# with gr.Blocks(css=css) as demo:
# gr.HTML(
# )
# with gr.Column(elem_classes=["my-column"]):
# with gr.Group(elem_classes=["my-group"]):
# image = WebRTC(label="Stream")
# image.stream(fn=segment, inputs=[image], outputs=[image])
# demo = gr.Interface(
# fn=segment,
# inputs=[gr.Image(sources=["webcam"], streaming=True)],
# outputs=["image"],
# title="Image Inference",
# cache_examples=False,
# live=True
# )
# if __name__ == "__main__":
# demo.queue().launch()