Spaces:
Runtime error
Runtime error
File size: 3,329 Bytes
80c08a8 5fdbdde 80c08a8 5fdbdde 80c08a8 62f156c 80c08a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import numpy as np
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from skimage import io
import torch, os
from PIL import Image
from briarmbg import BriaRMBG
import gradio as gr
import cv2
import numpy as np
import time
import random
from PIL import Image
bgrm = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bgrm.to(device)
print("device:", device)
def resize_image(image):
image = image.convert('RGB')
model_input_size = (1024, 1024)
image = image.resize(model_input_size, Image.BILINEAR)
return image
def process(image):
# prepare input
orig_image = Image.fromarray(image)
w,h = orig_im_size = orig_image.size
image = resize_image(orig_image)
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
im_tensor = torch.unsqueeze(im_tensor,0)
im_tensor = torch.divide(im_tensor,255.0)
im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
if torch.cuda.is_available():
im_tensor=im_tensor.cuda()
#inference
result=bgrm(im_tensor)
# post process
result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
ma = torch.max(result)
mi = torch.min(result)
result = (result-mi)/(ma-mi)
# image to pil
im_array = (result*255).cpu().data.numpy().astype(np.uint8)
pil_im = Image.fromarray(np.squeeze(im_array))
# paste the mask on the original image
new_im = Image.new("RGBA", pil_im.size, (0,255,0,255))
new_im.paste(orig_image, mask=pil_im)
# new_orig_image = orig_image.convert('RGBA')
return new_im
def process_video(video, progress=gr.Progress()):
cap = cv2.VideoCapture(video)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Get total frames
writer = None
tmpname ='output.mp4'
processed_frames = 0
start_time = time.time()
i=0
while cap.isOpened():
ret, frame = cap.read()
if ret is False:
break
if time.time() - start_time >= 20 * 60 - 5:
print("GPU Timeout is coming")
cap.release()
writer.release()
return tmpname
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = Image.fromarray(frame).convert('RGB')
if writer is None:
writer = cv2.VideoWriter(tmpname, cv2.VideoWriter_fourcc(*'mp4v'), cap.get(cv2.CAP_PROP_FPS), img.size)
processed_frames += 1
print(f"Processing frame {processed_frames}")
progress(processed_frames / total_frames, desc=f"Processing frame {processed_frames}/{total_frames}")
out = process(np.array(img))
writer.write(cv2.cvtColor(np.array(out), cv2.COLOR_BGR2RGB))
cap.release()
writer.release()
return tmpname
title = "🎞️ Video Background Removal Tool 🎥"
description = """Please note that if your video file is long (has a high number of frames), there is a chance that processing break due to GPU timeout. In this case, consider trying Fast mode."""
examples = [['./input.mp4']]
iface = gr.Interface(
fn=process_video,
inputs=["video"],
outputs="video",
examples=examples,
title=title,
description=description
)
iface.launch()
|