using UnityEngine; using Unity.Sentis; using UnityEngine.Video; using UnityEngine.UI; using Lays = Unity.Sentis.Layers; using FF = Unity.Sentis.Functional; /* * Blaze Palm Inference * ==================== * * Basic inference script for blaze palm * * Put this script on the Main Camera * Put palm_detection_lite.sentis in the Assets/StreamingAssets folder * Create a RawImage of in the scene * Put a link to that image in previewUI * Put a video in Assets/StreamingAssets folder and put the name of it int videoName * Or put a test image in inputImage * Set inputType to appropriate input */ public class RunBlazePalm : MonoBehaviour { public ModelAsset asset; //Drag a link to a raw image here: public RawImage previewUI = null; // Put your bounding box sprite image here public Texture2D boundingBoxTexture; public Sprite boundingBoxSprite; // optional images for palm markers public Sprite[] markerTextures; public string videoName = "chatting.mp4"; // public Texture2D inputImage; public InputType inputType = InputType.Video; //Resolution of preview image/video Vector2Int resolution = new Vector2Int(640, 640); WebCamTexture webcam; VideoPlayer video; const BackendType backend = BackendType.GPUCompute; RenderTexture targetTexture; public enum InputType { Image, Video, Webcam }; //Some adjustable parameters for the model [SerializeField, Range(0, 1)] float iouThreshold = 0.5f; [SerializeField, Range(0, 1)] float scoreThreshold = 0.5f; const int maxOutputBoxes = 64; IWorker worker; //Holds image size int size; Model model; //webcam device name: const string deviceName = ""; bool closing = false; TensorFloat anchors, centersToCorners; public struct BoundingBox { public float centerX; public float centerY; public float width; public float height; } void Start() { //(Note: if using a webcam on mobile get permissions here first) targetTexture = new RenderTexture(resolution.x, resolution.y, 0); previewUI.texture = targetTexture; SetupInput(); SetupModel(); SetupEngine(); if (boundingBoxSprite == null) { boundingBoxSprite = Sprite.Create(boundingBoxTexture, new Rect(0, 0, boundingBoxTexture.width, boundingBoxTexture.height), new Vector2(boundingBoxTexture.width / 2, boundingBoxTexture.height / 2)); } } void SetupInput() { switch (inputType) { case InputType.Webcam: { webcam = new WebCamTexture(deviceName, resolution.x, resolution.y); webcam.requestedFPS = 30; webcam.Play(); break; } case InputType.Video: { video = gameObject.AddComponent();//new VideoPlayer(); video.renderMode = VideoRenderMode.APIOnly; video.source = VideoSource.Url; video.url = Application.streamingAssetsPath + "/"+videoName; video.isLooping = true; video.Play(); break; } default: { Graphics.Blit(inputImage, targetTexture); } break; } } void Update() { if (inputType == InputType.Webcam) { // Format video input if (!webcam.didUpdateThisFrame) return; var aspect1 = (float)webcam.width / webcam.height; var aspect2 = (float)resolution.x / resolution.y; var gap = aspect2 / aspect1; var vflip = webcam.videoVerticallyMirrored; var scale = new Vector2(gap, vflip ? -1 : 1); var offset = new Vector2((1 - gap) / 2, vflip ? 1 : 0); Graphics.Blit(webcam, targetTexture, scale, offset); } if (inputType == InputType.Video) { var aspect1 = (float)video.width / video.height; var aspect2 = (float)resolution.x / resolution.y; var gap = aspect2 / aspect1; var vflip = false; var scale = new Vector2(gap, vflip ? -1 : 1); var offset = new Vector2((1 - gap) / 2, vflip ? 1 : 0); Graphics.Blit(video.texture, targetTexture, scale, offset); } if (inputType == InputType.Image) { Graphics.Blit(inputImage, targetTexture); } if (Input.GetKeyDown(KeyCode.Escape)) { closing = true; Application.Quit(); } if (Input.GetKeyDown(KeyCode.P)) { previewUI.enabled = !previewUI.enabled; } } void LateUpdate() { if (!closing) { RunInference(targetTexture); } } float[] GetGridBoxCoords() { var offsets = new float[2016 * 4]; int n = 0; AddGrid(offsets, 24, 2, 8, ref n); AddGrid(offsets, 12, 6, 16, ref n); return offsets; } void AddGrid(float[] offsets, int rows, int repeats, int cellWidth, ref int n) { for (int j = 0; j < repeats * rows * rows; j++) { offsets[n++] = cellWidth * ((j / repeats) % rows - (rows - 1) * 0.5f); offsets[n++] = cellWidth * ((j / repeats / rows) - (rows - 1) * 0.5f); n += 2; } } void SetupModel() { float[] offsets = GetGridBoxCoords(); model = ModelLoader.Load(asset); //model = ModelLoader.Load(Application.streamingAssetsPath + "/palm_detection_lite.sentis"); //We need to add extra layers to the model in order to aggregate the box predicions: size = model.inputs[0].shape.ToTensorShape()[2]; // Input tensor width (192) anchors = new TensorFloat(new TensorShape(offsets.Length / 4, 4), offsets); centersToCorners = new TensorFloat(new TensorShape(4, 4), new float[] { 1, 0, 1, 0, 0, 1, 0, 1, -0.5f, 0, 0.5f, 0, 0, -0.5f, 0, 0.5f }); var model2 = Functional.Compile( input => { var outputs = model.Forward(input); var regressors = outputs[1][0]; //shape=(2016,18) var scores = outputs[0][0].Transpose(0, 1) - scoreThreshold; //shape=(1,2016) var boxCoords = regressors[.., 0..4] + FunctionalTensor.FromTensor(anchors); //(2016,4) var boxCorners = FF.MatMul(boxCoords, FunctionalTensor.FromTensor(centersToCorners)); var indices = FF.NMS(boxCoords, scores, iouThreshold); //shape=(N) var indices2 = indices.Unsqueeze(-1).BroadcastTo(new int[] { 4 }); //shape=(N,4) var output = FF.Gather(boxCoords, 0, indices2); //shape=(N,4) var indices3 = indices.Unsqueeze(-1).BroadcastTo(new int[] { 18 }); //shape=(N,18) var markersOutput = FF.Gather(regressors, 0, indices3); //shape=(N,18) return (output, markersOutput); }, InputDef.FromModel(model)[0] ); worker = WorkerFactory.CreateWorker(backend, model2); } public void SetupEngine() { } void DrawPalms(TensorFloat index3, TensorFloat regressors, int NMAX, Vector2 scale) { for (int n = 0; n < NMAX; n++) { //Draw bounding box of the palm var box = new BoundingBox { centerX = index3[n, 0] * scale.x, centerY = index3[n, 1] * scale.y, width = index3[n, 2] * scale.x, height = index3[n, 3] * scale.y }; DrawBox(box, boundingBoxSprite); if (regressors == null) continue; //Draw markers starts of fingers for (int j = 0; j < 7; j++) { var marker = new BoundingBox { centerX = box.centerX + (regressors[n, 4 + j * 2] - regressors[n, 0]) * scale.x, centerY = box.centerY + (regressors[n, 4 + j * 2 + 1] - regressors[n, 1]) * scale.y, width = 4f * scale.x, height = 4f * scale.y, }; DrawBox(marker, j < markerTextures.Length ? markerTextures[j] : boundingBoxSprite); } } } void RunInference(Texture source) { var transform = new TextureTransform(); transform.SetDimensions(size, size, 3); transform.SetTensorLayout(0, 3, 1, 2); using var image = TextureConverter.ToTensor(source, transform); worker.Execute(image); var output = worker.PeekOutput("output_0") as TensorFloat; var markersOutput = worker.PeekOutput("output_1") as TensorFloat; output.CompleteOperationsAndDownload(); markersOutput.CompleteOperationsAndDownload(); ClearAnnotations(); Vector2 markerScale = previewUI.rectTransform.rect.size / size; DrawPalms(output, markersOutput, output.shape[0], markerScale); } public void DrawBox(BoundingBox box, Sprite sprite) { var panel = new GameObject("ObjectBox"); panel.AddComponent(); panel.AddComponent(); panel.transform.SetParent(previewUI.transform, false); var img = panel.GetComponent(); img.color = Color.white; img.sprite = sprite; img.type = Image.Type.Sliced; panel.transform.localPosition = new Vector3(box.centerX, -box.centerY); RectTransform rt = panel.GetComponent(); rt.sizeDelta = new Vector2(box.width, box.height); } public void ClearAnnotations() { foreach (Transform child in previewUI.transform) { Destroy(child.gameObject); } } void CleanUp() { anchors?.Dispose(); centersToCorners?.Dispose(); closing = true; if (webcam) Destroy(webcam); if (video) Destroy(video); RenderTexture.active = null; targetTexture.Release(); worker?.Dispose(); worker = null; } void OnDestroy() { CleanUp(); } }