'''
M-LSD
Copyright 2021-present NAVER Corp.
Apache License v2.0
'''
# for demo
import os
from flask import Flask, request, session, json, Response, render_template, abort, send_from_directory
import requests
from urllib.request import urlopen
from io import BytesIO
import uuid
import cv2
import time
import argparse

# for tflite
import numpy as np
from PIL import Image
import tensorflow as tf

# for square detector
from utils import pred_squares

os.environ['CUDA_VISIBLE_DEVICES'] = '' # CPU mode

# flask
app = Flask(__name__)
logger = app.logger
logger.info('init demo app')

# config
parser = argparse.ArgumentParser()

## model parameters
parser.add_argument('--tflite_path', default='./tflite_models/M-LSD_512_large_fp16.tflite', type=str)
parser.add_argument('--input_size', default=512, type=int,
                    help='The size of input images.')

## LSD parameter
parser.add_argument('--score_thr', default=0.10, type=float,
                    help='Discard center points when the score < score_thr.')

## intersection point parameters
parser.add_argument('--outside_ratio', default=0.10, type=float,
                    help='''Discard an intersection point 
                    when it is located outside a line segment farther than line_length * outside_ratio.''')
parser.add_argument('--inside_ratio', default=0.50, type=float,
                    help='''Discard an intersection point
                    when it is located inside a line segment farther than line_length * inside_ratio.''')

## ranking boxes parameters
parser.add_argument('--w_overlap', default=0.0, type=float,
                    help='''When increasing w_overlap, the final box tends to overlap with
                    the detected line segments as much as possible.''')
parser.add_argument('--w_degree', default=1.14, type=float,
                    help='''When increasing w_degree, the final box tends to be
                    a parallel quadrilateral with reference to the angle of the box.''')
parser.add_argument('--w_length', default=0.03, type=float,
                    help='''When increasing w_length, the final box tends to be
                    a parallel quadrilateral with reference to the length of the box.''')
parser.add_argument('--w_area', default=1.84, type=float,
                    help='When increasing w_area, the final box tends to be the largest one out of candidates.')
parser.add_argument('--w_center', default=1.46, type=float,
                    help='When increasing w_center, the final box tends to be located in the center of input image.')

## flask demo parameter
parser.add_argument('--port', default=5000, type=int,
                    help='flask demo will be running on http://0.0.0.0:port/')


class model_graph:
    def __init__(self, args):
        self.interpreter, self.input_details, self.output_details = self.load_tflite(args.tflite_path)
        self.params = {'score': args.score_thr,'outside_ratio': args.outside_ratio,'inside_ratio': args.inside_ratio, 
                       'w_overlap': args.w_overlap,'w_degree': args.w_degree,'w_length': args.w_length,
                       'w_area': args.w_area,'w_center': args.w_center}
        self.args = args


    def load_tflite(self, tflite_path):
        interpreter = tf.lite.Interpreter(model_path=tflite_path)
        interpreter.allocate_tensors()
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()

        return interpreter, input_details, output_details


    def pred_tflite(self, image):
        segments, squares, score_array, inter_points = pred_squares(image, self.interpreter, self.input_details, self.output_details, [self.args.input_size, self.args.input_size], params=self.params)

        output = {}
        output['segments'] = segments
        output['squares'] = squares
        output['scores'] = score_array
        output['inter_points'] = inter_points

        return output


    def read_image(self, image_url):
        response = requests.get(image_url, stream=True)
        image = np.asarray(Image.open(BytesIO(response.content)).convert('RGB'))
        
        max_len = 1024
        h, w, _ = image.shape
        org_shape = [h, w]
        max_idx = np.argmax(org_shape)

        max_val = org_shape[max_idx]
        if max_val  > max_len:
            min_idx = (max_idx + 1) % 2
            ratio = max_len / max_val
            new_min = org_shape[min_idx] * ratio
            new_shape = [0, 0]
            new_shape[max_idx] = 1024
            new_shape[min_idx] = new_min

            image = cv2.resize(image, (int(new_shape[1]), int(new_shape[0])), interpolation=cv2.INTER_AREA)

        return image
    

    def init_resize_image(self, im, maximum_size=1024):
        h, w, _ = im.shape
        size = [h, w]
        max_arg = np.argmax(size)
        max_len = size[max_arg]
        min_arg = max_arg - 1
        min_len = size[min_arg]
        if max_len < maximum_size:
            return im
        else:
            ratio = maximum_size / max_len
            max_len = max_len * ratio
            min_len = min_len * ratio
            size[max_arg] = int(max_len)
            size[min_arg] = int(min_len)

            im = cv2.resize(im, (size[1], size[0]), interpolation = cv2.INTER_AREA)

            return im


    def decode_image(self, session_id, rawimg):
        dirpath = os.path.join('static/results', session_id)

        if not os.path.exists(dirpath):
            os.makedirs(dirpath)
        save_path = os.path.join(dirpath, 'input.png')
        input_image_url = os.path.join(dirpath, 'input.png')
    
        img = cv2.imdecode(np.frombuffer(rawimg, dtype='uint8'), 1)[:,:,::-1]
        img = self.init_resize_image(img)
        cv2.imwrite(save_path, img[:,:,::-1])

        return img, input_image_url


    def draw_output(self, image, output, save_path='test.png'):
        color_dict = {'red': [255, 0, 0],
                      'green': [0, 255, 0],
                      'blue': [0, 0, 255],
                      'cyan': [0, 255, 255],
                      'black': [0, 0, 0],
                      'yellow': [255, 255, 0],
                      'dark_yellow': [200, 200, 0]}
        
        line_image = image.copy()
        square_image = image.copy()
        square_candidate_image = image.copy()
        
        line_thick = 5

        # output > line array
        for line in output['segments']:
            x_start, y_start, x_end, y_end = [int(val) for val in line]
            cv2.line(line_image, (x_start, y_start), (x_end, y_end), color_dict['red'], line_thick)
        
        inter_image = line_image.copy()
               
        for pt in output['inter_points']:
            x, y = [int(val) for val in pt]
            cv2.circle(inter_image, (x, y), 10, color_dict['blue'], -1)

        for square in output['squares']:
            cv2.polylines(square_candidate_image, [square.reshape([-1, 1, 2])], True, color_dict['dark_yellow'], line_thick)

        for square in output['squares'][0:1]:
            cv2.polylines(square_image, [square.reshape([-1, 1, 2])], True, color_dict['yellow'], line_thick)
            for pt in square:
                cv2.circle(square_image, (int(pt[0]), int(pt[1])), 10, color_dict['cyan'], -1)

        '''
        square image | square candidates image
        inter image  | line image
        '''
        output_image = self.init_resize_image(square_image, 512)
        output_image = np.concatenate([output_image, self.init_resize_image(square_candidate_image, 512)], axis=1)
        output_image_tmp = np.concatenate([self.init_resize_image(inter_image, 512), self.init_resize_image(line_image, 512)], axis=1)
        output_image = np.concatenate([output_image, output_image_tmp], axis=0)

        cv2.imwrite(save_path, output_image[:,:,::-1])

        return output_image


    def save_output(self, session_id, input_image_url, image, output):
        dirpath = os.path.join('static/results', session_id)

        if not os.path.exists(dirpath):
            os.makedirs(dirpath)
        
        save_path = os.path.join(dirpath, 'output.png')
        self.draw_output(image, output, save_path=save_path)

        output_image_url = os.path.join(dirpath, 'output.png')

        rst = {}
        rst['input_image_url'] = input_image_url
        rst['session_id'] = session_id
        rst['output_image_url'] = output_image_url

        with open(os.path.join(dirpath, 'results.json'), 'w') as f:
            json.dump(rst, f)


def init_worker(args):
    global model
    
    model = model_graph(args)

   
@app.route('/')
def index():
    return render_template('index_scan.html', session_id='dummy_session_id')


@app.route('/', methods=['POST'])
def index_post():
    request_start = time.time()
    configs = request.form
            
    session_id = str(uuid.uuid1())

    image_url = configs['image_url'] # image_url
    
    if len(image_url) == 0:
        bio = BytesIO()
        request.files['image'].save(bio)
        rawimg = bio.getvalue()
        image, image_url = model.decode_image(session_id, rawimg)
    else:
        image = model.read_image(image_url)
    
    output = model.pred_tflite(image)
    
    model.save_output(session_id, image_url, image, output)

    return render_template('index_scan.html', session_id=session_id)


@app.route('/favicon.ico')
def favicon():
    return send_from_directory(os.path.join(app.root_path, 'static'),
                               'favicon.ico', mimetype='image/vnd.microsoft.icon')


if __name__ == '__main__':
    args = parser.parse_args()

    init_worker(args)

    app.run(host='0.0.0.0', port=args.port)