import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from mediapipe.framework.formats import landmark_pb2
from mediapipe import solutions
import numpy as np

# heavy changed in gradio app

# for X,Y,W,H to x1,y1,x2,y2(Left-top,right-bottom style)
def xywh_to_xyxy(box):
  return [box[0],box[1],box[0]+box[2],box[1]+box[3]]

def to_int_box(box):
  return [int(box[0]),int(box[1]),int(box[2]),int(box[3])]

def convert_to_box(face_landmarks_list,indices,w=1024,h=1024):
  x1=w
  y1=h
  x2=0
  y2=0
  for index in indices:
    x=min(w,max(0,(face_landmarks_list[0][index].x*w)))
    y=min(h,max(0,(face_landmarks_list[0][index].y*h)))
    if x<x1:
      x1=x

    if y<y1:
      y1=y
    
    if x>x2:
      x2=x
    if y>y2:
      y2=y
   
        
  return [int(x1),int(y1),int(x2-x1),int(y2-y1)]
       
  
def box_to_square(bbox):
  box=list(bbox)
  if box[2]>box[3]:
    diff = box[2]-box[3]
    box[3]+=diff
    box[1]-=diff/2
  elif box[3]>box[2]:
    diff = box[3]-box[2]
    box[2]+=diff
    box[0]-=diff/2
  return box


def face_landmark_result_to_box(face_landmarker_result,width=1024,height=1024):
  face_landmarks_list = face_landmarker_result.face_landmarks


  full_indices  = list(range(456))

  MIDDLE_FOREHEAD = 151
  BOTTOM_CHIN_EX = 152
  BOTTOM_CHIN = 175
  CHIN_TO_MIDDLE_FOREHEAD = [200,14,1,6,18,9]
  MOUTH_BOTTOM = [202,200,422]
  EYEBROW_CHEEK_LEFT_RIGHT = [46,226,50,1,280,446,276]

  LEFT_HEAD_OUTER_EX = 251  #on side face almost same as full
  LEFT_HEAD_OUTER = 301
  LEFT_EYE_OUTER_EX = 356
  LEFT_EYE_OUTER = 264
  LEFT_MOUTH_OUTER_EX = 288
  LEFT_MOUTH_OUTER = 288
  LEFT_CHIN_OUTER = 435
  RIGHT_HEAD_OUTER_EX = 21
  RIGHT_HEAD_OUTER = 71
  RIGHT_EYE_OUTER_EX = 127
  RIGHT_EYE_OUTER = 34
  RIGHT_MOUTH_OUTER_EX = 58
  RIGHT_MOUTH_OUTER = 215
  RIGHT_CHIN_OUTER = 150

  # TODO naming line
  min_indices=CHIN_TO_MIDDLE_FOREHEAD+EYEBROW_CHEEK_LEFT_RIGHT+MOUTH_BOTTOM

  chin_to_brow_indices = [LEFT_CHIN_OUTER,LEFT_MOUTH_OUTER,LEFT_EYE_OUTER,LEFT_HEAD_OUTER,MIDDLE_FOREHEAD,RIGHT_HEAD_OUTER,RIGHT_EYE_OUTER,RIGHT_MOUTH_OUTER,RIGHT_CHIN_OUTER,BOTTOM_CHIN]+min_indices
  
  box1 = convert_to_box(face_landmarks_list,min_indices,width,height)
  box2 = convert_to_box(face_landmarks_list,chin_to_brow_indices,width,height)
  box3 = convert_to_box(face_landmarks_list,full_indices,width,height)
  #print(box)

  return [box1,box2,box3,box_to_square(box1),box_to_square(box2),box_to_square(box3)]


def draw_landmarks_on_image(detection_result,rgb_image):
  face_landmarks_list = detection_result.face_landmarks
  annotated_image = np.copy(rgb_image)

  # Loop through the detected faces to visualize.
  for idx in range(len(face_landmarks_list)):
    face_landmarks = face_landmarks_list[idx]

    # Draw the face landmarks.
    face_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
    face_landmarks_proto.landmark.extend([
      landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in face_landmarks
    ])

    solutions.drawing_utils.draw_landmarks(
        image=annotated_image,
        landmark_list=face_landmarks_proto,
        connections=mp.solutions.face_mesh.FACEMESH_TESSELATION,
        landmark_drawing_spec=None,
        connection_drawing_spec=mp.solutions.drawing_styles
        .get_default_face_mesh_tesselation_style())
    
  return annotated_image

def mediapipe_to_box(image_data,model_path="face_landmarker.task"):
  BaseOptions = mp.tasks.BaseOptions
  FaceLandmarker = mp.tasks.vision.FaceLandmarker
  FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions
  VisionRunningMode = mp.tasks.vision.RunningMode

  options = FaceLandmarkerOptions(
      base_options=BaseOptions(model_asset_path=model_path),
      running_mode=VisionRunningMode.IMAGE
      ,min_face_detection_confidence=0, min_face_presence_confidence=0
      )


  with FaceLandmarker.create_from_options(options) as landmarker:
    if isinstance(image_data,str):
        mp_image = mp.Image.create_from_file(image_data)
    else:
        mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(image_data))
    face_landmarker_result = landmarker.detect(mp_image)
    boxes = face_landmark_result_to_box(face_landmarker_result,mp_image.width,mp_image.height)
    return boxes,mp_image,face_landmarker_result