import base64
from io import BytesIO
import json
import os
from typing import Any, Dict, List, Optional, Tuple, Union

from .custom_st_2 import OtherClass
import requests
import torch
from torch import nn
from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoImageProcessor
from PIL import Image

OtherClass()

class Transformer(nn.Module):
    """Huggingface AutoModel to generate token embeddings.
    Loads the correct class, e.g. BERT / RoBERTa etc.

    Args:
        model_name_or_path: Huggingface models name
            (https://huggingface.co/models)
        max_seq_length: Truncate any inputs longer than max_seq_length
        model_args: Keyword arguments passed to the Huggingface
            Transformers model
        tokenizer_args: Keyword arguments passed to the Huggingface
            Transformers tokenizer
        config_args: Keyword arguments passed to the Huggingface
            Transformers config
        cache_dir: Cache dir for Huggingface Transformers to store/load
            models
        do_lower_case: If true, lowercases the input (independent if the
            model is cased or not)
        tokenizer_name_or_path: Name or path of the tokenizer. When
            None, then model_name_or_path is used
    """

    def __init__(
        self,
        model_name_or_path: str,
        max_seq_length: Optional[int] = None,
        model_args: Optional[Dict[str, Any]] = None,
        tokenizer_args: Optional[Dict[str, Any]] = None,
        config_args: Optional[Dict[str, Any]] = None,
        cache_dir: Optional[str] = None,
        do_lower_case: bool = False,
        tokenizer_name_or_path: str = None,
    ) -> None:
        super(Transformer, self).__init__()
        self.config_keys = ["max_seq_length", "do_lower_case"]
        self.do_lower_case = do_lower_case
        if model_args is None:
            model_args = {}
        if tokenizer_args is None:
            tokenizer_args = {}
        if config_args is None:
            config_args = {}

        config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
        self.jina_clip = AutoModel.from_pretrained(
            model_name_or_path, config=config, cache_dir=cache_dir, **model_args
        )

        if max_seq_length is not None and "model_max_length" not in tokenizer_args:
            tokenizer_args["model_max_length"] = max_seq_length
        self.tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path,
            cache_dir=cache_dir,
            **tokenizer_args,
        )
        self.preprocessor = AutoImageProcessor.from_pretrained(
            tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path,
            cache_dir=cache_dir,
            **tokenizer_args,
        )

        # No max_seq_length set. Try to infer from model
        if max_seq_length is None:
            if (
                hasattr(self.jina_clip, "config")
                and hasattr(self.jina_clip.config, "max_position_embeddings")
                and hasattr(self.tokenizer, "model_max_length")
            ):
                max_seq_length = min(self.jina_clip.config.max_position_embeddings, self.tokenizer.model_max_length)

        self.max_seq_length = max_seq_length

        if tokenizer_name_or_path is not None:
            self.jina_clip.config.tokenizer_class = self.tokenizer.__class__.__name__

    def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Returns token_embeddings, cls_token"""
        if "input_ids" in features:
            embedding = self.jina_clip.get_text_features(input_ids=features["input_ids"])
        else:
            embedding = self.jina_clip.get_image_features(pixel_values=features["pixel_values"])
        return {"sentence_embedding": embedding}

    def get_word_embedding_dimension(self) -> int:
        return self.config.text_config.embed_dim

    def decode_data_image(data_image_str):
        header, data = data_image_str.split(',', 1)
        image_data = base64.b64decode(data)
        return Image.open(BytesIO(image_data))

    def tokenize(
        self, batch: Union[List[str]], padding: Union[str, bool] = True
    ) -> Dict[str, torch.Tensor]:
        """Tokenizes a text and maps tokens to token-ids"""
        images = []
        texts = []
        for sample in batch:
            if isinstance(sample, str):
                if sample.startswith('http'):
                    response = requests.get(sample)
                    images.append(Image.open(BytesIO(response.content)).convert('RGB'))
                elif sample.startswith('data:image/'):
                    images.append(self.decode_data_image(sample).convert('RGB'))
                else:
                    # TODO: Make sure that Image.open fails for non-image files
                    try:
                        images.append(Image.open(sample).convert('RGB'))
                    except:
                        texts.append(sample)
            elif isinstance(sample, Image.Image):
                images.append(sample.convert('RGB'))

        if images and texts:
            raise ValueError('Batch must contain either images or texts, not both')

        if texts:
            return self.tokenizer(
                texts,
                padding=padding,
                truncation="longest_first",
                return_tensors="pt",
                max_length=self.max_seq_length,
            )
        elif images:
            return self.preprocessor(images)
        return {}

    def save(self, output_path: str, safe_serialization: bool = True) -> None:
        self.jina_clip.save_pretrained(output_path, safe_serialization=safe_serialization)
        self.tokenizer.save_pretrained(output_path)
        self.preprocessor.save_pretrained(output_path)

    @staticmethod
    def load(input_path: str) -> "Transformer":
        # Old classes used other config names than 'sentence_bert_config.json'
        for config_name in [
            "sentence_bert_config.json",
            "sentence_roberta_config.json",
            "sentence_distilbert_config.json",
            "sentence_camembert_config.json",
            "sentence_albert_config.json",
            "sentence_xlm-roberta_config.json",
            "sentence_xlnet_config.json",
        ]:
            sbert_config_path = os.path.join(input_path, config_name)
            if os.path.exists(sbert_config_path):
                break

        with open(sbert_config_path) as fIn:
            config = json.load(fIn)
        # Don't allow configs to set trust_remote_code
        if "model_args" in config and "trust_remote_code" in config["model_args"]:
            config["model_args"].pop("trust_remote_code")
        if "tokenizer_args" in config and "trust_remote_code" in config["tokenizer_args"]:
            config["tokenizer_args"].pop("trust_remote_code")
        if "config_args" in config and "trust_remote_code" in config["config_args"]:
            config["config_args"].pop("trust_remote_code")
        return Transformer(model_name_or_path=input_path, **config)