gra / run.py
noamholz's picture
Update run.py
3940ce7 verified
raw
history blame
3.15 kB
import gradio as gr
import numpy as np
from time import sleep
import torch
import cv2
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
# from torchvision import transforms
class Counter:
def __init__(self):
self.count = 0
def increment(self):
self.count += 1
return self.count
counter = Counter()
cnt = 0
weights2load = 'segformer_ep15_loss0.00.pth'
id2label = {0: 'seal', 255: 'bck'}
label2id = {'seal': 0, 'bck': 255}
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
num_labels=2,
id2label=id2label,
label2id=label2id,
)
image_processor = SegformerImageProcessor(reduce_labels=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(weights2load, weights_only=True, map_location=device))
model.to(device).eval()
# counter = Count()
def segment(im, interval_s=2):
# im = cv2.resize(im, (im.shape[1] // 2, im.shape[0] // 2))
imout = im.copy()
# if counter.increment() % 3 == 0:
# # if (counter.imout.sum() == 0) or ((cnt % 100) == 0):
# pixel_values = image_processor(im, return_tensors="pt").pixel_values.to(device)
# outputs = model(pixel_values=pixel_values)
# logits = outputs.logits.cpu().detach().numpy() ** 2
# logits_n = (logits[0, 0] - logits[0, 0].min()) / (logits[0, 0].max() - logits[0, 0].min())
# logits_n = cv2.resize(logits_n, (im.shape[1], im.shape[0]))
# imout[..., 1] = np.clip(imout[..., 1] + logits_n * 200, 0, 254)
return imout, counter.count #np.flipud(im)
# with gr.Blocks() as demo:
# inp = gr.Image(sources=["webcam"], streaming=True)
# inp.stream(segment, inputs=inp, outputs=[gr.Image()])
demo = gr.Interface(
segment,
[gr.Image(sources=["webcam"], streaming=True)],
[gr.Image(), gr.Number()],
css=".output-image, .input-image, .image-preview {height: 400px !important}",
live=True
)
# with gr.Blocks() as demo:
# inp = gr.Image(sources=["webcam"], streaming=True)
# out = gr.Image()
# inp.stream(segment, inputs=inp, outputs=out)
if __name__ == "__main__":
demo.launch()
# from gradio_webrtc import WebRTC
# css = """.my-group {max-width: 600px !important; max-height: 600px !important;}
# .my-column {display: flex !important; justify-content: center !important; align-items: center !important;}"""
# with gr.Blocks(css=css) as demo:
# gr.HTML(
# )
# with gr.Column(elem_classes=["my-column"]):
# with gr.Group(elem_classes=["my-group"]):
# image = WebRTC(label="Stream")
# image.stream(fn=segment, inputs=[image], outputs=[image])
# demo = gr.Interface(
# fn=segment,
# inputs=[gr.Image(sources=["webcam"], streaming=True)],
# outputs=["image"],
# title="Image Inference",
# cache_examples=False,
# live=True
# )
# if __name__ == "__main__":
# demo.queue().launch()