gra / run.py
noamholz's picture
added command to load weights
42159c7 verified
raw
history blame
1.57 kB
import gradio as gr
import numpy as np
from time import sleep
import torch
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
weights2load = 'segformer_ep15_loss0.00.pth'
id2label = {0: 'seal', 255: 'bck'}
label2id = {'seal': 0, 'bck': 255}
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
num_labels=2,
id2label=id2label,
label2id=label2id,
)
image_processor = SegformerImageProcessor(reduce_labels=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(weights2load, weights_only=True, map_location=device))
model.to(device).eval()
def flip_periodically(im, interval_ms=0):
"""
Flips the image periodically with the given interval.
Args:
im: The input image.
interval_ms: The interval in milliseconds between flips.
Returns:
The flipped image.
"""
pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device)
outputs = model(pixel_values=pixel_values)
logits = outputs.logits.cpu()
sleep(interval_ms / 1000) # Convert milliseconds to seconds
return logits[0, 0] #np.flipud(im)
with gr.Blocks() as demo:
inp = gr.Image(sources=["webcam"], streaming=True)
out = gr.Image()
inp.stream(flip_periodically, inputs=inp, outputs=out)
if __name__ == "__main__":
demo.launch()