Spaces:
Running
Running
from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoConfig | |
from peft import prepare_model_for_int8_training | |
from peft import LoraConfig, get_peft_model, TaskType, get_peft_config | |
from peft import PeftModel, PeftConfig | |
import torch | |
from accelerate import Accelerator | |
from datasets import load_from_disk | |
import numpy as np | |
import datasets | |
from torch.utils.data import DataLoader | |
from transformers import default_data_collator | |
import argparse | |
from transformers import LlamaForCausalLM | |
import time | |
from datasets import load_dataset | |
#### DO NOT EDIT ###### | |
generation_length = 1 | |
context_length = 128 | |
tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf") | |
model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf").to("cuda") | |
eval_dataset = load_dataset("wikitext", 'wikitext-103-v1', split="test") | |
# tokenize the dataset and filter out examples that are shorter than the context length | |
def tokenize_and_filter_function(examples): | |
tokenized_examples = tokenizer(examples["text"], truncation=True, max_length=context_length) | |
# only keep the examples where the context is not too long | |
result = { | |
"input_ids": [], | |
"attention_mask": [], | |
} | |
for i, input_ids in enumerate(tokenized_examples["input_ids"]): | |
if len(input_ids) == context_length: | |
result["input_ids"].append(input_ids) | |
result["attention_mask"].append(tokenized_examples["attention_mask"][i]) | |
return result | |
eval_dataset = eval_dataset.map(tokenize_and_filter_function, batched=True, num_proc=4, remove_columns=["text"]) | |
################# | |
batch_size = 4 | |
eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size) | |
with torch.no_grad(): | |
model.eval() | |
# record average step time | |
total_time = 0 | |
for idx, batch in enumerate(eval_dataloader): | |
if idx == 100: | |
break | |
input_ids = batch["input_ids"].to("cuda") | |
start_time = time.time() | |
outputs = model.generate( | |
input_ids=input_ids, | |
max_length= generation_length + context_length, | |
) | |
total_time += time.time() - start_time | |
print("Average per token generation time: ", total_time / 100) |