leaderboard / benchmark.py
Jae-Won Chung
Do not print input and output with rich
0d43f69
raw
history blame
8.5 kB
"""Perform inference of one model on one input prompt and measure time and energy."""
from __future__ import annotations
import os
import json
import copy
import atexit
from typing import Generator, Literal
import tyro
import torch
import rich
from rich.table import Table
from fastchat.serve.inference import generate_stream
from fastchat.model.model_adapter import load_model, get_conversation_template
from zeus.monitor import ZeusMonitor
SYSTEM_PROMPTS = {
"chat": (
"A chat between a human user (prompter) and an artificial intelligence (AI) assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
"chat-concise": (
"A chat between a human user (prompter) and an artificial intelligence (AI) assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions. "
"The assistnat's answers are concise but high-quality."
),
"instruct": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request."
),
"instruct-concise": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request."
"The response should be concise but high-quality."
),
}
def main(
model_path: str,
input_file: str,
device_index: int = 0,
task: Literal[tuple(SYSTEM_PROMPTS)] = "chat", # type: ignore
load_8bit: bool = False,
temperature: float = 0.7,
repitition_penalty: float = 1.0,
max_new_tokens: int = 512,
) -> None:
"""Run the main routine.
Code structure is based on
https://github.com/lm-sys/FastChat/blob/57dea54055/fastchat/serve/inference.py#L249
Args:
model_path: Path to or Huggingface Hub Id of the model.
input_file: Path to the input JSON file. Assumed to be our cleaned ShareGPT data.
device_index: Index of the GPU to use for inference.
task: Type of task to perform inference on.
load_8bit: Whether to load the model in 8-bit mode.
temperature: Temperature to use for sampling.
repitition_penalty: Repitition penalty to use for the model.
max_new_tokens: Maximum numbers of tokens to generate, ignoring the prompt.
"""
# NOTE(JW): ChatGLM is implemented as a special case in FastChat inference.
# Also, it's primarily a model that's fine-tuned for Chinese, so it doesn't
# make sense to prompt it in English and talk about its verbosity.
if "chatglm" in model_path.lower():
raise ValueError("ChatGLM is not supported.")
# Print out what we're about to do.
if model_path.endswith("/"):
model_path = model_path[:-1]
model_name_cleaned = "--".join(model_path.split("/")[-2:])
output_dir = f"data/{task}/{model_name_cleaned}"
output_csv_path = f"{output_dir}/benchmark.json"
config_json_path = f"{output_dir}/config.json"
table = Table(title="Benchmark")
table.add_column("Configuration")
table.add_column("Value")
table.add_row("Model", f"{model_name_cleaned} (path: {model_path})")
table.add_row("Input", input_file)
table.add_row("Device", f"cuda:{device_index}")
table.add_row("Task", task)
table.add_row("8-bit", str(load_8bit))
table.add_row("Temperature", f"{temperature:.2f}")
table.add_row("Repitition Penalty", f"{repitition_penalty:.2f}")
table.add_row("Max New Tokens", str(max_new_tokens))
table.add_row("Output CSV", output_csv_path)
table.add_row("Config JSON", config_json_path)
rich.get_console().print(table)
# Set the device.
torch.cuda.set_device(f"cuda:{device_index}")
# Load the model (Huggingface PyTorch) and tokenizer (Huggingface).
model, tokenizer = load_model(
model_path=model_path,
device="cuda",
num_gpus=1,
max_gpu_memory=None,
load_8bit=load_8bit,
cpu_offloading=False,
gptq_config=None,
debug=False,
)
# Chats are accumulated in a conversation helper object.
conv_base = get_conversation_template(model_path)
# Standardize the system prompt for every model.
conv_base.system = SYSTEM_PROMPTS[task]
conv_base.messages = []
conv_base.offset = 0
gen_params = {
"model": model_path,
"prompt": "EMPTY",
"temperature": temperature,
"repitition_penalty": repitition_penalty,
"max_new_tokens": max_new_tokens,
"stop": conv_base.stop_str,
"stop_token_ids": conv_base.stop_token_ids,
"echo": False,
}
monitor = ZeusMonitor(gpu_indices=[torch.cuda.current_device()])
# Output files.
# Leave only the last two path components and replace slashes with double dashes.
os.makedirs(output_dir, exist_ok=True)
output_json = open(output_csv_path, "w")
output_json.write("[\n")
output_json.flush()
# Conclude the JSON file format with a closing bracket. Using `atexit` will
# handle all cases of the program exiting, including Ctrl-C and errors.
atexit.register(lambda: output_json.write("\n]\n"))
# Dump the configuration to a JSON file.
with open(config_json_path, "w") as config_json:
json.dump(
{
"model_path": model_path,
"input_file": input_file,
"device_index": device_index,
"task": task,
"load_8bit": load_8bit,
"temperature": temperature,
"repitition_penalty": repitition_penalty,
"max_new_tokens": max_new_tokens,
},
config_json,
indent=4,
)
config_json.write("\n")
def dataloader(input_file: str) -> Generator[tuple[bool, str], None, None]:
"""Yields a tuple of whether this is a warmup run and the input prompt."""
for _ in range(3):
yield True, "Say something long and random. I don't care about the content."
for item in json.load(open(input_file, "r")):
input_prompt = item["conversations"][0]["value"]
yield False, input_prompt
# Warm up the GPU with some random prompts.
# Forward through all the prompts.
is_first = True
for is_warmup, input_prompt in dataloader(input_file):
# Construct the input prompt.
conv = copy.deepcopy(conv_base)
conv.append_message(conv.roles[0], input_prompt)
conv.append_message(conv.roles[1], "")
prompt = conv.get_prompt()
gen_params["prompt"] = prompt
# Print input prompt.
rich.print(f"\n[u cyan]{'Warmup ' if is_warmup else ''}Prompt[/u cyan]:")
print(prompt.strip() + "\n")
# Generate the ouptut from the model.
output_stream = generate_stream(model, tokenizer, gen_params, device="cuda")
output = {}
#################################################
# Inference and measurement zone!
#################################################
monitor.begin_window("inference")
for output in output_stream:
pass
measurements = monitor.end_window("inference")
#################################################
# Record numbers.
output_text = output["text"]
if not is_warmup:
response_length = len(tokenizer.encode(output_text)) # number of tokens
latency = measurements.time
throughput = response_length / latency
energy = measurements.total_energy
output = {
"model": model_name_cleaned,
"throughput": throughput,
"response_length": response_length,
"latency": latency,
"energy": energy,
"input": prompt.strip(),
"output": output_text.strip(),
}
output_str = json.dumps(output, indent=4)
if not is_warmup:
if not is_first:
output_json.write(",\n" + output_str)
else:
is_first = False
output_json.write(output_str)
output_json.flush()
# Print the response.
rich.print(f"\n[u cyan]{'Warmup ' if is_warmup else ''}Response[/u cyan]:")
print(output_text.strip() + "\n")
# Print measurement.
rich.print(measurements)
if __name__ == "__main__":
tyro.cli(main)