File size: 4,668 Bytes
4a6db11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
using System;
using System.Collections.Generic;
using Unity.InferenceEngine;
using UnityEngine;
public class RunMiniLM : MonoBehaviour
{
public ModelAsset modelAsset;
public TextAsset vocabAsset;
const BackendType backend = BackendType.GPUCompute;
string string1 = "That is a happy person"; // similarity = 1
//Choose a string to compare with string1:
string string2 = "That is a happy dog"; // similarity = 0.695
//string string2 = "That is a very happy person"; // similarity = 0.943
//string string2 = "Today is a sunny day"; // similarity = 0.257
//Special tokens
const int START_TOKEN = 101;
const int END_TOKEN = 102;
//Store the vocabulary
string[] tokens;
const int FEATURES = 384; //size of feature space
Worker engine, dotScore;
void Start()
{
tokens = vocabAsset.text.Split("\r\n");
engine = CreateMLModel();
dotScore = CreateDotScoreModel();
var tokens1 = GetTokens(string1);
var tokens2 = GetTokens(string2);
using Tensor<float> embedding1 = GetEmbedding(tokens1);
using Tensor<float> embedding2 = GetEmbedding(tokens2);
float score = GetDotScore(embedding1, embedding2);
Debug.Log("Similarity Score: " + score);
}
float GetDotScore(Tensor<float> A, Tensor<float> B)
{
dotScore.Schedule(A, B);
var output = (dotScore.PeekOutput() as Tensor<float>).DownloadToNativeArray();
return output[0];
}
Tensor<float> GetEmbedding(List<int> tokenList)
{
int N = tokenList.Count;
using var input_ids = new Tensor<int>(new TensorShape(1, N), tokenList.ToArray());
using var token_type_ids = new Tensor<int>(new TensorShape(1, N), new int[N]);
int[] mask = new int[N];
for (int i = 0; i < mask.Length; i++)
{
mask[i] = 1;
}
using var attention_mask = new Tensor<int>(new TensorShape(1, N), mask);
engine.Schedule(input_ids, attention_mask, token_type_ids);
var output = engine.PeekOutput().ReadbackAndClone() as Tensor<float>;
return output;
}
Worker CreateMLModel()
{
var model = ModelLoader.Load(modelAsset);
var graph = new FunctionalGraph();
var inputs = graph.AddInputs(model);
var tokenEmbeddings = Functional.Forward(model, inputs)[0];
var attention_mask = inputs[1];
var output = MeanPooling(tokenEmbeddings, attention_mask);
var modelWithMeanPooling = graph.Compile(output);
return new Worker(modelWithMeanPooling, backend);
}
//Get average of token embeddings taking into account the attention mask
FunctionalTensor MeanPooling(FunctionalTensor tokenEmbeddings, FunctionalTensor attentionMask)
{
var mask = attentionMask.Unsqueeze(-1).BroadcastTo(new[] { FEATURES }); //shape=(1,N,FEATURES)
var A = Functional.ReduceSum(tokenEmbeddings * mask, 1); //shape=(1,FEATURES)
var B = A / (Functional.ReduceSum(mask, 1) + 1e-9f); //shape=(1,FEATURES)
var C = Functional.Sqrt(Functional.ReduceSum(Functional.Square(B), 1, true)); //shape=(1,FEATURES)
return B / C; //shape=(1,FEATURES)
}
Worker CreateDotScoreModel()
{
var graph = new FunctionalGraph();
var input1 = graph.AddInput<float>(new TensorShape(1, FEATURES));
var input2 = graph.AddInput<float>(new TensorShape(1, FEATURES));
var output = Functional.ReduceSum(input1 * input2, 1);
var dotScoreModel = graph.Compile(output);
return new Worker(dotScoreModel, backend);
}
List<int> GetTokens(string text)
{
//split over whitespace
string[] words = text.ToLower().Split(null);
var ids = new List<int>
{
START_TOKEN
};
string s = "";
foreach (var word in words)
{
int start = 0;
for (int i = word.Length; i >= 0; i--)
{
string subword = start == 0 ? word.Substring(start, i) : "##" + word.Substring(start, i - start);
int index = Array.IndexOf(tokens, subword);
if (index >= 0)
{
ids.Add(index);
s += subword + " ";
if (i == word.Length) break;
start = i;
i = word.Length + 1;
}
}
}
ids.Add(END_TOKEN);
Debug.Log("Tokenized sentence = " + s);
return ids;
}
void OnDestroy()
{
dotScore?.Dispose();
engine?.Dispose();
}
}
|