UnityGiles's picture
update to inference engine
d6bd9e7
raw
history blame
2.63 kB
using Unity.InferenceEngine;
using UnityEngine;
public class RunMobileNet : MonoBehaviour
{
public ModelAsset modelAsset;
//The image to classify here:
public Texture2D inputImage;
//Link class_desc.txt here:
public TextAsset labelsAsset;
//The input tensor
Tensor<float> input = new Tensor<float>(new TensorShape(1, 3, 224, 224));
const BackendType backend = BackendType.GPUCompute;
Worker worker;
string[] labels;
//Used to normalise the input RGB values
Tensor<float> mulRGB = new Tensor<float>(new TensorShape(1, 3, 1, 1), new[] { 1 / 0.229f, 1 / 0.224f, 1 / 0.225f });
Tensor<float> shiftRGB = new Tensor<float>(new TensorShape(1, 3, 1, 1), new[] { 0.485f, 0.456f, 0.406f });
void Start()
{
//Parse neural net labels
labels = labelsAsset.text.Split('\n');
//Load model from asset
var model = ModelLoader.Load(modelAsset);
//We modify the model to normalise the input RGB values and select the highest prediction
//probability and item number
var graph = new FunctionalGraph();
var image = graph.AddInput(model, 0);
var normalizedInput = (image - Functional.Constant(shiftRGB)) * Functional.Constant(mulRGB);
var probability = Functional.Forward(model, normalizedInput)[0];
var value = Functional.ReduceMax(probability, 1);
var index = Functional.ArgMax(probability, 1);
graph.AddOutput(value, "value");
graph.AddOutput(index, "index");
var model2 = graph.Compile();
//Set up the worker to run the model
worker = new Worker(model2, backend);
//Execute inference
ExecuteML();
}
public void ExecuteML()
{
//Preprocess image for input
TextureConverter.ToTensor(inputImage, input);
//Schedule neural net
worker.Schedule(input);
//Read output tensors
using var value = (worker.PeekOutput("value") as Tensor<float>).ReadbackAndClone();
using var index = (worker.PeekOutput("index") as Tensor<int>).ReadbackAndClone();
//Select the best output class and print the results
var accuracy = value[0];
var ID = index[0];
//The result is output to the console window
int percent = Mathf.FloorToInt(accuracy * 100f + 0.5f);
Debug.Log($"Prediction: {labels[ID]} {percent}﹪");
//Clean memory
Resources.UnloadUnusedAssets();
}
void OnDestroy()
{
input?.Dispose();
mulRGB?.Dispose();
shiftRGB?.Dispose();
worker?.Dispose();
}
}