import torch.hub

from transformers import (
    CLIPVisionModel,
    CLIPVisionConfig,
    CLIPModel,
    CLIPProcessor,
    AutoTokenizer,
    CLIPTextModelWithProjection,
    CLIPTextConfig,
    CLIPVisionModelWithProjection,
    ResNetModel,
    ResNetConfig
)
from torch import nn

from PIL import Image
import requests


class CLIP(nn.Module):
    def __init__(self, path):
        """Initializes the CLIP model."""
        super().__init__()
        if path == "":
            config_vision = CLIPVisionConfig()
            self.clip = CLIPVisionModel(config_vision)
        else:
            self.clip = CLIPVisionModel.from_pretrained(path)

    def forward(self, x):
        """Predicts CLIP features from an image.
        Args:
            x (dict that contains "img": torch.Tensor): Input batch
        """
        features = self.clip(pixel_values=x["img"])["last_hidden_state"]
        return features


class CLIPJZ(nn.Module):
    def __init__(self, path):
        """Initializes the CLIP model."""
        super().__init__()
        if path == "":
            config_vision = CLIPVisionConfig()
            self.clip = CLIPVisionModel(config_vision)
        else:
            self.clip = CLIPVisionModel.from_pretrained(path)

    def forward(self, x):
        """Predicts CLIP features from an image.
        Args:
            x (dict that contains "img": torch.Tensor): Input batch
        """
        features = self.clip(pixel_values=x["img"])["last_hidden_state"]
        return features


class StreetCLIP(nn.Module):
    def __init__(self, path):
        """Initializes the CLIP model."""
        super().__init__()
        self.clip = CLIPModel.from_pretrained(path)
        self.transform = CLIPProcessor.from_pretrained(path)

    def forward(self, x):
        """Predicts CLIP features from an image.
        Args:
            x (dict that contains "img": torch.Tensor): Input batch
        """
        features = self.clip.get_image_features(
            **self.transform(images=x["img"], return_tensors="pt").to(x["gps"].device)
        ).unsqueeze(1)
        return features


class CLIPText(nn.Module):
    def __init__(self, path):
        """Initializes the CLIP model."""
        super().__init__()
        if path == "":
            config_vision = CLIPVisionConfig()
            self.clip = CLIPVisionModel(config_vision)
        else:
            self.clip = CLIPVisionModelWithProjection.from_pretrained(path)

    def forward(self, x):
        """Predicts CLIP features from an image.
        Args:
            x (dict that contains "img": torch.Tensor): Input batch
        """
        features = self.clip(pixel_values=x["img"])
        return features.image_embeds, features.last_hidden_state


class TextEncoder(nn.Module):
    def __init__(self, path):
        """Initializes the CLIP text model."""
        super().__init__()
        if path == "":
            config_vision = CLIPTextConfig()
            self.clip = CLIPTextModelWithProjection(config_vision)
            self.transform = AutoTokenizer()
        else:
            self.clip = CLIPTextModelWithProjection.from_pretrained(path)
            self.transform = AutoTokenizer.from_pretrained(path)
        for p in self.clip.parameters():
            p.requires_grad = False
        self.clip.eval()

    def forward(self, x):
        """Predicts CLIP features from text.
        Args:
            x (dict that contains "text": list): Input batch
        """
        features = self.clip(
            **self.transform(x["text"], padding=True, return_tensors="pt").to(
                x["gps"].device
            )
        ).text_embeds
        return features


class DINOv2(nn.Module):
    def __init__(self, tag) -> None:
        """Initializes the DINO model."""
        super().__init__()
        self.dino = torch.hub.load("facebookresearch/dinov2", tag)
        self.stride = 14  # ugly but dinov2 stride = 14

    def forward(self, x):
        """Predicts DINO features from an image."""
        x = x["img"]

        # crop for stride
        _, _, H, W = x.shape
        H_new = H - H % self.stride
        W_new = W - W % self.stride
        x = x[:, :, :H_new, :W_new]

        # forward features
        x = self.dino.forward_features(x)
        x = x["x_prenorm"]
        return x
    
class ResNet(nn.Module):
    def __init__(self, path):
        """Initializes the ResNet model."""
        super().__init__()
        if path == "":
            config_vision = ResNetConfig()
            self.resnet = ResNetModel(config_vision)
        else:
            self.resnet = ResNetModel.from_pretrained(path)

    def forward(self, x):
        """Predicts ResNet50 features from an image.
        Args:
            x (dict that contains "img": torch.Tensor): Input batch
        """
        features = self.resnet(x["img"])["pooler_output"]
        return features.squeeze()