import os
import json
from PIL import Image
from skimage import io
import gradio as gr
from modelscope_studio import encode_image, decode_image, call_demo_service


yes, no = "是", "否"

def get_size(h, w, max_size=720):
    if min(h, w) > max_size:
        if h > w:
            h, w = int(max_size * h / w), max_size
        else:
            h, w = max_size, int(max_size * w / h)
    return h, w


def inference(img: Image, colorization_option: str, image_denoise_option: str, color_enhance_option: str) -> Image:
    if img is None:
        return None
    w, h = img.size
    h, w = get_size(h, w, 512)
    img = img.resize((w, h))
    
    input_url = encode_image(img)
    res_url = input_url
    
    # image-denoising (optional)
    if image_denoise_option == yes:
        data = {
            "task": "image-denoising",
            "inputs": [
                res_url
            ],
            "parameters":{},
            "urlPaths": {
                "inUrls": [
                    {
                        "value": res_url,
                        "fileType": "png",
                        "type": "image",
                        "displayType": "ImgUploader",
                        "validator": {
                            "accept": "*.jpeg,*.jpg,*.png",
                            "max_resolution": "5000*5000",
                            "max_size": "10m"
                        },
                        "name": "",
                        "title": ""
                    }
                ],
                "outUrls": [
                    {
                        "outputKey": "output_img",
                        "type": "image"
                    }
                ]
            }
        }
        result = call_demo_service(
            path='damo', name='cv_nafnet_image-denoise_sidd', data=json.dumps(data))
        print(f"image-denoising result: {result}")
        res_url = result['data']['output_img']

    # image-colorization (optional)
    if colorization_option == yes:
        data = {
            "task": "image-colorization",
            "inputs": [
                res_url
            ],
            "parameters":{},
            "urlPaths": {
                "inUrls": [
                    {
                        "value": res_url,
                        "fileType": "png",
                        "type": "image",
                        "displayType": "ImgUploader",
                        "validator": {
                            "accept": "*.jpeg,*.jpg,*.png",
                            "max_size": "10m",
                            "max_resolution": "5000*5000",
                        },
                        "name": "",
                        "title": ""
                    }
                ],
                "outUrls": [
                    {
                        "outputKey": "output_img",
                        "type": "image"
                    }
                ]
            }
        }
        result = call_demo_service(
            path='damo', name='cv_ddcolor_image-colorization', data=json.dumps(data))
        print(f"image-colorization result: {result}")
        res_url = result['data']['output_img']


    # image-portrait-enhancement
    data = {
        "task": "image-portrait-enhancement",
        "inputs": [
            res_url
        ],
        "parameters":{},
        "urlPaths": {
            "inUrls": [
                {
                    "value": res_url,
                    "fileType": "png",
                    "type": "image",
                    "displayType": "ImgUploader",
                    "validator": {
                        "accept": "*.jpeg,*.jpg,*.png",
                        "max_size": "10M",
                        "max_resolution": "2000*2000",
                    },
                    "name": "",
                    "title": ""
                }
            ],
            "outUrls": [
                {
                    "outputKey": "output_img",
                    "type": "image"
                }
            ]
        }
    }
    result = call_demo_service(
        path='damo', name='cv_gpen_image-portrait-enhancement', data=json.dumps(data))
    print(f"image-portrait-enhancement result: {result}")
    res_url = result['data']['output_img']

    # image-color-enhancement (optional)
    if color_enhance_option == yes:
        data = {
            "task": "image-color-enhancement",
            "inputs": [
                res_url
            ],
            "parameters":{},
            "urlPaths": {
                "inUrls": [
                    {
                        "value": res_url,
                        "fileType": "png",
                        "type": "image",
                        "displayType": "ImgUploader",
                        "validator": {
                            "accept": "*.jpeg,*.jpg,*.png",
                            "max_size": "10m",
                            "max_resolution": "5000*5000",
                        },
                        "name": "",
                        "title": ""
                    }
                ],
                "outUrls": [
                    {
                        "outputKey": "output_img",
                        "type": "image"
                    }
                ]
            }
        }
        result = call_demo_service(
            path='damo', name='cv_csrnet_image-color-enhance-models', data=json.dumps(data))
        print(f"image-color-enhancement result: {result}")
        res_url = result['data']['output_img']


    res_img = decode_image(res_url)

    return res_img


title = "AI老照片修复"
description = '''
输入一张老照片,点击一键修复,就能获得由AI完成画质增强、智能上色等处理后的彩色照片!还等什么呢?快让相册里的老照片坐上时光机吧~
'''
examples = [[os.path.dirname(__file__) + './images/input1.jpg'], 
            [os.path.dirname(__file__) + './images/input2.jpg'], 
            [os.path.dirname(__file__) + './images/input3.jpg'], 
            [os.path.dirname(__file__) + './images/input4.jpg'],
            [os.path.dirname(__file__) + './images/input5.jpg']]

css_style = "#overview {margin: auto;max-width: 600px; max-height: 400px; width: 100%;}"

with gr.Blocks(title=title, css=css_style) as demo:
    gr.HTML('''
        <div style="text-align: center; max-width: 720px; margin: 0 auto;">
            <img id="overview" alt="overview" src="https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/public/ModelScope/studio_old_photo_restoration/overview_long.gif" />
        </div>
      ''')
    gr.Markdown(description)
    with gr.Row():
        with gr.Column(scale=2):
            img_input = gr.components.Image(label="图片", type="pil")
            colorization_option = gr.components.Radio(label="重新上色", choices=[yes, no], value=yes)
            image_denoise_option = gr.components.Radio(label="应用图像去噪(存在细节损失风险)", choices=[yes, no], value=no)
            color_enhance_option = gr.components.Radio(label="应用色彩增强(存在罕见色调风险)", choices=[yes, no], value=no)
            btn = gr.Button("一键修复")
        with gr.Column(scale=3):
            img_output = gr.components.Image(label="图片", type="pil").style(height=600)
    inputs = [img_input, colorization_option, image_denoise_option, color_enhance_option]
    btn.click(fn=inference, inputs=inputs, outputs=img_output)
    gr.Examples(examples, inputs=img_input)

demo.launch()