## Introduction

This tutorial demonstrates how to perform evaluation on a gpt-j-6B-int8 model.

## Prerequisite

In [None]:
!pip install onnx onnxruntime torch transformers datasets accelerate

## Run

### 1. Get lambada acc

In [None]:
from transformers import AutoTokenizer
import torch
from datasets import load_dataset
import onnxruntime as ort
from torch.nn.functional import pad

# load model
model_id = "EleutherAI/gpt-j-6B"
tokenizer = AutoTokenizer.from_pretrained(model_id)

def tokenize_function(examples):
    example = tokenizer(examples['text'])
    return example

# create dataset
dataset = load_dataset('lambada', split='validation')
dataset = dataset.shuffle(seed=42)
dataset = dataset.map(tokenize_function, batched=True)
dataset.set_format(type='torch', columns=['input_ids'])

# create session
options = ort.SessionOptions()
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session = ort.InferenceSession('/path/to/model.onnx', options, providers=ort.get_available_providers())
total, hit = 0, 0
index = 1

# inference
for idx, batch in enumerate(dataset):
    input_ids = batch['input_ids'].unsqueeze(0)
    label = input_ids[:, -1]
    pad_len = 0  ##set to 0
    input_ids = pad(input_ids, (0, pad_len), value=1)
    ort_inputs = {
        'input_ids': input_ids.detach().cpu().numpy(),
        'attention_mask': torch.ones(input_ids.shape).detach().cpu().numpy().astype('int64')
    }
    predictions = session.run(None, ort_inputs)
    outputs = torch.from_numpy(predictions[0]) 
    last_token_logits = outputs[:, -2 - pad_len, :]
    pred = last_token_logits.argmax(dim=-1)
    total += label.size(0)
    hit += (pred == label).sum().item()
acc = hit / total
print('acc: ', acc)

### 2. Text Generation

In [None]:
import os
import time
import sys

# create session
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session = ort.InferenceSession('/path/to/model.onnx', sess_options)

# input prompt
# 32 tokens input
prompt = "Once upon a time, there existed a little girl, who liked to have adventures." + \
                 " She wanted to go to places and meet new people, and have fun."

print("prompt: ", prompt)

# start
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
for i in range(32):
    inp = {'input_ids': input_ids.detach().cpu().numpy(),
            'attention_mask': torch.ones(input_ids.shape).detach().cpu().numpy().astype('int64')}
    output = session.run(None, inp)
    logits = output[0]
    logits = torch.from_numpy(logits)
    next_token_logits = logits[:, -1, :]
    probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
    next_tokens = torch.argmax(probs, dim=-1)
    input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
print(tokenizer.decode(input_ids[0]))