File size: 2,248 Bytes
85e3d20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoConfig
from peft import prepare_model_for_int8_training
from peft import LoraConfig, get_peft_model, TaskType, get_peft_config
from peft import PeftModel, PeftConfig
import torch
from accelerate import Accelerator
from datasets import load_from_disk
import numpy as np
import datasets
from torch.utils.data import DataLoader
from transformers import default_data_collator
import argparse
from transformers import LlamaForCausalLM
import time
from datasets import load_dataset


#### DO NOT EDIT ######

generation_length = 1
context_length = 128

tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf").to("cuda")
eval_dataset = load_dataset("wikitext", 'wikitext-103-v1', split="test")

# tokenize the dataset and filter out examples that are shorter than the context length
def tokenize_and_filter_function(examples):
    tokenized_examples = tokenizer(examples["text"], truncation=True, max_length=context_length)
    # only keep the examples where the context is not too long
    result = {
        "input_ids": [],
        "attention_mask": [],
    }
    for i, input_ids in enumerate(tokenized_examples["input_ids"]):
        if len(input_ids) == context_length:
            result["input_ids"].append(input_ids)
            result["attention_mask"].append(tokenized_examples["attention_mask"][i])
    return result

eval_dataset = eval_dataset.map(tokenize_and_filter_function, batched=True, num_proc=4, remove_columns=["text"])


#################

batch_size = 4

eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size)

with torch.no_grad():
    model.eval()
    # record average step time
    total_time = 0
    for idx, batch in enumerate(eval_dataloader):
        if idx == 100:
            break
        input_ids = batch["input_ids"].to("cuda")
        start_time = time.time()
        outputs = model.generate(
            input_ids=input_ids,
            max_length= generation_length + context_length,
        )
        total_time += time.time() - start_time
    print("Average per token generation time: ", total_time / 100)