File size: 8,806 Bytes
e27e041 6829cc9 e27e041 6b0ac00 d95f1d1 ec4a92e d95f1d1 ec4a92e d95f1d1 ec4a92e d95f1d1 ec4a92e d95f1d1 ec4a92e d95f1d1 ec4a92e d95f1d1 ec4a92e d95f1d1 ec4a92e d95f1d1 ec4a92e d95f1d1 ec4a92e d95f1d1 ec4a92e d95f1d1 ec4a92e d95f1d1 ec4a92e d95f1d1 ec4a92e d95f1d1 ec4a92e d95f1d1 ec4a92e d95f1d1 6829cc9 e1b0d36 ec4a92e b3d3ae8 d95f1d1 6829cc9 d95f1d1 ec4a92e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 |
---
library_name: transformers
tags:
- gpt
- byte-tokenization
- mobile
- embedded
- onnx
license: cc-by-nc-4.0
datasets:
- custom
- web
language: en
widget:
- text: "In order to make pancakes, you need to"
- text: "Once upon a time"
---
<p align="center">
<img src="logo.png" alt="IJK Technology" width="150">
</p>
<h1 align="center">IJK Technology β ByteGPT-small</h1>
**ByteGPT-small** is a small GPT-style language model trained using byte tokenization inspired by the ByT5 paper. It is designed for use on compute- and memory-constrained devices, such as mobile phones and embedded systems.
## π Overview
- **Model Type:** GPT-style causal language model
- **Tokenizer:** Byte-level tokenization (from ByT5)
- **Intended Use:** Edge devices, mobile phones, embedded systems
- **Size:** Small (initial prototype)
- **Training:** Custom-trained from scratch
## π§ Why Byte Tokenization?
Byte tokenization offers several advantages for small-scale, efficient models:
1. **Reduced Memory Footprint:**
Byte-level tokenization drastically reduces the size of the embedding layer, making the model suitable for devices with limited RAM.
2. **No External Dependencies:**
Unlike subword tokenizers (e.g., SentencePiece, BPE), byte tokenization requires no external libraries for tokenization. A simple Python script can handle tokenization.
3. **Robustness to Noise:**
Byte-level models are more robust to misspellings, typos, and out-of-vocabulary tokens.
## π‘ Future Plans
This is the **first** in a series of models. While this model is not yet highly useful due to its small size, it represents the foundation for future versions. Upcoming releases will include:
- **Larger Models:** Scaled-up versions with better performance
- **Distilled Models:** Using GPRO distillation to create highly efficient small models
- **Benchmark Results:** Comparative performance on mobile devices
## π» Usage
### **Quick Start (with `transformers`):**
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("ijktech/ByteGPT-small", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ijktech/ByteGPT-small")
input_text = "What is the capital of France?"
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
### Tokenizer
The tokenizer is byte-level, compatible with AutoTokenizer from Hugging Face:
```python
tokenizer = AutoTokenizer.from_pretrained("ijktech/ByteGPT-small")
```
### ONNX
The model is also available in ONNX format, and can be used with the ONNX Runtime:
```python
import onnxruntime as ort
import numpy as np
# Create ONNX Runtime session
ort_session = ort.InferenceSession("model.onnx")
# Helper function to generate text using the ONNX model
def generate_with_onnx(prompt_ids, max_new_tokens=50, temperature=1.0):
input_ids = prompt_ids.clone()
for _ in range(max_new_tokens):
# Get the last block_size tokens if input is too long
if input_ids.shape[1] > model.block_size:
input_ids = input_ids[:, -model.block_size:]
# Run inference
ort_inputs = {
'input': input_ids.cpu().numpy()
}
logits = ort_session.run(None, ort_inputs)[0]
# Get predictions for the next token
logits = torch.from_numpy(logits)
logits = logits[:, -1, :] # Only take the last token's predictions
# Apply temperature
if temperature != 1.0:
logits = logits / temperature
# Sample from the distribution
probs = torch.nn.functional.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Append the new token
input_ids = torch.cat([input_ids, next_token], dim=1)
return input_ids
# Test the generation
prompt = "Hello"
prompt_ids = tok(prompt, return_tensors="pt")["input_ids"]
generated_ids = generate_with_onnx(prompt_ids)
generated_text = tok.decode(generated_ids[0], skip_special_tokens=True)
print(f"Generated text: {generated_text}")
#Generated text: Hello everyone!
#A dinner is only available for St. Loui
```
### Android Usage
The model can be used on Android devices using ONNX Runtime Mobile. Here's an example using Kotlin:
```kotlin
import ai.onnxruntime.*
import java.nio.LongBuffer
class ByteGPTTokenizer {
companion object {
private const val PAD_TOKEN = "<pad>"
private const val EOS_TOKEN = "</s>"
private const val UNK_TOKEN = "<unk>"
// Token IDs for special tokens
private const val PAD_ID = 0L
private const val EOS_ID = 1L
private const val UNK_ID = 2L
private const val OFFSET = 3L // Number of special tokens
}
fun encode(text: String): LongArray {
// Convert text to UTF-8 bytes and add offset
val bytes = text.encodeToByteArray()
val ids = bytes.map { (it.toInt() and 0xFF).toLong() + OFFSET }.toLongArray()
// Add EOS token
return ids + EOS_ID
}
fun decode(ids: LongArray): String {
// Convert IDs back to bytes, handling special tokens
val bytes = ids.mapNotNull { id ->
when (id) {
PAD_ID -> null
EOS_ID -> null
UNK_ID -> null
else -> (id - OFFSET).toByte()
}
}.toByteArray()
return bytes.toString(Charsets.UTF_8)
}
}
class ByteGPTGenerator(
private val context: Context,
private val modelPath: String = "model_mobile.ort",
private val maxLength: Int = 512
) {
private val env = OrtEnvironment.getEnvironment()
private val session: OrtSession
private val tokenizer = ByteGPTTokenizer()
init {
context.assets.open(modelPath).use { modelInput ->
val modelBytes = modelInput.readBytes()
session = env.createSession(modelBytes)
}
}
fun generate(prompt: String, maxNewTokens: Int = 50, temperature: Float = 1.0f): String {
var currentIds = tokenizer.encode(prompt)
for (i in 0 until maxNewTokens) {
if (currentIds.size >= maxLength) break
// Prepare input tensor
val shape = longArrayOf(1, currentIds.size.toLong())
val tensorInput = OnnxTensor.createTensor(
env,
LongBuffer.wrap(currentIds),
shape
)
// Run inference
val output = session.run(
mapOf("input" to tensorInput),
setOf("output")
)
// Get logits for the last token
val logits = output[0].value as Array<Array<Array<Float>>>
val lastTokenLogits = logits[0].last()
// Apply temperature
if (temperature != 1.0f) {
for (j in lastTokenLogits.indices) {
lastTokenLogits[j] /= temperature
}
}
// Convert to probabilities using softmax
val expLogits = lastTokenLogits.map { Math.exp(it.toDouble()) }
val sum = expLogits.sum()
val probs = expLogits.map { it / sum }
// Sample from distribution
val random = Math.random()
var cumsum = 0.0
var nextToken = 0
for (j in probs.indices) {
cumsum += probs[j]
if (random < cumsum) {
nextToken = j
break
}
}
// Append new token
currentIds = currentIds.plus(nextToken.toLong())
// Stop if we generate EOS
if (nextToken == ByteGPTTokenizer.EOS_ID) break
}
return tokenizer.decode(currentIds)
}
}
// Usage example:
val generator = ByteGPTGenerator(context)
val result = generator.generate("Once upon a time")
println(result)
```
Make sure to:
1. Add the ONNX Runtime Mobile dependency to your `build.gradle`:
```gradle
dependencies {
implementation 'com.microsoft.onnxruntime:onnxruntime-android:latest.release'
}
```
2. Place the `model_mobile.ort` file in your app's assets folder.
## π License
π **CC-BY-NC-4.0**: Free for non-commercial use.
πΌ **Commercial Use**: Contact IJK Technology Ltd for licensing at [[email protected]](mailto:[email protected]).
## π οΈ About IJK Technology Ltd
IJK Technology Ltd (IJKTech) develops innovative machine learning models optimized for on-device inference. Our focus is on efficiency, privacy, and usability across mobile and embedded platforms. |