import os
import torch
import numpy as np
import imgui
import dnnlib
from gui_utils import imgui_utils

#----------------------------------------------------------------------------

class DragWidget:
    def __init__(self, viz):
        self.viz            = viz
        self.point          = [-1, -1]
        self.points         = []
        self.targets        = []
        self.is_point       = True
        self.last_click     = False
        self.is_drag        = False
        self.iteration      = 0
        self.mode           = 'point'
        self.r_mask         = 50
        self.show_mask      = False
        self.mask           = torch.ones(256, 256)
        self.lambda_mask    = 20
        self.feature_idx    = 5
        self.r1             = 3
        self.r2             = 12
        self.path           = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '_screenshots'))
        self.defer_frames   = 0
        self.disabled_time  = 0

    def action(self, click, down, x, y):
        if self.mode == 'point':
            self.add_point(click, x, y)
        elif down:
            self.draw_mask(x, y)

    def add_point(self, click, x, y):
        if click:
            self.point = [y, x]
        elif self.last_click:
            if self.is_drag:
                self.stop_drag()
            if self.is_point:
                self.points.append(self.point)
                self.is_point = False
            else:
                self.targets.append(self.point)
                self.is_point = True
        self.last_click = click

    def init_mask(self, w, h):
        self.width, self.height = w, h
        self.mask = torch.ones(h, w)

    def draw_mask(self, x, y):
        X = torch.linspace(0, self.width, self.width)
        Y = torch.linspace(0, self.height, self.height)
        yy, xx = torch.meshgrid(Y, X)
        circle = (xx - x)**2 + (yy - y)**2 < self.r_mask**2
        if self.mode == 'flexible':
            self.mask[circle] = 0
        elif self.mode == 'fixed':
            self.mask[circle] = 1

    def stop_drag(self):
        self.is_drag = False
        self.iteration = 0

    def set_points(self, points):
        self.points = points

    def reset_point(self):
        self.points = []
        self.targets = []
        self.is_point = True

    def load_points(self, suffix):
        points = []
        point_path = self.path + f'_{suffix}.txt'
        try:
            with open(point_path, "r") as f:
                for line in f.readlines():
                    y, x = line.split()
                    points.append([int(y), int(x)])
        except:
            print(f'Wrong point file path: {point_path}')
        return points

    @imgui_utils.scoped_by_object_id
    def __call__(self, show=True):
        viz = self.viz
        reset = False
        if show:
            with imgui_utils.grayed_out(self.disabled_time != 0):
                imgui.text('Drag')
                imgui.same_line(viz.label_w)

                if imgui_utils.button('Add point', width=viz.button_w, enabled='image' in viz.result):
                    self.mode = 'point'

                imgui.same_line()
                reset = False
                if imgui_utils.button('Reset point', width=viz.button_w, enabled='image' in viz.result):
                    self.reset_point()
                    reset = True

                imgui.text(' ')
                imgui.same_line(viz.label_w)
                if imgui_utils.button('Start', width=viz.button_w, enabled='image' in viz.result):
                    self.is_drag = True
                    if len(self.points) > len(self.targets):
                        self.points = self.points[:len(self.targets)]

                imgui.same_line()
                if imgui_utils.button('Stop', width=viz.button_w, enabled='image' in viz.result):
                    self.stop_drag()

                imgui.text(' ')
                imgui.same_line(viz.label_w)
                imgui.text(f'Steps: {self.iteration}')
                
                imgui.text('Mask')
                imgui.same_line(viz.label_w)
                if imgui_utils.button('Flexible area', width=viz.button_w, enabled='image' in viz.result):
                    self.mode = 'flexible'
                    self.show_mask = True
                
                imgui.same_line()
                if imgui_utils.button('Fixed area', width=viz.button_w, enabled='image' in viz.result):
                    self.mode = 'fixed'
                    self.show_mask = True
                
                imgui.text(' ')
                imgui.same_line(viz.label_w)
                if imgui_utils.button('Reset mask', width=viz.button_w, enabled='image' in viz.result):
                    self.mask = torch.ones(self.height, self.width)
                imgui.same_line()
                _clicked, self.show_mask = imgui.checkbox('Show mask', self.show_mask)

                imgui.text(' ')
                imgui.same_line(viz.label_w)
                with imgui_utils.item_width(viz.font_size * 6):
                    changed, self.r_mask = imgui.input_int('Radius', self.r_mask)

                imgui.text(' ')
                imgui.same_line(viz.label_w)
                with imgui_utils.item_width(viz.font_size * 6):
                    changed, self.lambda_mask = imgui.input_int('Lambda', self.lambda_mask)

        self.disabled_time = max(self.disabled_time - viz.frame_delta, 0)
        if self.defer_frames > 0:
            self.defer_frames -= 1
        viz.args.is_drag = self.is_drag
        if self.is_drag:
            self.iteration += 1
        viz.args.iteration = self.iteration
        viz.args.points = [point for point in self.points]
        viz.args.targets = [point for point in self.targets]
        viz.args.mask = self.mask
        viz.args.lambda_mask = self.lambda_mask
        viz.args.feature_idx = self.feature_idx
        viz.args.r1 = self.r1
        viz.args.r2 = self.r2
        viz.args.reset = reset


#----------------------------------------------------------------------------