import os
import json
import tempfile
import uuid
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import pyarrow as pa
import pyarrow.parquet as pq
from huggingface_hub import CommitScheduler
from huggingface_hub.hf_api import HfApi

###################################
# Parquet scheduler               #
# Uploads data in parquet format  #
###################################


class ParquetScheduler(CommitScheduler):
    """
    Usage: configure the scheduler with a repo id. Once started, you can add data to be uploaded to the Hub. 1 `.append`
    call will result in 1 row in your final dataset.

    ```py
    # Start scheduler
    >>> scheduler = ParquetScheduler(repo_id="my-parquet-dataset")

    # Append some data to be uploaded
    >>> scheduler.append({...})
    >>> scheduler.append({...})
    >>> scheduler.append({...})
    ```

    The scheduler will automatically infer the schema from the data it pushes.
    Optionally, you can manually set the schema yourself:

    ```py
    >>> scheduler = ParquetScheduler(
    ...     repo_id="my-parquet-dataset",
    ...     schema={
    ...         "prompt": {"_type": "Value", "dtype": "string"},
    ...         "negative_prompt": {"_type": "Value", "dtype": "string"},
    ...         "guidance_scale": {"_type": "Value", "dtype": "int64"},
    ...         "image": {"_type": "Image"},
    ...     },
    ... )

    See https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Value for the list of
    possible values.
    """

    def __init__(
        self,
        *,
        repo_id: str,
        schema: Optional[Dict[str, Dict[str, str]]] = None,
        every: Union[int, float] = 5,
        path_in_repo: Optional[str] = "data",
        repo_type: Optional[str] = "dataset",
        revision: Optional[str] = None,
        private: bool = False,
        token: Optional[str] = None,
        allow_patterns: Union[List[str], str, None] = None,
        ignore_patterns: Union[List[str], str, None] = None,
        hf_api: Optional[HfApi] = None,
    ) -> None:
        super().__init__(
            repo_id=repo_id,
            folder_path="dummy",  # not used by the scheduler
            every=every,
            path_in_repo=path_in_repo,
            repo_type=repo_type,
            revision=revision,
            private=private,
            token=token,
            allow_patterns=allow_patterns,
            ignore_patterns=ignore_patterns,
            hf_api=hf_api,
        )

        self._rows: List[Dict[str, Any]] = []
        self._schema = schema

    def append(self, row: Dict[str, Any]) -> None:
        """Add a new item to be uploaded."""
        with self.lock:
            self._rows.append(row)

    def push_to_hub(self):
        # Check for new rows to push
        with self.lock:
            rows = self._rows
            self._rows = []
        if not rows:
            return
        print(f"Got {len(rows)} item(s) to commit.")

        # Load images + create 'features' config for datasets library
        schema: Dict[str, Dict] = self._schema or {}
        path_to_cleanup: List[Path] = []
        for row in rows:
            for key, value in row.items():
                # Infer schema (for `datasets` library)
                if key not in schema:
                    schema[key] = _infer_schema(key, value)

                # Load binary files if necessary
                if schema[key]["_type"] in ("Image", "Audio"):
                    # It's an image or audio: we load the bytes and remember to cleanup the file
                    file_path = Path(value)
                    if file_path.is_file():
                        row[key] = {
                            "path": file_path.name,
                            "bytes": file_path.read_bytes(),
                        }
                        path_to_cleanup.append(file_path)

        # Complete rows if needed
        for row in rows:
            for feature in schema:
                if feature not in row:
                    row[feature] = None

        # Export items to Arrow format
        table = pa.Table.from_pylist(rows)

        # Add metadata (used by datasets library)
        table = table.replace_schema_metadata(
            {"huggingface": json.dumps({"info": {"features": schema}})}
        )

        # Write to parquet file
        archive_file = tempfile.NamedTemporaryFile(delete=False)
        pq.write_table(table, archive_file.name)
        archive_file.close()

        # Upload
        self.api.upload_file(
            repo_id=self.repo_id,
            repo_type=self.repo_type,
            revision=self.revision,
            path_in_repo=f"{uuid.uuid4()}.parquet",
            path_or_fileobj=archive_file.name,
        )
        print("Commit completed.")

        # Cleanup
        os.unlink(archive_file.name)
        for path in path_to_cleanup:
            path.unlink(missing_ok=True)


def _infer_schema(key: str, value: Any) -> Dict[str, str]:
    """
    Infer schema for the `datasets` library.

    See https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Value.
    """
    # In short any column_name in the dataset has any of these keywords
    # the column will be inferred into the correct column type accordingly
    if "image" in key:
        return {"_type": "Image"}
    if "audio" in key:
        return {"_type": "Audio"}
    if isinstance(value, int):
        return {"_type": "Value", "dtype": "int64"}
    if isinstance(value, float):
        return {"_type": "Value", "dtype": "float64"}
    if isinstance(value, bool):
        return {"_type": "Value", "dtype": "bool"}
    if isinstance(value, bytes):
        return {"_type": "Value", "dtype": "binary"}
    # Otherwise in last resort => convert it to a string
    return {"_type": "Value", "dtype": "string"}