UnityGiles commited on
Commit
b0ba900
·
1 Parent(s): 292a662

update to inference engine

Browse files
README.md CHANGED
@@ -2,42 +2,28 @@
2
  license: mit
3
  library_name: unity-sentis
4
  pipeline_tag: text-generation
 
 
5
  ---
6
 
7
- # Tiny Stories Model in Unity Sentis Format (Sentis 1.4.0-pre.2*)
8
- *Version 1.3.0 Sentis files are not compatible with Sentis 1.4.0 and would need to be recreated/downloaded
9
-
10
- This is the [Tiny Stories model](https://huggingface.co/roneneldan/TinyStories-33M) checked to run on Unity 2023. Tiny Stories is a Large Language Model that was trained on children's stories and can create stories based on the first couple of sentences.
11
 
 
12
 
13
  ## How to Use
14
- * Create a new scene in Unity 2023
15
- * Install `com.unity.sentis` and `com.unity.nuget.newtonsoft-json` packages
16
- * Add the RunTinyStories.cs file to the Main Camera
17
- * Put `tinystories.sentis`, `vocab.json` and `merges.txt` in the Assets/StreamingAssets folder
18
- * Adjust some of the variables such as the `outputText` string to set the prompt
19
- * Press run
20
- * The output will appear in the console window
21
-
22
- ## Example Input
23
- ```
24
- One day an alien came down from Mars. It saw a chicken
25
- ```
26
- ## Example Output
27
- ```
28
- One day an alien came down from Mars. It saw a chicken and said, "Hello, little chicken. What are you doing here?"
29
-
30
- The chicken replied, "I'm looking for a place to stay. I'm very tired."
31
-
32
- The alien said, "You can stay here. I have a nice place for you. It's very comfortable."
33
 
34
- The chicken was so happy. She thanked the alien and said, "Thank you. I'm very comfortable here."
 
 
 
 
 
35
 
36
- The alien smiled and said, "You're welcome
37
- ```
38
 
39
- ## Unity Sentis
40
- Unity Sentis is the inference engine which runs on Unity 2023. More can be found about it [here](https://unity.com/products/sentis)
41
 
42
  ## Disclaimer
43
  The model was trained on children's stories so very unlikely to produce undesirable text. As an extra precaution, we removed a few tokens from vocab.json that might not be suitable for younger audiences. The original json can be found on the Tiny Stories original page.
 
2
  license: mit
3
  library_name: unity-sentis
4
  pipeline_tag: text-generation
5
+ tags:
6
+ - unity-inference-engine
7
  ---
8
 
9
+ # Tiny Stories in Unity 6 with Inference Engine
 
 
 
10
 
11
+ This is the [Tiny Stories model](https://huggingface.co/roneneldan/TinyStories-33M) running in Unity 6 with Inference Engine. Tiny Stories is a Large Language Model that was trained on children's stories and can create stories based on the first couple of sentences.
12
 
13
  ## How to Use
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ * Create a new scene in Unity 6;
16
+ * Install `com.unity.ai.inference` from the package manager;
17
+ * Install `com.unity.nuget.newtonsoft-json` from the package manager;
18
+ * Drag the `tinystories.onnx` asset from the `models` folder into the `Model Asset` field;
19
+ * Drag the `vocab.json` asset from the `data` folder into the `Vocab Asset` field;
20
+ * Drag the `merges.txt` asset from the `data` folder into the `Merges Asset` field;
21
 
22
+ ## Preview
23
+ Enter play mode. If working correctly the predicted text will be logged to the console.
24
 
25
+ ## Inference Engine
26
+ Inference Engine is a neural network inference library for Unity. Find out more [here](https://docs.unity3d.com/Packages/com.unity.ai.inference@latest).
27
 
28
  ## Disclaimer
29
  The model was trained on children's stories so very unlikely to produce undesirable text. As an extra precaution, we removed a few tokens from vocab.json that might not be suitable for younger audiences. The original json can be found on the Tiny Stories original page.
RunTinyStories.cs CHANGED
@@ -1,33 +1,14 @@
1
- using System.Collections;
2
  using System.Collections.Generic;
3
  using UnityEngine;
4
- using Unity.Sentis;
5
- using System.IO;
6
  using System.Text;
7
- using FF = Unity.Sentis.Functional;
8
-
9
- /*
10
- * Tiny Stories Inference Code
11
- * ===========================
12
- *
13
- * Put this script on the Main Camera
14
- *
15
- * In Assets/StreamingAssets put:
16
- *
17
- * tinystories.sentis (or put in asset folder and drag onto field)
18
- * vocab.json
19
- * merges.txt
20
- *
21
- * Install package com.unity.nuget.newtonsoft-json from packagemanger
22
- * Install package com.unity.sentis
23
- *
24
- */
25
-
26
 
27
  public class RunTinyStories : MonoBehaviour
28
  {
29
- //Drop the tinystories.sentis or onnx file on here if using an asset:
30
- //public ModelAsset asset;
 
31
  const BackendType backend = BackendType.GPUCompute;
32
 
33
  //string outputString = "Once upon a time, there were three bears";
@@ -45,22 +26,21 @@ public class RunTinyStories : MonoBehaviour
45
  //Store the vocabulary
46
  string[] tokens;
47
 
48
- IWorker engine;
49
 
50
- int currentToken = 0;
51
  int[] outputTokens = new int[maxTokens];
52
 
53
  // Used for special character decoding
54
  int[] whiteSpaceCharacters = new int[256];
55
  int[] encodedCharacters = new int[256];
56
 
57
- bool runInference = false;
58
-
59
 
60
  //stop after this many tokens
61
  const int stopAfter = 100;
62
 
63
- int totalTokens = 0;
64
 
65
  string[] merges;
66
  Dictionary<string, int> vocab;
@@ -71,19 +51,17 @@ public class RunTinyStories : MonoBehaviour
71
 
72
  LoadVocabulary();
73
 
74
- var model1 = ModelLoader.Load(Path.Join(Application.streamingAssetsPath , "tinystories.sentis"));
75
- //var model1 = ModelLoader.Load(asset);
76
  //Create a new model to select the random token:
77
- var model2 = FF.Compile(
78
- (input, currentToken) =>
79
- {
80
- var row = FF.Select(model1.Forward(input)[8], 1, currentToken);
81
- return FF.Multinomial(predictability * row, 1);
82
- },
83
- (model1.inputs[0], InputDef.Int(new TensorShape()))
84
- );
85
 
86
- engine = WorkerFactory.CreateWorker(backend, model2);
 
 
 
 
 
 
 
87
 
88
  DecodePrompt(outputString);
89
 
@@ -101,16 +79,14 @@ public class RunTinyStories : MonoBehaviour
101
 
102
  void RunInference()
103
  {
104
- using var tokensSoFar = new TensorInt(new TensorShape(1, maxTokens), outputTokens);
105
- using var index = new TensorInt(currentToken);
106
 
107
- engine.Execute(new Dictionary<string, Tensor> { {"input_0", tokensSoFar }, { "input_1", index }});
108
 
109
- var probs = engine.PeekOutput() as TensorInt;
110
  Debug.Log(probs.shape);
111
 
112
- probs.CompleteOperationsAndDownload();
113
-
114
  int ID = probs[0];
115
 
116
  //shift window down if got to the end
@@ -130,23 +106,22 @@ public class RunTinyStories : MonoBehaviour
130
  else outputString += GetUnicodeText(tokens[ID]);
131
 
132
  Debug.Log(outputString);
133
-
134
  }
135
 
136
  void DecodePrompt(string text)
137
  {
138
  var inputTokens = GetTokens(text);
139
 
140
- for(int i = 0; i < inputTokens.Count; i++)
141
  {
142
  outputTokens[i] = inputTokens[i];
143
  }
144
  currentToken = inputTokens.Count - 1;
145
  }
146
-
147
  void LoadVocabulary()
148
  {
149
- var jsonText = File.ReadAllText(Path.Join(Application.streamingAssetsPath , "vocab.json"));
150
  vocab = Newtonsoft.Json.JsonConvert.DeserializeObject<Dictionary<string, int>>(jsonText);
151
  tokens = new string[vocab.Count];
152
  foreach (var item in vocab)
@@ -154,7 +129,7 @@ public class RunTinyStories : MonoBehaviour
154
  tokens[item.Value] = item.Key;
155
  }
156
 
157
- merges = File.ReadAllLines(Path.Join(Application.streamingAssetsPath , "merges.txt"));
158
  }
159
 
160
  // Translates encoded special characters to Unicode
@@ -174,8 +149,7 @@ public class RunTinyStories : MonoBehaviour
174
  string outText = "";
175
  foreach (char letter in text)
176
  {
177
- outText += ((int)letter <= 256) ? letter :
178
- (char)whiteSpaceCharacters[(int)(letter - 256)];
179
  }
180
  return outText;
181
  }
@@ -185,7 +159,7 @@ public class RunTinyStories : MonoBehaviour
185
  string outText = "";
186
  foreach (char letter in text)
187
  {
188
- outText += (char)encodedCharacters[(int)letter];
189
  }
190
  return outText;
191
  }
@@ -215,7 +189,7 @@ public class RunTinyStories : MonoBehaviour
215
 
216
  // Start with a list of single characters
217
  var inputTokens = new List<string>();
218
- foreach(var letter in text)
219
  {
220
  inputTokens.Add(letter.ToString());
221
  }
@@ -224,7 +198,7 @@ public class RunTinyStories : MonoBehaviour
224
 
225
  //Find the ids of the words in the vocab
226
  var ids = new List<int>();
227
- foreach(var token in inputTokens)
228
  {
229
  if (vocab.TryGetValue(token, out int id))
230
  {
@@ -237,7 +211,7 @@ public class RunTinyStories : MonoBehaviour
237
 
238
  void ApplyMerges(List<string> inputTokens)
239
  {
240
- foreach(var merge in merges)
241
  {
242
  string[] pair = merge.Split(' ');
243
  int n = 0;
@@ -254,9 +228,8 @@ public class RunTinyStories : MonoBehaviour
254
  }
255
  }
256
 
257
- private void OnDestroy()
258
  {
259
  engine?.Dispose();
260
  }
261
-
262
  }
 
 
1
  using System.Collections.Generic;
2
  using UnityEngine;
3
+ using Unity.InferenceEngine;
 
4
  using System.Text;
5
+ using FF = Unity.InferenceEngine.Functional;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  public class RunTinyStories : MonoBehaviour
8
  {
9
+ public ModelAsset modelAsset;
10
+ public TextAsset vocabAsset;
11
+ public TextAsset mergesAsset;
12
  const BackendType backend = BackendType.GPUCompute;
13
 
14
  //string outputString = "Once upon a time, there were three bears";
 
26
  //Store the vocabulary
27
  string[] tokens;
28
 
29
+ Worker engine;
30
 
31
+ int currentToken;
32
  int[] outputTokens = new int[maxTokens];
33
 
34
  // Used for special character decoding
35
  int[] whiteSpaceCharacters = new int[256];
36
  int[] encodedCharacters = new int[256];
37
 
38
+ bool runInference;
 
39
 
40
  //stop after this many tokens
41
  const int stopAfter = 100;
42
 
43
+ int totalTokens;
44
 
45
  string[] merges;
46
  Dictionary<string, int> vocab;
 
51
 
52
  LoadVocabulary();
53
 
54
+ var model1 = ModelLoader.Load(modelAsset);
 
55
  //Create a new model to select the random token:
 
 
 
 
 
 
 
 
56
 
57
+ var graph = new FunctionalGraph();
58
+ var input = graph.AddInput(model1, 0);
59
+ var currentTokenInput = graph.AddInput<int>(new TensorShape(), "currentToken");
60
+ var row = FF.Select(Functional.Forward(model1, input)[0], 1, currentTokenInput);
61
+ var output = FF.Multinomial(predictability * row, 1);
62
+ var model2 = graph.Compile(output);
63
+
64
+ engine = new Worker(model2, backend);
65
 
66
  DecodePrompt(outputString);
67
 
 
79
 
80
  void RunInference()
81
  {
82
+ using var tokensSoFar = new Tensor<int>(new TensorShape(1, maxTokens), outputTokens);
83
+ using var index = new Tensor<int>(new TensorShape(), new[] { currentToken });
84
 
85
+ engine.Schedule(tokensSoFar, index);
86
 
87
+ using var probs = (engine.PeekOutput() as Tensor<int>).ReadbackAndClone();
88
  Debug.Log(probs.shape);
89
 
 
 
90
  int ID = probs[0];
91
 
92
  //shift window down if got to the end
 
106
  else outputString += GetUnicodeText(tokens[ID]);
107
 
108
  Debug.Log(outputString);
 
109
  }
110
 
111
  void DecodePrompt(string text)
112
  {
113
  var inputTokens = GetTokens(text);
114
 
115
+ for (int i = 0; i < inputTokens.Count; i++)
116
  {
117
  outputTokens[i] = inputTokens[i];
118
  }
119
  currentToken = inputTokens.Count - 1;
120
  }
121
+
122
  void LoadVocabulary()
123
  {
124
+ var jsonText = vocabAsset.text;
125
  vocab = Newtonsoft.Json.JsonConvert.DeserializeObject<Dictionary<string, int>>(jsonText);
126
  tokens = new string[vocab.Count];
127
  foreach (var item in vocab)
 
129
  tokens[item.Value] = item.Key;
130
  }
131
 
132
+ merges = mergesAsset.text.Split("\r\n");
133
  }
134
 
135
  // Translates encoded special characters to Unicode
 
149
  string outText = "";
150
  foreach (char letter in text)
151
  {
152
+ outText += (letter <= 256) ? letter : (char)whiteSpaceCharacters[letter - 256];
 
153
  }
154
  return outText;
155
  }
 
159
  string outText = "";
160
  foreach (char letter in text)
161
  {
162
+ outText += (char)encodedCharacters[letter];
163
  }
164
  return outText;
165
  }
 
189
 
190
  // Start with a list of single characters
191
  var inputTokens = new List<string>();
192
+ foreach (var letter in text)
193
  {
194
  inputTokens.Add(letter.ToString());
195
  }
 
198
 
199
  //Find the ids of the words in the vocab
200
  var ids = new List<int>();
201
+ foreach (var token in inputTokens)
202
  {
203
  if (vocab.TryGetValue(token, out int id))
204
  {
 
211
 
212
  void ApplyMerges(List<string> inputTokens)
213
  {
214
+ foreach (var merge in merges)
215
  {
216
  string[] pair = merge.Split(' ');
217
  int n = 0;
 
228
  }
229
  }
230
 
231
+ void OnDestroy()
232
  {
233
  engine?.Dispose();
234
  }
 
235
  }
merges.txt → data/merges.txt RENAMED
File without changes
vocab.json → data/vocab.json RENAMED
File without changes
info.json CHANGED
@@ -3,13 +3,13 @@
3
  "RunTinyStories.cs"
4
  ],
5
  "models": [
6
- "tinystories.sentis"
7
  ],
8
  "data": [
9
- "vocab.json",
10
- "merges.txt"
11
  ],
12
  "version": [
13
- "1.4.0"
14
  ]
15
  }
 
3
  "RunTinyStories.cs"
4
  ],
5
  "models": [
6
+ "models/tinystories.onnx"
7
  ],
8
  "data": [
9
+ "data/vocab.json",
10
+ "data/merges.txt"
11
  ],
12
  "version": [
13
+ "2.2.0"
14
  ]
15
  }
tinystories.onnx → models/tinystories.onnx RENAMED
File without changes
tinystories.sentis DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c7962eb7db56b241cc19cd3f0cffcf5d76d3c35639917f07effa6b3c242c91e9
3
- size 478818076