import os 
#os.system("pip install git+https://huggingface.co/spaces/Omnibus/real_esrgan_mod")
import gradio as gr 
import yt_dlp
import json
from rembg import remove as rm
import cv2
import uuid
import numpy as np 
import moviepy
from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
import moviepy.video.io.ImageSequenceClip
from moviepy.editor import *
from PIL import Image
#import esr
uid=uuid.uuid4()
if not os.path.exists(f'{uid}'): os.makedirs(f'{uid}')
if not os.path.exists(f'{uid}-frames'): os.makedirs(f'{uid}-frames')
if not os.path.exists(f'{uid}-rembg'): os.makedirs(f'{uid}-rembg')
esr = gr.Interface.load("spaces/Omnibus/Real_ESRGAN_mod")

load_js = """
function(text_input, url_params) {
    console.log(text_input, url_params);
    const params = new URLSearchParams(window.location.search);
    url_params = Object.fromEntries(params);
    return [text_input, url_params]
}
"""
def rem_cv(inp):
    cap = cv2.VideoCapture(f'{inp}')
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3))
    fgbg = cv2.bgsegm.createBackgroundSubtractorGMG()
    while True:
        ret, frame = cap.read()
        fgmask = fgbg.apply(frame)
        fgmask = cv2.morphologyEx(fgmask, cv2.MORPH_OPEN, kernel)
        yield (fgmask)

def load_video(vid):
    new_video_in = str(vid)
    capture = cv2.VideoCapture(new_video_in)
    fps = capture.get(cv2.CAP_PROP_FPS)
    msc = capture.get(cv2.CAP_PROP_POS_MSEC)
    frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
    print (msc)
    msc = float(frame_count/fps)
    print (msc)
    capture.release()
    vid_t = round(msc)

    hours = int(vid_t/360)
    minutes = int(vid_t/60)-(hours*360)
    seconds = vid_t-(minutes*60)-(hours*360)
    
    
    vid_len = f'{hours}:{minutes}:{seconds}'
    
    #vid_len_up = gr.update(label = "End", value = f"{vid_len}", placeholder = "0:00:54")
    return frame_count, fps, vid_len,vid

def capture_function(vid):
    new_video_in = str(f"{uid}-clip.mp4")
    capture = cv2.VideoCapture(new_video_in)
    frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
    fbox=[]
    cnt=0
    frame_count1= int(frame_count)
    for i in range(int(frame_count1)-1):
        capture.set(cv2.CAP_PROP_POS_FRAMES, i)
        ret, frame_f = capture.read(i)
        #frame_ff = cv2.cvtColor(frame_f, cv2.COLOR_BGR2RGB)
        cv2.imwrite(f'{uid}-frames/{i+1}.png',frame_f)
        fbox.append(f'{uid}-frames/{i+1}.png')
        frame_num=f'Working on {i+1} of {frame_count1}'
        
        yield fbox,fbox,frame_num
    #return fbox,fbox,frame_num
    yield (fbox,fbox,frame_num)

def im_2_vid(images,fps):
    #width, height = Image.open(images[0]).size
    this_im=cv2.imread(images[0])
    height=this_im.shape[0]
    width= this_im.shape[1]
    print (width,height)
    size = (width, height)
    movie_clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip(images, fps)
    movie_clip.write_videofile(f'{uid}-rembg/bg_removed-{uid}.mp4')    
    #codec = cv2.VideoWriter_fourcc(*'mp4v') #DIVX, DIVD
    #video = cv2.VideoWriter("tmp_vid.mp4", codec, fps, size)
    #for img in images:
    #        video.write(img)
    
    
    return (f'{uid}-rembg/bg_removed-{uid}.mp4', f'{uid}-rembg/bg_removed-{uid}.mp4')
def rem_bg(vid):
    new_video_in = str(vid)
    capture = cv2.VideoCapture(new_video_in)    
    frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = capture.get(cv2.CAP_PROP_FPS)
    fbox2=[]
    cnt=0
    frame_count1= int(frame_count)
    for i in range(int(frame_count1)-1):
        capture.set(cv2.CAP_PROP_POS_FRAMES, i)
        ret, frame_f = capture.read(i)
        #frame_ff = cv2.cvtColor(frame_f, cv2.COLOR_BGR2RGB)
        out = rm(frame_f)
        print (i)
        cv2.imwrite(f'{uid}-rembg/{i}.png',out)
        fbox2.append(f'{uid}-rembg/{i}.png')
        frame_num=f'Working on {i+1} of {frame_count1-1}'
        yield fbox2,frame_num,None
    frame_num=f'Done: {i+1} of {frame_count1}'
    out_vid,_file = im_2_vid(fbox2, fps)
    yield (fbox2,frame_num,out_vid)


def predict(text, url_params):
    mod_url=""
    mod=gr.HTML("")
    out = None
    valid=gr.update(visible=False)
    mod_url = url_params.get('url')
    print (mod_url)
    return ["" + text + "", mod_url]

def dl(inp,img):
    out = None
    out_file=[]
    if img == None and inp !="":
        try:
            inp_out=inp.replace("https://","")
            inp_out=inp_out.replace("/","_").replace(".","_").replace("=","_").replace("?","_")
            os.system(f'yt-dlp "{inp}" --trim-filenames 160 -o "{uid}/{inp_out}.mp4" -S res,mp4 --recode mp4')  
            out = f"{uid}/{inp_out}.mp4"
            capture = cv2.VideoCapture(out)
            fps = capture.get(cv2.CAP_PROP_FPS)
            capture.release()
        except Exception as e:
            out = None
    elif img !=None and inp == "":
        capture = cv2.VideoCapture(img)
        fps = capture.get(cv2.CAP_PROP_FPS)
        capture.release()
        out = f"{img}"
    return out,out,out,out,fps
def dl_json(inp):
    out_json={}
    out_file=[]
    try:
        inp_out=inp.replace("https://","")
        inp_out=inp_out.replace("/","_").replace(".","_").replace("=","_").replace("?","_")
        os.system(f'yt-dlp "{inp}" --write-info-json --skip-download -o "{inp_out}"')  
        out_file.append(f"{inp_out}.info.json")
        out_json=f'{inp_out}.info.json'
        try:
            with open(f"{inp_out}.info.json", "r") as f:
                f_out = f.read()
            json_object = json.loads(f_out)
            out_json = json.dumps(json_object, indent=4)
        except Exception as e:
            print (e)
    except Exception as e:
        print (e)
    return out_json
def trim_vid(vid,start_time,end_time):
    print (vid)
    start_hr=float(start_time.split(":",2)[0])*360
    start_min=int(start_time.split(":",2)[1])*60
    start_sec=int(start_time.split(":",2)[2])
    end_hr=int(end_time.split(":",2)[0])*360
    end_min=int(end_time.split(":",2)[1])*60
    end_sec=float(end_time.split(":",2)[2])
    start=start_hr+start_min+start_sec
    end=end_hr+end_min+end_sec

    clip = VideoFileClip(vid)
    mod_fps=clip.fps
    clip = clip.subclip(start, end)
    clip.write_videofile(f"{uid}-clip.mp4", fps=mod_fps)    
    #ffmpeg_extract_subclip(vid, start, end, targetname=f"{uid}-clip.mp4")
    out= f"{uid}-clip.mp4"
    capture = cv2.VideoCapture(out)    
    frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))    
    capture.release()
    return out,out,frame_count
    
def other():    
    new_video_in = str(vid)
    capture = cv2.VideoCapture(new_video_in)    
    frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = capture.get(cv2.CAP_PROP_FPS)
    fbox3=[]
    frame_count1= int(frame_count)
    start = (int(fps*int(start_f)))
    capture.set(cv2.CAP_PROP_POS_FRAMES, i-1)
    for i in range(int(frame_count1)):
        capture.set(cv2.CAP_PROP_POS_FRAMES, i)
        ret, frame_f = capture.read(i)
        #frame_ff = cv2.cvtColor(frame_f, cv2.COLOR_BGR2RGB)
        out = rm(frame_f)
        print (i)
        cv2.imwrite(f'{uid}-rembg/{i}.png',out)
        fbox3.append(f'{uid}-rembg/{i}.png')
        frame_num=f'Working on {i+1} of {frame_count}'
        yield fbox2,frame_num,None    

def make_gif(fps):
    _=None
    mes="Gathering Video..."
    yield _,_,mes
    clip = VideoFileClip(f"{uid}-clip.mp4")
    mes="Writing GIF"
    yield _,_,mes
    clip.write_gif(f"{uid}/clip_gif.gif",program='ffmpeg', fps=float(fps)) 
    mes="Saving GIF"
    yield _,_,mes
    out = f"{uid}/clip_gif.gif"
    mes="GIF Complete"
    yield out,out,mes

def update_speed(inp,clip_speed,fps):
    if "-split.mp4" in inp:
        out_inp = f'{inp.split("-split.mp4",1)[0]}-split.mp4'
    else:
        out_inp = f'{inp.split(".mp4",1)[0]}-split.mp4'
    mod_fps=float(fps)*float(clip_speed)
    clip = VideoFileClip(inp)
    #final = clip.fx(vfx.speedx, clip_speed)
    final = moviepy.video.fx.all.speedx(clip, factor=clip_speed)    
    final.write_videofile(f'{out_inp}', fps=mod_fps) 
    out = f'{out_inp}'
    return out,out 

def echo_fn(inp):
    return inp 
def check_load(inp_url,outp_vid,hid_box,start_f,end_f):
    if outp_vid == None and inp_url !="":
        out_trim,in_vid,trim_count=trim_vid(hid_box,start_f,end_f)
    elif outp_vid !=None and inp_url == "" and hid_box =="":
        out_trim = None
        in_vid=outp_vid
        trim_count = ""
    elif outp_vid !=None and inp_url == "" and hid_box !="":
        out_trim,in_vid,trim_count=trim_vid(hid_box,start_f,end_f)
    return out_trim,in_vid,trim_count      
#outp_vid.change(echo_fn,outp_vid,[out_trim])
def process_image_1(image):
    #print(image)
    rand_im = uuid.uuid4()
    cv2.imwrite(f"{rand_im}-vid_tmp_proc.png", image)
    #image = f"{rand_im}-vid_tmp_proc.png"
    #image=Image.fromarray(image)
    
    #out = esr.realesrgan1(img=image, model_name="realesr-general-x4v3", denoise_strength=0.5, face_enhance=True, outscale=1)
    out = os.path.abspath(f"{rand_im}-vid_tmp_proc.png")
    out_url = f'https://omnibus-vid-url-dl-mod.hf.space/file={out}'
    out = esr(out_url, "realesr-general-x4v3", 0.5, True, 1)
    #out = esr.realesrgan1(image, "realesr-general-x4v3", 0.5, False, 2)
    print (out)
    out = Image.open(out)
    out = np.array(out)
    #print (out)
    return out
def improve_quality(model_name,denoise_strength,outscale,face_enhance):
    
    clip1 = VideoFileClip(f"{uid}-clip.mp4")
    clip = clip1.fl_image(process_image_1)
    clip.write_videofile(f"{uid}-clip-high.mp4")
    return f"{uid}-clip-high.mp4"






with gr.Blocks() as app:
    with gr.Tab("Load"):
        with gr.Row():
            with gr.Column():
                inp_url = gr.Textbox()
                go_btn = gr.Button("Run")
                outp_vid=gr.Video(format="mp4")
            with gr.Column():
                with gr.Row():
                    frame_count=gr.Textbox(label="Frame Count",interactive = False)
                    fps=gr.Textbox(label="FPS",interactive = False)
                outp_file=gr.Files()
                clip_speed = gr.Slider(label="Speed", minimum=0.01, maximum=2, value=1, step=0.01)
                speed_btn = gr.Button("Update Speed")
                with gr.Row():
                    start_f = gr.Textbox(label = "Start", value = "0:00:00", placeholder = "0:00:23",interactive = True)
                    end_f = gr.Textbox(label = "End", value = "0:00:05", placeholder = "0:00:54",interactive = True)
                    trim_count = gr.Textbox(label="Trimmed Frames")
                trim_btn=gr.Button("Trim")
                out_trim=gr.Video(format="mp4")
                hid_box = gr.Textbox(visible=True)
                hid_fps = gr.Textbox(visible=True)
    with gr.Tab("Quality"):
        with gr.Row():
            model_name = gr.Dropdown(label="Real-ESRGAN inference model to be used",
                                     choices=["RealESRGAN_x4plus", "RealESRNet_x4plus", "RealESRGAN_x4plus_anime_6B",
                                              "RealESRGAN_x2plus", "realesr-general-x4v3"],
                                     value="realesr-general-x4v3", show_label=True)
            denoise_strength = gr.Slider(label="Denoise Strength (Used only with the realesr-general-x4v3 model)",
                                         minimum=0, maximum=1, step=0.1, value=0.5)
            outscale = gr.Slider(label="Image Upscaling Factor",
                                 minimum=1, maximum=10, step=1, value=2, show_label=True)
            face_enhance = gr.Checkbox(label="Face Enhancement using GFPGAN (Doesn't work for anime images)",
                                       value=False, show_label=True)        
        impro_btn = gr.Button("Run")
        with gr.Row():
            with gr.Column():
                clip_in = gr.Video()
            with gr.Column():
                clip_out = gr.Video()
                
    with gr.Tab("Frames"):
        with gr.Row():
            frame_btn = gr.Button("Get Frames")
            frame_stat=gr.Textbox(label="Status")
        with gr.Row():
            with gr.Column():
                frame_gal = gr.Gallery(columns=6)
            with gr.Column():
                frame_file = gr.Files()
   
    with gr.Tab("GIF"):
        with gr.Row():
            with gr.Column():
                with gr.Row():
                    check_qual=gr.Checkbox(label="Improve Quality", value=False)
                    gif_btn = gr.Button("Make GIF")
            with gr.Column():
                gif_stat=gr.Textbox(label="Status")
        with gr.Row():
            with gr.Column():
                gif_show = gr.Video()
            with gr.Column():
                gif_file = gr.Files()                
        
    with gr.Tab("Rem BG"):
        with gr.Row():    
            with gr.Column():
                rem_btn=gr.Button()
                in_vid=gr.Video(format="mp4")
            with gr.Column():
                #rem_vid=gr.Video()
                frame_num=gr.Textbox(label="Progress")
                rem_vid=gr.Gallery(columns=6)
        rem_bg_vid=gr.Video()        
    with gr.Tab("Info"):
        with gr.Row():
            info_btn = gr.Button("Load")
            info_stat=gr.Textbox(label="Status")
        info_json = gr.JSON()    
    with gr.Row(visible=False):
        text_input=gr.Textbox()
        text_output=gr.Textbox()
        url_params=gr.JSON()


    
    impro_btn.click(improve_quality,[model_name,denoise_strength,outscale,face_enhance],clip_out)

        
    info_btn.click(dl_json,inp_url,info_json)
    speed_btn.click(update_speed,[hid_box,clip_speed,hid_fps],[outp_vid,hid_box])
    gif_btn.click(make_gif,fps,[gif_show,gif_file,gif_stat])
    trim_btn.click(trim_vid,[hid_box,start_f,end_f],[out_trim,in_vid,trim_count])
    outp_vid.change(load_video,outp_vid,[frame_count,fps,end_f,hid_box]).then(trim_vid,[hid_box,start_f,end_f],[out_trim,in_vid,trim_count])
    frame_btn.click(capture_function,[out_trim],[frame_gal,frame_file,frame_stat])
    rem_btn.click(rem_bg,[out_trim],[rem_vid,frame_num,rem_bg_vid])
    go_btn.click(dl,[inp_url,outp_vid],[outp_vid,outp_file,out_trim,hid_box,hid_fps])
    app.load(fn=predict, inputs=[text_input,url_params], outputs=[text_output,text_input],_js=load_js)
app.queue(concurrency_count=10).launch()