import spaces
import gradio as gr
import cv2
import numpy as np
import time
import random
from PIL import Image
import torch

torch.jit.script = lambda f: f

from transparent_background import Remover

@spaces.GPU()
def doo(video, mode, progress=gr.Progress()):
    if mode == 'Fast':
        remover = Remover(mode='fast')
    else:
        remover = Remover()

    cap = cv2.VideoCapture(video)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))  # Get total frames
    writer = None
    tmpname = random.randint(111111111, 999999999)
    processed_frames = 0
    start_time = time.time()

    while cap.isOpened():
        ret, frame = cap.read()

        if ret is False:
            break

        if time.time() - start_time >= 20 * 60 - 5:
            print("GPU Timeout is coming")
            cap.release()
            writer.release()
            return str(tmpname) + '.mp4'
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(frame).convert('RGB')

        if writer is None:
            writer = cv2.VideoWriter(str(tmpname) + '.mp4', cv2.VideoWriter_fourcc(*'mp4v'), cap.get(cv2.CAP_PROP_FPS), img.size)

        processed_frames += 1
        print(f"Processing frame {processed_frames}")
        progress(processed_frames / total_frames, desc=f"Processing frame {processed_frames}/{total_frames}")
        out = remover.process(img, type='green')
        writer.write(cv2.cvtColor(np.array(out), cv2.COLOR_BGR2RGB))

    cap.release()
    writer.release()
    return str(tmpname) + '.mp4'

examples = [['./mp4.mp4']]

css = """
footer {
    visibility: hidden;
}
"""


iface = gr.Interface(theme="Yntec/HaleyCH_Theme_Orange", css=css,
    fn=doo,
    inputs=["video", gr.components.Radio(['Normal', 'Fast'], label='Select mode', value='Normal', info='Normal is more accurate, but takes longer. | Fast has lower accuracy so the process will be faster.')],
    outputs="video",
    examples=examples
)
iface.launch()