# Copyright 2022 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for creating a tf.train.Example proto of image triplets."""

import io
import os
from typing import Any, List, Mapping, Optional

from absl import logging
import apache_beam as beam
import numpy as np
import PIL.Image
import six
from skimage import transform
import tensorflow as tf

_UINT8_MAX_F = float(np.iinfo(np.uint8).max)
_GAMMA = 2.2


def _resample_image(image: np.ndarray, resample_image_width: int,
                    resample_image_height: int) -> np.ndarray:
  """Re-samples and returns an `image` to be `resample_image_size`."""
  # Convert image from uint8 gamma [0..255] to float linear [0..1].
  image = image.astype(np.float32) / _UINT8_MAX_F
  image = np.power(np.clip(image, 0, 1), _GAMMA)

  # Re-size the image
  resample_image_size = (resample_image_height, resample_image_width)
  image = transform.resize_local_mean(image, resample_image_size)

  # Convert back from float linear [0..1] to uint8 gamma [0..255].
  image = np.power(np.clip(image, 0, 1), 1.0 / _GAMMA)
  image = np.clip(image * _UINT8_MAX_F + 0.5, 0.0,
                  _UINT8_MAX_F).astype(np.uint8)
  return image


def generate_image_triplet_example(
    triplet_dict: Mapping[str, str],
    scale_factor: int = 1,
    center_crop_factor: int = 1) -> Optional[tf.train.Example]:
  """Generates and serializes a tf.train.Example proto from an image triplet.

  Default setting creates a triplet Example with the input images unchanged.
  Images are processed in the order of center-crop then downscale.

  Args:
    triplet_dict: A dict of image key to filepath of the triplet images.
    scale_factor: An integer scale factor to isotropically downsample images.
    center_crop_factor: An integer cropping factor to center crop images with
      the original resolution but isotropically downsized by the factor.

  Returns:
    tf.train.Example proto, or None upon error.

  Raises:
    ValueError if triplet_dict length is different from three or the scale input
    arguments are non-positive.
  """
  if len(triplet_dict) != 3:
    raise ValueError(
        f'Length of triplet_dict must be exactly 3, not {len(triplet_dict)}.')

  if scale_factor <= 0 or center_crop_factor <= 0:
    raise ValueError(f'(scale_factor, center_crop_factor) must be positive, '
                     f'Not ({scale_factor}, {center_crop_factor}).')

  feature = {}

  # Keep track of the path where the images came from for debugging purposes.
  mid_frame_path = os.path.dirname(triplet_dict['frame_1'])
  feature['path'] = tf.train.Feature(
      bytes_list=tf.train.BytesList(value=[six.ensure_binary(mid_frame_path)]))

  for image_key, image_path in triplet_dict.items():
    if not tf.io.gfile.exists(image_path):
      logging.error('File not found: %s', image_path)
      return None

    # Note: we need both the raw bytes and the image size.
    # PIL.Image does not expose a method to grab the original bytes.
    # (Also it is not aware of non-local file systems.)
    # So we read with tf.io.gfile.GFile to get the bytes, and then wrap the
    # bytes in BytesIO to let PIL.Image open the image.
    try:
      byte_array = tf.io.gfile.GFile(image_path, 'rb').read()
    except tf.errors.InvalidArgumentError:
      logging.exception('Cannot read image file: %s', image_path)
      return None
    try:
      pil_image = PIL.Image.open(io.BytesIO(byte_array))
    except PIL.UnidentifiedImageError:
      logging.exception('Cannot decode image file: %s', image_path)
      return None
    width, height = pil_image.size
    pil_image_format = pil_image.format

    # Optionally center-crop images and downsize images
    # by `center_crop_factor`.
    if center_crop_factor > 1:
      image = np.array(pil_image)
      quarter_height = image.shape[0] // (2 * center_crop_factor)
      quarter_width = image.shape[1] // (2 * center_crop_factor)
      image = image[quarter_height:-quarter_height,
                    quarter_width:-quarter_width, :]
      pil_image = PIL.Image.fromarray(image)

      # Update image properties.
      height, width, _ = image.shape
      buffer = io.BytesIO()
      try:
        pil_image.save(buffer, format='PNG')
      except OSError:
        logging.exception('Cannot encode image file: %s', image_path)
        return None
      byte_array = buffer.getvalue()

    # Optionally downsample images by `scale_factor`.
    if scale_factor > 1:
      image = np.array(pil_image)
      image = _resample_image(image, image.shape[1] // scale_factor,
                              image.shape[0] // scale_factor)
      pil_image = PIL.Image.fromarray(image)

      # Update image properties.
      height, width, _ = image.shape
      buffer = io.BytesIO()
      try:
        pil_image.save(buffer, format='PNG')
      except OSError:
        logging.exception('Cannot encode image file: %s', image_path)
        return None
      byte_array = buffer.getvalue()

    # Create tf Features.
    image_feature = tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[byte_array]))
    height_feature = tf.train.Feature(
        int64_list=tf.train.Int64List(value=[height]))
    width_feature = tf.train.Feature(
        int64_list=tf.train.Int64List(value=[width]))
    encoding = tf.train.Feature(
        bytes_list=tf.train.BytesList(
            value=[six.ensure_binary(pil_image_format.lower())]))

    # Update feature map.
    feature[f'{image_key}/encoded'] = image_feature
    feature[f'{image_key}/format'] = encoding
    feature[f'{image_key}/height'] = height_feature
    feature[f'{image_key}/width'] = width_feature

  # Create tf Example.
  features = tf.train.Features(feature=feature)
  example = tf.train.Example(features=features)
  return example


class ExampleGenerator(beam.DoFn):
  """Generate a tf.train.Example per input image triplet filepaths."""

  def __init__(self,
               images_map: Mapping[str, Any],
               scale_factor: int = 1,
               center_crop_factor: int = 1):
    """Initializes the map of 3 images to add to each tf.train.Example.

    Args:
      images_map: Map from image key to image filepath.
      scale_factor: A scale factor to downsample frames.
      center_crop_factor: A factor to centercrop and downsize frames.
    """
    super().__init__()
    self._images_map = images_map
    self._scale_factor = scale_factor
    self._center_crop_factor = center_crop_factor

  def process(self, triplet_dict: Mapping[str, str]) -> List[bytes]:
    """Generates a serialized tf.train.Example for a triplet of images.

    Args:
      triplet_dict: A dict of image key to filepath of the triplet images.

    Returns:
      A serialized tf.train.Example proto. No shuffling is applied.
    """
    example = generate_image_triplet_example(triplet_dict, self._scale_factor,
                                             self._center_crop_factor)
    if example:
      return [example.SerializeToString()]
    else:
      return []