ByteGPT-small / README.md
ijktech-jk's picture
Add README with project details
e1b0d36 verified
|
raw
history blame
8.81 kB
metadata
library_name: transformers
tags:
  - gpt
  - byte-tokenization
  - mobile
  - embedded
  - onnx
license: cc-by-nc-4.0
datasets:
  - custom
  - web
language: en
widget:
  - text: In order to make pancakes, you need to
  - text: Once upon a time

IJK Technology

IJK Technology – ByteGPT-small

ByteGPT-small is a small GPT-style language model trained using byte tokenization inspired by the ByT5 paper. It is designed for use on compute- and memory-constrained devices, such as mobile phones and embedded systems.

πŸš€ Overview

  • Model Type: GPT-style causal language model
  • Tokenizer: Byte-level tokenization (from ByT5)
  • Intended Use: Edge devices, mobile phones, embedded systems
  • Size: Small (initial prototype)
  • Training: Custom-trained from scratch

🧠 Why Byte Tokenization?

Byte tokenization offers several advantages for small-scale, efficient models:

  1. Reduced Memory Footprint:
    Byte-level tokenization drastically reduces the size of the embedding layer, making the model suitable for devices with limited RAM.

  2. No External Dependencies:
    Unlike subword tokenizers (e.g., SentencePiece, BPE), byte tokenization requires no external libraries for tokenization. A simple Python script can handle tokenization.

  3. Robustness to Noise:
    Byte-level models are more robust to misspellings, typos, and out-of-vocabulary tokens.

πŸ’‘ Future Plans

This is the first in a series of models. While this model is not yet highly useful due to its small size, it represents the foundation for future versions. Upcoming releases will include:

  • Larger Models: Scaled-up versions with better performance
  • Distilled Models: Using GPRO distillation to create highly efficient small models
  • Benchmark Results: Comparative performance on mobile devices

πŸ’» Usage

Quick Start (with transformers):

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("ijktech/ByteGPT-small", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ijktech/ByteGPT-small")

input_text = "What is the capital of France?"
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=100)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Tokenizer

The tokenizer is byte-level, compatible with AutoTokenizer from Hugging Face:

tokenizer = AutoTokenizer.from_pretrained("ijktech/ByteGPT-small")

ONNX

The model is also available in ONNX format, and can be used with the ONNX Runtime:

import onnxruntime as ort
import numpy as np

# Create ONNX Runtime session
ort_session = ort.InferenceSession("model.onnx")

# Helper function to generate text using the ONNX model
def generate_with_onnx(prompt_ids, max_new_tokens=50, temperature=1.0):
    input_ids = prompt_ids.clone()
    
    for _ in range(max_new_tokens):
        # Get the last block_size tokens if input is too long
        if input_ids.shape[1] > model.block_size:
            input_ids = input_ids[:, -model.block_size:]
            
        # Run inference
        ort_inputs = {
            'input': input_ids.cpu().numpy()
        }
        logits = ort_session.run(None, ort_inputs)[0]
        
        # Get predictions for the next token
        logits = torch.from_numpy(logits)
        logits = logits[:, -1, :] # Only take the last token's predictions
        
        # Apply temperature
        if temperature != 1.0:
            logits = logits / temperature
            
        # Sample from the distribution
        probs = torch.nn.functional.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        
        # Append the new token
        input_ids = torch.cat([input_ids, next_token], dim=1)
    
    return input_ids

# Test the generation
prompt = "Hello"
prompt_ids = tok(prompt, return_tensors="pt")["input_ids"]
generated_ids = generate_with_onnx(prompt_ids)
generated_text = tok.decode(generated_ids[0], skip_special_tokens=True)
print(f"Generated text: {generated_text}")
#Generated text: Hello everyone!
#A dinner is only available for St. Loui

Android Usage

The model can be used on Android devices using ONNX Runtime Mobile. Here's an example using Kotlin:

import ai.onnxruntime.*
import java.nio.LongBuffer

class ByteGPTTokenizer {
    companion object {
        private const val PAD_TOKEN = "<pad>"
        private const val EOS_TOKEN = "</s>"
        private const val UNK_TOKEN = "<unk>"
        
        // Token IDs for special tokens
        private const val PAD_ID = 0L
        private const val EOS_ID = 1L
        private const val UNK_ID = 2L
        private const val OFFSET = 3L // Number of special tokens
    }

    fun encode(text: String): LongArray {
        // Convert text to UTF-8 bytes and add offset
        val bytes = text.encodeToByteArray()
        val ids = bytes.map { (it.toInt() and 0xFF).toLong() + OFFSET }.toLongArray()
        
        // Add EOS token
        return ids + EOS_ID
    }

    fun decode(ids: LongArray): String {
        // Convert IDs back to bytes, handling special tokens
        val bytes = ids.mapNotNull { id ->
            when (id) {
                PAD_ID -> null
                EOS_ID -> null
                UNK_ID -> null
                else -> (id - OFFSET).toByte()
            }
        }.toByteArray()
        
        return bytes.toString(Charsets.UTF_8)
    }
}

class ByteGPTGenerator(
    private val context: Context,
    private val modelPath: String = "model_mobile.ort",
    private val maxLength: Int = 512
) {
    private val env = OrtEnvironment.getEnvironment()
    private val session: OrtSession
    private val tokenizer = ByteGPTTokenizer()

    init {
        context.assets.open(modelPath).use { modelInput ->
            val modelBytes = modelInput.readBytes()
            session = env.createSession(modelBytes)
        }
    }

    fun generate(prompt: String, maxNewTokens: Int = 50, temperature: Float = 1.0f): String {
        var currentIds = tokenizer.encode(prompt)

        for (i in 0 until maxNewTokens) {
            if (currentIds.size >= maxLength) break

            // Prepare input tensor
            val shape = longArrayOf(1, currentIds.size.toLong())
            val tensorInput = OnnxTensor.createTensor(
                env,
                LongBuffer.wrap(currentIds),
                shape
            )

            // Run inference
            val output = session.run(
                mapOf("input" to tensorInput),
                setOf("output")
            )

            // Get logits for the last token
            val logits = output[0].value as Array<Array<Array<Float>>>
            val lastTokenLogits = logits[0].last()

            // Apply temperature
            if (temperature != 1.0f) {
                for (j in lastTokenLogits.indices) {
                    lastTokenLogits[j] /= temperature
                }
            }

            // Convert to probabilities using softmax
            val expLogits = lastTokenLogits.map { Math.exp(it.toDouble()) }
            val sum = expLogits.sum()
            val probs = expLogits.map { it / sum }

            // Sample from distribution
            val random = Math.random()
            var cumsum = 0.0
            var nextToken = 0
            for (j in probs.indices) {
                cumsum += probs[j]
                if (random < cumsum) {
                    nextToken = j
                    break
                }
            }

            // Append new token
            currentIds = currentIds.plus(nextToken.toLong())

            // Stop if we generate EOS
            if (nextToken == ByteGPTTokenizer.EOS_ID) break
        }

        return tokenizer.decode(currentIds)
    }
}

// Usage example:
val generator = ByteGPTGenerator(context)
val result = generator.generate("Once upon a time")
println(result)

Make sure to:

  1. Add the ONNX Runtime Mobile dependency to your build.gradle:
dependencies {
    implementation 'com.microsoft.onnxruntime:onnxruntime-android:latest.release'
}
  1. Place the model_mobile.ort file in your app's assets folder.

πŸ“œ License

πŸ“ CC-BY-NC-4.0: Free for non-commercial use.

πŸ’Ό Commercial Use: Contact IJK Technology Ltd for licensing at [email protected].

πŸ› οΈ About IJK Technology Ltd

IJK Technology Ltd (IJKTech) develops innovative machine learning models optimized for on-device inference. Our focus is on efficiency, privacy, and usability across mobile and embedded platforms.