import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

class Experience:
    def __init__(self):
        self.state: list[int] = None
        self.action: int = None
        self.reward: float = None
        self.next_state: list[int] = None
        self.done: bool = False

class TetrisNet(nn.Module):
    """
    The PyTorch neural network equivalent to your Keras model:
    Input: 16-dimensional board
    Hidden layers: 64 -> 64 -> 32, ReLU activation
    Output: 4-dimensional, linear
    """
    def __init__(self):
        super(TetrisNet, self).__init__()
        self.layer1 = nn.Linear(16, 64)
        self.layer2 = nn.Linear(64, 64)
        self.layer3 = nn.Linear(64, 32)
        self.output = nn.Linear(32, 4)
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.relu(self.layer1(x))
        x = self.relu(self.layer2(x))
        x = self.relu(self.layer3(x))
        x = self.output(x)
        return x

class TetrisAI:
    """
    PyTorch implementation of the TetrisAI class.
    - Loads a saved model if save_file_path is provided.
    - Otherwise, constructs a fresh model.
    - Has methods to save, predict, and train the model.
    """

    def __init__(self, save_file_path: str = None):
        # Create the model
        self.model = TetrisNet()

        # Define the optimizer and loss function
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.003)
        self.criterion = nn.MSELoss()

        # Load from file if path is provided
        if save_file_path is not None:
            checkpoint = torch.load(save_file_path, map_location=torch.device('cpu'))
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.model.eval()

    def save(self, path: str) -> None:
        """
        Saves the PyTorch model and optimizer state to a file.
        """
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict()
        }, path)

    def predict(self, board: list[int]) -> list[float]:
        """
        Performs a forward pass to predict the Q-values for each possible move.
        Returns these Q-values as a list of floats.
        """
        # Convert board to a float tensor with shape [1, 16]
        x = torch.tensor([board], dtype=torch.float32)

        # Put model in evaluation mode and disable gradient tracking
        self.model.eval()
        with torch.no_grad():
            prediction = self.model(x)

        # Convert the single batch output (shape [1, 4]) to a Python list of floats
        return prediction[0].tolist()

    def train(self, board: list[int], qvalues: list[float]) -> None:
        """
        Trains the model on one step using the given board as input and qvalues as the desired output.
        """
        # Put model in training mode
        self.model.train()

        # Convert data to tensors
        x = torch.tensor([board], dtype=torch.float32)
        y = torch.tensor([qvalues], dtype=torch.float32)

        # Zero the parameter gradients
        self.optimizer.zero_grad()

        # Forward + Backward + Optimize
        predictions = self.model(x)
        loss = self.criterion(predictions, y)
        loss.backward()
        self.optimizer.step()