using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.Sentis;
using System.IO;
using Lays = Unity.Sentis.Layers;

/*
 *      Neural Cellular Automata Inference Code
 *      =======================================
 *      
 * Put this script on the Main Camera
 * Create an image or quad in the scene.
 * Assign an unlit transparent material to the image/quad.
 * Draw the same material into the outputMaterial field
 * 
 */
public class RunAutomata : MonoBehaviour
{
    //Change this to load a different model:
    public AutomataNames automataName = AutomataNames.Poop;

    //Reduce this to make it run slower
    [Range(0f, 1f)]
    public float stepSize = 1.0f;

    const BackendType backend = BackendType.GPUCompute;

    //Drag your unlit transparent material here for drawing the output
    public Material outputMaterial;

    //optional material for average alpha 
    public Material avgAlphaMaterial;

    public enum AutomataNames { Lizard, Turtle ,Poop};
   
    //Model parameters
    const int trainedResolution = 40;
    const int trainedPool = 16;
    const int alphaBlocks = 4;
    int m_paddedImageSize;
    int m_trainedHiddenStates;

    //Workers to run the networks
    private IWorker m_WorkerStateUpdate;
    private IWorker m_WorkerClip;

    private TensorFloat m_currentStateTensor;
    private RenderTexture m_currentStateTexture;
    private RenderTexture m_currentBlockAlphaStateTexture;

    Ops m_ops;
    ITensorAllocator m_allocator;

    void Start()
    {
        m_allocator = new TensorCachingAllocator();
        m_ops = WorkerFactory.CreateOps(backend, m_allocator);

        Application.targetFrameRate = 60;

        LoadAutomataModel();

        CreateProcessingModel();

        SetupState();

        SetupTextures();

        DrawDotAt(m_paddedImageSize / 2, m_paddedImageSize / 2);
    }
    void LoadAutomataModel() {

        Model m_ModelStateUpdate = null;

        switch (automataName) {
            case AutomataNames.Lizard:
                m_ModelStateUpdate = ModelLoader.Load(Application.streamingAssetsPath + "/lizard.sentis");
                break;
            case AutomataNames.Turtle:
                m_ModelStateUpdate = ModelLoader.Load(Application.streamingAssetsPath + "/turtle.sentis");
                break;
            case AutomataNames.Poop:
                m_ModelStateUpdate = ModelLoader.Load(Application.streamingAssetsPath + "/poop.sentis");
                break;
        }
        m_trainedHiddenStates = m_ModelStateUpdate.inputs[0].shape[3].value;

        m_paddedImageSize = trainedResolution + trainedPool * 2;

        m_WorkerStateUpdate = WorkerFactory.CreateWorker(backend, m_ModelStateUpdate, false);

    }
    void CreateProcessingModel() { 

        var m_Model = new Model();

        var input0 = new Model.Input
        {
            name = "input0",
            shape = (new SymbolicTensorShape(1, m_trainedHiddenStates, m_paddedImageSize, m_paddedImageSize)),
            dataType=DataType.Float
        };

        var input1 = new Model.Input
        {
            name = "input1",
            shape = (new SymbolicTensorShape(1, m_trainedHiddenStates, m_paddedImageSize, m_paddedImageSize)),
            dataType = DataType.Float
        };

        var inputStepSize = new Model.Input
        {
            name = "inputStepSize",
            shape = new SymbolicTensorShape(1, 1, 1, 1),
            dataType = DataType.Float
        };

        m_Model.inputs.Add(input0);
        m_Model.inputs.Add(input1);
        m_Model.inputs.Add(inputStepSize);

        m_Model.AddConstant(new Lays.Constant("aliveRate", new TensorFloat(new TensorShape(1, 1, 1, 1), new[] { 0.1f })));

        m_Model.AddConstant(new Lays.Constant("sliceStarts", new int[] { 0, 3, 0, 0 }));
        m_Model.AddConstant(new Lays.Constant("sliceEnds", new[] { 1, 4 ,m_paddedImageSize, m_paddedImageSize }));

        m_Model.AddLayer(new Lays.Slice("sliceI0", "input0", "sliceStarts", "sliceEnds"));
        m_Model.AddLayer(new Lays.MaxPool("maxpool0", "sliceI0", new[] { 3, 3 }, new[] { 1, 1 }, new[] { 1, 1, 1, 1 }));
        m_Model.AddLayer(new Lays.Greater("pre_life_mask", "maxpool0", "aliveRate")); //INT
        
        m_Model.AddLayer(new Lays.Mul("input1_stepsize",  "input1", "inputStepSize" ));
        
        m_Model.AddLayer(new Lays.RandomUniform("random", new int[] { 1, 1, m_paddedImageSize, m_paddedImageSize}, 0.0f, 1.0f, 0));
        m_Model.AddConstant(new Lays.Constant("fireRate", new TensorFloat(new TensorShape(1, 1, 1, 1), new[] { 0.5f })));
        m_Model.AddLayer(new Lays.LessOrEqual("lessEqualFireRateINT", "random", "fireRate"));

        m_Model.AddLayer(new Lays.Cast("lessEqualFireRate", "lessEqualFireRateINT", DataType.Float));
        
        m_Model.AddLayer(new Lays.Mul("mul", "input1_stepsize", "lessEqualFireRate" ));
        
        m_Model.AddLayer(new Lays.Add("add", "input0", "mul" ));
        
        m_Model.AddLayer(new Lays.Slice("sliceI1", "add", "sliceStarts", "sliceEnds"));
        m_Model.AddLayer(new Lays.MaxPool("maxpool1", "sliceI1", new [] { 3 ,3 }, new[] { 1, 1 }, new[] {1, 1, 1, 1}));
        m_Model.AddLayer(new Lays.Greater("post_life_mask", "maxpool1", "aliveRate"));

        
        m_Model.AddLayer(new Lays.And("andINT", "pre_life_mask", "post_life_mask"));
        m_Model.AddLayer(new Lays.Cast("and", "andINT", DataType.Float));
        
        m_Model.AddLayer(new Lays.Mul("outputState", "add", "and" ));

        m_Model.AddConstant(new Lays.Constant("sliceStarts2", new[] { 0, 0, trainedPool, trainedPool }));
        m_Model.AddConstant(new Lays.Constant("sliceEnds2", new[] { 1, 4, m_paddedImageSize - trainedPool, m_paddedImageSize - trainedPool }));

        m_Model.AddLayer(new Lays.Slice("outputImage", "outputState", "sliceStarts2", "sliceEnds2"));
        
        m_Model.AddLayer(new Lays.Slice("outputIC", "outputImage", "sliceStarts", "sliceEnds"));

        int blockSize = trainedResolution / alphaBlocks;
        m_Model.AddLayer(new Lays.AveragePool("avgPoolBlocks", "outputIC", new[] { blockSize, blockSize }, new[] { blockSize, blockSize }, new[] { 1, 1, 1, 1 }));

        m_Model.outputs.Add("outputState");
        m_Model.outputs.Add("outputImage");
        m_Model.outputs.Add("avgPoolBlocks");
        
        m_WorkerClip = WorkerFactory.CreateWorker(BackendType.GPUCompute, m_Model);

    }

    void SetupState()
    {
        float[] data = new float[1 * m_paddedImageSize * m_paddedImageSize * m_trainedHiddenStates];
        m_currentStateTensor = new TensorFloat(new TensorShape(1, m_trainedHiddenStates, m_paddedImageSize, m_paddedImageSize), data);
    }

    void SetupTextures()
    {
        m_currentStateTexture = new RenderTexture(trainedResolution, trainedResolution, 0)
        {
            enableRandomWrite = true
        };
        outputMaterial.mainTexture = m_currentStateTexture;

        if (avgAlphaMaterial)
        {
            m_currentBlockAlphaStateTexture = new RenderTexture(alphaBlocks, alphaBlocks, 0)
            {
                enableRandomWrite = true
            };
            outputMaterial.mainTexture = m_currentBlockAlphaStateTexture;
        }
    }

    void DrawDotAt(int x,int y)
    {
        m_currentStateTensor.MakeReadable();

        float[] data = m_currentStateTensor.ToReadOnlyArray();
        for (int k = 3; k < 16; k++)
        {
            data[m_paddedImageSize * m_paddedImageSize * k + m_paddedImageSize * y + x] = 1f;
        }
        Replace(ref m_currentStateTensor, new TensorFloat(m_currentStateTensor.shape, data));
    }

    void Update()
    {
        DoInference();
        if (Input.GetKeyDown(KeyCode.Escape))
        {
            Application.Quit();
        }
        if (Input.GetKeyDown(KeyCode.Space))
        {
            DrawDotAt(UnityEngine.Random.Range(0, m_paddedImageSize), UnityEngine.Random.Range(0, m_paddedImageSize));
        }
    }

    void Replace(ref TensorFloat A, TensorFloat B)
    {
        A?.Dispose();
        A = B;
    }

    void DoInference() {

        using var stepSizeTensor = new TensorFloat(new TensorShape(1, 1, 1, 1), new float[] { stepSize });

        using var currentStateTensorT = m_ops.Transpose(m_currentStateTensor, new int[] { 0, 2, 3, 1 });

        m_WorkerStateUpdate.Execute(currentStateTensorT);
        TensorFloat outputStateT = m_WorkerStateUpdate.PeekOutput() as TensorFloat;

        using var outputState = m_ops.Transpose(outputStateT, new int[] { 0, 3, 1, 2 });

        var inputs = new Dictionary<string, Tensor>() { 
            { "input0", m_currentStateTensor }, //float
            { "input1", outputState }, //float
            { "inputStepSize", stepSizeTensor }  //float
        };
        m_WorkerClip.Execute(inputs);

        TensorFloat clippedState = m_WorkerClip.PeekOutput("outputState") as TensorFloat;
        TensorFloat outputImage = m_WorkerClip.PeekOutput("outputImage") as TensorFloat;
        TensorFloat blockAvgAlphaState = m_WorkerClip.PeekOutput("avgPoolBlocks") as TensorFloat;

        if (m_currentStateTexture)
        {
            TextureConverter.RenderToTexture(outputImage, m_currentStateTexture);
        }

        if (m_currentBlockAlphaStateTexture)
        {
            TextureConverter.RenderToTexture(blockAvgAlphaState, m_currentBlockAlphaStateTexture);
        }

        Replace(ref m_currentStateTensor, clippedState);
        m_currentStateTensor.TakeOwnership();
    }

    void OnDestroy()
    {
        m_currentStateTensor.Dispose();

        m_WorkerStateUpdate.Dispose();
        m_WorkerClip.Dispose();

        m_ops?.Dispose();
        m_allocator?.Dispose();
    }
   
}