File size: 3,329 Bytes
80c08a8
 
 
 
 
 
 
5fdbdde
80c08a8
 
 
 
 
5fdbdde
80c08a8
 
 
 
 
62f156c
80c08a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import numpy as np
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from skimage import io
import torch, os
from PIL import Image
from briarmbg import BriaRMBG
import gradio as gr
import cv2
import numpy as np
import time
import random
from PIL import Image


bgrm = BriaRMBG.from_pretrained("briaai/RMBG-1.4")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bgrm.to(device)
print("device:", device)


def resize_image(image):
    image = image.convert('RGB')
    model_input_size = (1024, 1024)
    image = image.resize(model_input_size, Image.BILINEAR)
    return image


def process(image):

    # prepare input
    orig_image = Image.fromarray(image)
    w,h = orig_im_size = orig_image.size
    image = resize_image(orig_image)
    im_np = np.array(image)
    im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
    im_tensor = torch.unsqueeze(im_tensor,0)
    im_tensor = torch.divide(im_tensor,255.0)
    im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
    if torch.cuda.is_available():
        im_tensor=im_tensor.cuda()

    #inference
    result=bgrm(im_tensor)
    # post process
    result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
    ma = torch.max(result)
    mi = torch.min(result)
    result = (result-mi)/(ma-mi)    
    # image to pil
    im_array = (result*255).cpu().data.numpy().astype(np.uint8)
    pil_im = Image.fromarray(np.squeeze(im_array))
    # paste the mask on the original image
    new_im = Image.new("RGBA", pil_im.size, (0,255,0,255))
    new_im.paste(orig_image, mask=pil_im)
    # new_orig_image = orig_image.convert('RGBA')
    return new_im





def process_video(video, progress=gr.Progress()):
    
    cap = cv2.VideoCapture(video)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))  # Get total frames
    writer = None
    tmpname ='output.mp4'
    processed_frames = 0
    start_time = time.time()
    i=0
    while cap.isOpened():
        ret, frame = cap.read()

        if ret is False:
            break

        if time.time() - start_time >= 20 * 60 - 5:
            print("GPU Timeout is coming")
            cap.release()
            writer.release()
            return tmpname
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(frame).convert('RGB')

        if writer is None:
            writer = cv2.VideoWriter(tmpname, cv2.VideoWriter_fourcc(*'mp4v'), cap.get(cv2.CAP_PROP_FPS), img.size)

        processed_frames += 1
        print(f"Processing frame {processed_frames}")
        progress(processed_frames / total_frames, desc=f"Processing frame {processed_frames}/{total_frames}")
        out = process(np.array(img))
        writer.write(cv2.cvtColor(np.array(out), cv2.COLOR_BGR2RGB))

    cap.release()
    writer.release()
    return tmpname

title = "🎞️ Video Background Removal Tool 🎥"
description = """Please note that if your video file is long (has a high number of frames), there is a chance that processing break due to GPU timeout. In this case, consider trying Fast mode."""

examples = [['./input.mp4']]

iface = gr.Interface(
    fn=process_video,
    inputs=["video"],
    outputs="video",
    examples=examples,
    title=title,
    description=description
)
iface.launch()