carlesonielfa commited on
Commit
05cf642
·
verified ·
1 Parent(s): 1920f6a

Update to Sentis 2.1.1

Browse files

Hi, Thank you for uploading the model and code!

I was messing around with Sentis and updated the code to make it run on the newer version.

- Update code to be compatible with Sentis 2.1.1 (Tested in Unity 6000.0.29f1)
- Replaced `TokenizerUtils` and `Phi3InputFormatter` which seem to be missing in the original code

Files changed (1) hide show
  1. Phi3Claude.cs +161 -134
Phi3Claude.cs CHANGED
@@ -1,134 +1,161 @@
1
- using UnityEngine;
2
- using Microsoft.ML.Tokenizers;
3
- using Unity.Sentis;
4
- using System.IO;
5
- using System.Linq;
6
- using System.Collections.Generic;
7
- using System.Collections;
8
-
9
- public class Phi3Claude : MonoBehaviour
10
- {
11
- IWorker worker;
12
- LlamaTokenizer tokenizer;
13
-
14
- List<int> tokens = new();
15
- TensorInt inputTensor, attentionMaskTensor, positionIdsTensor;
16
- TensorFloat outputLogits;
17
-
18
- int maxTokens = 100; // Maximum number of tokens to generate
19
- List<int> eosTokens; // End of sequence tokens
20
-
21
- private IBackend backend;
22
-
23
- private void Start()
24
- {
25
- var tokenizerModelPath = Path.Combine(Application.streamingAssetsPath, "Phi35/tokenizer.model");
26
- var sentisModelPath = Path.Combine(Application.streamingAssetsPath, "Phi35/model_Uint8.sentis");
27
- var configPath = Path.Combine(Application.streamingAssetsPath, "Phi35/generation_config.json");
28
-
29
- var model = ModelLoader.Load(sentisModelPath);
30
-
31
- worker = WorkerFactory.CreateWorker(BackendType.GPUCompute, model);
32
- Dictionary<string, int> specialTokens = TokenizerUtils.LoadSpecialTokens(Path.Combine(Application.streamingAssetsPath, "Phi35/added_tokens.json"));
33
-
34
- using (Stream tokenizerModelStream = new FileStream(tokenizerModelPath, FileMode.Open, FileAccess.Read))
35
- {
36
- tokenizer = LlamaTokenizer.Create(
37
- tokenizerModelStream,
38
- addBeginOfSentence: true,
39
- addEndOfSentence: false,
40
- specialTokens: specialTokens
41
- );
42
- }
43
-
44
- eosTokens = TokenizerUtils.IdentifyEOSTokens(configPath);
45
- backend = WorkerFactory.CreateBackend(BackendType.GPUCompute);
46
-
47
- Generate("Hello, how is your day?");
48
- }
49
-
50
- public void Generate(string userPrompt, string systemPrompt = "You are a helpful assistant.")
51
- {
52
- string completePrompt = Phi3InputFormatter.FormatChatInput(systemPrompt, userPrompt);
53
- Debug.Log("Complete prompt : " + completePrompt);
54
-
55
- int[] inputIds = tokenizer.EncodeToIds(completePrompt).ToArray();
56
- Debug.Log($"Tokenized input: [{string.Join(", ", inputIds)}]");
57
- Debug.Log($"Decoded tokens: [{string.Join(", ", tokenizer.Decode(inputIds, true))}]");
58
-
59
- tokens.Clear();
60
- tokens.AddRange(inputIds);
61
-
62
- StartCoroutine(GenerateSequence());
63
- }
64
-
65
- private IEnumerator GenerateSequence()
66
- {
67
- for (int i = 0; i < maxTokens; i++)
68
- {
69
- RefreshTensors(tokens.ToArray());
70
-
71
- worker.Execute(new Dictionary<string, Tensor>()
72
- {
73
- {"input_ids", inputTensor},
74
- {"attention_mask", attentionMaskTensor},
75
- {"position_ids", positionIdsTensor}
76
- }); // > 15ms (/!\ should be async)
77
-
78
- outputLogits = worker.PeekOutput("logits") as TensorFloat; // Async
79
- outputLogits.ReadbackRequest(); // Async
80
-
81
- yield return outputLogits.IsReadbackRequestDone(); // 236 ms
82
-
83
- tokens.Add(ProcessLogits()); // > 200ms
84
-
85
- int nextToken = tokens[tokens.Count - 1];
86
-
87
- CleanupTensors();
88
-
89
- if (eosTokens.Contains(nextToken))
90
- break;
91
- }
92
-
93
- string generatedText = tokenizer.Decode(tokens.ToArray(), true); // 0 ms
94
- Debug.Log($"Generated sequence: {generatedText}");
95
- }
96
-
97
-
98
- private int ProcessLogits()
99
- {
100
- // Greedy sampling for simplicity
101
- using var argMaxTensor = TensorInt.AllocNoData(new TensorShape(1, outputLogits.shape[1]));
102
- backend.ArgMax(outputLogits, argMaxTensor, axis: 2, selectLastIndex: false);
103
-
104
- var argMaxTensorArray = argMaxTensor.ToReadOnlyArray(); // TODO : investigate on why it's long to process
105
- int nextToken = argMaxTensorArray[outputLogits.shape[1] - 1];
106
-
107
- Debug.Log($"<color=orange>Next token: [ID = {nextToken}, STR = \"{tokenizer.Decode(new[] { nextToken }, true)}\"]</color>");
108
-
109
- return nextToken;
110
- }
111
-
112
- private void RefreshTensors(int[] ids)
113
- {
114
- // Update input tensors with the full context
115
- inputTensor = new TensorInt(new TensorShape(1, ids.Length), ids);
116
- attentionMaskTensor = new TensorInt(new TensorShape(1, ids.Length), Enumerable.Repeat(1, ids.Length).ToArray());
117
- positionIdsTensor = new TensorInt(new TensorShape(1, ids.Length), Enumerable.Range(0, ids.Length).ToArray());
118
- }
119
-
120
- private void CleanupTensors()
121
- {
122
- inputTensor?.Dispose();
123
- attentionMaskTensor?.Dispose();
124
- positionIdsTensor?.Dispose();
125
- outputLogits?.Dispose();
126
- }
127
-
128
- private void OnDestroy() {
129
- CleanupTensors();
130
-
131
- worker?.Dispose();
132
- backend?.Dispose();
133
- }
134
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ using UnityEngine;
2
+ using Microsoft.ML.Tokenizers;
3
+ using Unity.Sentis;
4
+ using System.IO;
5
+ using System.Linq;
6
+ using System.Collections.Generic;
7
+ using System.Collections;
8
+
9
+ public class Phi3Claude : MonoBehaviour
10
+ {
11
+ Worker worker_model;
12
+ Worker worker_decoding;
13
+ LlamaTokenizer tokenizer;
14
+
15
+ List<int> tokens = new();
16
+ Tensor<int> inputTensor, attentionMaskTensor, positionIdsTensor;
17
+ Tensor<float> outputLogits;
18
+ Tensor<int> argMaxTensor;
19
+
20
+ int maxTokens = 100; // Maximum number of tokens to generate
21
+ List<int> eosTokens; // End of sequence tokens
22
+
23
+ private void Start()
24
+ {
25
+ var tokenizerModelPath = Path.Combine(Application.streamingAssetsPath, "Phi35/tokenizer.model");
26
+ var sentisModelPath = Path.Combine(Application.streamingAssetsPath, "Phi35/model_Uint8.sentis");
27
+ var configPath = Path.Combine(Application.streamingAssetsPath, "Phi35/generation_config.json");
28
+
29
+ var model = ModelLoader.Load(sentisModelPath);
30
+ var vocab_size = 32064;
31
+ // Create a model that does greedy decoding
32
+ FunctionalGraph graph = new FunctionalGraph();
33
+ FunctionalTensor logits = graph.AddInput<float>(new DynamicTensorShape(1,-1,vocab_size));
34
+ FunctionalTensor argMax = Functional.ArgMax(logits, 2, false);
35
+ Model greedyModel = graph.Compile(argMax);
36
+
37
+ worker_model = new Worker(model, BackendType.GPUCompute);
38
+ worker_decoding = new Worker(greedyModel, BackendType.GPUCompute);
39
+ // Manually set from added_tokens.json
40
+ Dictionary<string, int> specialTokens = new()
41
+ {
42
+ { "<|assistant|>", 32001 },
43
+ { "<|endoftext|>", 32000 },
44
+ { "<|end|>", 32007 },
45
+ { "<|placeholder1|>", 32002 },
46
+ { "<|placeholder2|>", 32003 },
47
+ { "<|placeholder3|>", 32004 },
48
+ { "<|placeholder4|>", 32005 },
49
+ { "<|placeholder5|>", 32008 },
50
+ { "<|placeholder6|>", 32009 },
51
+ { "<|system|>", 32006 },
52
+ { "<|user|>", 32010 }
53
+ };
54
+
55
+
56
+ using (Stream tokenizerModelStream = new FileStream(tokenizerModelPath, FileMode.Open, FileAccess.Read))
57
+ {
58
+ tokenizer = LlamaTokenizer.Create(
59
+ tokenizerModelStream,
60
+ addBeginOfSentence: true,
61
+ addEndOfSentence: false,
62
+ specialTokens: specialTokens
63
+ );
64
+ }
65
+
66
+ // Manually set from generation_config.json
67
+ eosTokens = new(){32007, 32001, 32000};
68
+
69
+ Generate("What is the capital of France?");
70
+ }
71
+
72
+ public void Generate(string userPrompt, string systemPrompt = "You are a helpful assistant.")
73
+ {
74
+ string completePrompt = $@"<|system|>
75
+ {systemPrompt}<|end|>
76
+ <|user|>
77
+ {userPrompt}<|end|>
78
+ <|assistant|>";
79
+ Debug.Log("Complete prompt : " + completePrompt);
80
+
81
+ int[] inputIds = tokenizer.EncodeToIds(completePrompt).ToArray();
82
+ Debug.Log($"Tokenized input: [{string.Join(", ", inputIds)}]");
83
+ Debug.Log($"Decoded tokens: [{string.Join(", ", tokenizer.Decode(inputIds, true))}]");
84
+
85
+ tokens.Clear();
86
+ tokens.AddRange(inputIds);
87
+
88
+ StartCoroutine(GenerateSequence());
89
+ }
90
+
91
+ private IEnumerator GenerateSequence()
92
+ {
93
+ for (int i = 0; i < maxTokens; i++)
94
+ {
95
+ RefreshTensors(tokens.ToArray());
96
+
97
+ worker_model.SetInput("input_ids", inputTensor);
98
+ worker_model.SetInput("attention_mask", attentionMaskTensor);
99
+ worker_model.SetInput("position_ids", positionIdsTensor);
100
+ worker_model.Schedule(); // > 15ms (/!\ should be async)
101
+
102
+ outputLogits = worker_model.PeekOutput("logits") as Tensor<float>; // Async
103
+ outputLogits.ReadbackRequest(); // Async
104
+
105
+ yield return outputLogits.IsReadbackRequestDone(); // 236 ms
106
+
107
+ tokens.Add(ProcessLogits()); // > 200ms
108
+
109
+ int nextToken = tokens[tokens.Count - 1];
110
+
111
+ CleanupTensors();
112
+
113
+ if (eosTokens.Contains(nextToken))
114
+ break;
115
+ }
116
+
117
+ string generatedText = tokenizer.Decode(tokens.ToArray(), true); // 0 ms
118
+ Debug.Log($"Generated sequence: {generatedText}");
119
+ }
120
+
121
+
122
+ private int ProcessLogits()
123
+ {
124
+ worker_decoding.SetInput(0, outputLogits);
125
+ worker_decoding.Schedule();
126
+ argMaxTensor = worker_decoding.PeekOutput() as Tensor<int>;
127
+ argMaxTensor.ReadbackRequest();
128
+ argMaxTensor.IsReadbackRequestDone();
129
+
130
+ var argMaxTensorArray = argMaxTensor.DownloadToArray(); // TODO : investigate on why it's long to process
131
+ int nextToken = argMaxTensorArray[outputLogits.shape[1] - 1];
132
+
133
+ Debug.Log($"<color=orange>Next token: [ID = {nextToken}, STR = \"{tokenizer.Decode(new[] { nextToken }, true)}\"]</color>");
134
+
135
+ return nextToken;
136
+ }
137
+
138
+ private void RefreshTensors(int[] ids)
139
+ {
140
+ // Update input tensors with the full context
141
+ inputTensor = new Tensor<int>(new TensorShape(1, ids.Length), ids);
142
+ attentionMaskTensor = new Tensor<int>(new TensorShape(1, ids.Length), Enumerable.Repeat(1, ids.Length).ToArray());
143
+ positionIdsTensor = new Tensor<int>(new TensorShape(1, ids.Length), Enumerable.Range(0, ids.Length).ToArray());
144
+ }
145
+
146
+ private void CleanupTensors()
147
+ {
148
+ inputTensor?.Dispose();
149
+ attentionMaskTensor?.Dispose();
150
+ positionIdsTensor?.Dispose();
151
+ outputLogits?.Dispose();
152
+ argMaxTensor?.Dispose();
153
+ }
154
+
155
+ private void OnDestroy() {
156
+ CleanupTensors();
157
+
158
+ worker_model?.Dispose();
159
+ worker_decoding?.Dispose();
160
+ }
161
+ }