gra / run.py
noamholz's picture
Update run.py
457380a verified
raw
history blame
1.62 kB
import gradio as gr
import numpy as np
from time import sleep
import torch
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from torchvision import transforms
weights2load = 'segformer_ep15_loss0.00.pth'
id2label = {0: 'seal', 255: 'bck'}
label2id = {'seal': 0, 'bck': 255}
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
num_labels=2,
id2label=id2label,
label2id=label2id,
)
image_processor = SegformerImageProcessor(reduce_labels=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(weights2load, weights_only=True, map_location=device))
model.to(device).eval()
def flip_periodically(im, interval_ms=0):
"""
Flips the image periodically with the given interval.
Args:
im: The input image.
interval_ms: The interval in milliseconds between flips.
Returns:
The flipped image.
"""
transforms.ToTensor()(im)
image_processor(im, return_tensors="pt").pixel_values.to(device)
outputs = model(pixel_values=pixel_values)
logits = outputs.logits.cpu()
sleep(interval_ms / 1000) # Convert milliseconds to seconds
return logits[0, 0] #np.flipud(im)
with gr.Blocks() as demo:
inp = gr.Image(sources=["webcam"], streaming=True)
out = gr.Image()
inp.stream(flip_periodically, inputs=inp, outputs=out)
if __name__ == "__main__":
demo.launch()