from dataclasses import dataclass

from PIL import Image
import pandas as pd

from geoguessr_bot.guessr import AbstractGuessr
from geoguessr_bot.interfaces import Coordinate
from geoguessr_bot.retriever import AbstractImageEmbedder
from geoguessr_bot.retriever import Retriever


@dataclass
class NearestNeighborEmbedderGuessr(AbstractGuessr):
    """Guesses a coordinate using an Embedder and a retriever followed by NN.
    """
    embedder: AbstractImageEmbedder
    retriever: Retriever
    metadata_path: str

    def __post_init__(self):
        """Load metadata
        """
        metadata = pd.read_csv(self.metadata_path)
        self.image_to_coordinate = {
            image.split("/")[-1]: Coordinate(latitude=latitude, longitude=longitude)
            for image, latitude, longitude in zip(metadata["path"], metadata["latitude"], metadata["longitude"])
        }
        

    def guess(self, image: Image) -> Coordinate:
        """Guess a coordinate from an image
        """
        # Embed image
        image = Image.fromarray(image)
        image_embedding = self.embedder.embed(image)[None, :]
        
        # Retrieve nearest neighbor
        nearest_neighbors = self.retriever.retrieve(image_embedding)
        nearest_neighbor = nearest_neighbors[0][0][0]

        # Guess coordinate
        guess_coordinate = self.image_to_coordinate[nearest_neighbor]
        return guess_coordinate