Jae-Won Chung
Updated diffusion benchmark and data
c97bae1
raw
history blame
14.2 kB
from __future__ import annotations
import os
import time
import json
import argparse
import multiprocessing as mp
from pprint import pprint
from pathlib import Path
from contextlib import suppress
from dataclasses import dataclass, field, asdict
import torch
import pynvml
import numpy as np
import pandas as pd
from PIL import Image
from datasets import load_dataset, Dataset
from transformers.trainer_utils import set_seed
from transformers import CLIPModel, CLIPProcessor
from diffusers import (
ModelMixin, # type: ignore
AutoPipelineForText2Image, # type: ignore
DiffusionPipeline, # type: ignore
StableCascadeCombinedPipeline, # type: ignore
)
from zeus.monitor import ZeusMonitor
# Disable torch gradients globally
torch.set_grad_enabled(False)
CLIP = "openai/clip-vit-large-patch14"
@dataclass
class Results:
model: str
num_parameters: dict[str, int]
gpu_model: str
power_limit: int
batch_size: int
num_inference_steps: int
num_prompts: int
average_clip_score: float = 0.0
total_runtime: float = 0.0
total_energy: float = 0.0
average_batch_latency: float = 0.0
average_images_per_second: float = 0.0
average_batch_energy: float = 0.0
average_power_consumption: float = 0.0
peak_memory: float = 0.0
results: list[Result] = field(default_factory=list, repr=False)
@dataclass
class ResultIntermediateBatched:
batch_latency: float = 0.0
batch_energy: float = 0.0
prompts: list[str] = field(default_factory=list)
images: np.ndarray = np.empty(0)
@dataclass
class Result:
batch_latency: float
sample_energy: float
prompt: str
image_path: str | None
clip_score: float
def get_pipeline(model_id: str):
"""Instantiate a Diffusers pipeline from a modes's HuggingFace Hub ID."""
# Load args to give to `from_pretrained` from the model's kwargs.json file
kwargs = json.load(open(f"models/{model_id}/kwargs.json"))
with suppress(KeyError):
kwargs["torch_dtype"] = eval(kwargs["torch_dtype"])
# Add additional args
kwargs["safety_checker"] = None
kwargs["revision"] = open(f"models/{model_id}/revision.txt").read().strip()
# Hack for stable-cascade, which defaults to only a part of the model.
if model_id == "stabilityai/stable-cascade":
pipeline = StableCascadeCombinedPipeline.from_pretrained(model_id, **kwargs).to("cuda:0")
print("\nInstantiated pipeline via StableCascadeCombinedPipeline:\n", pipeline)
else:
try:
pipeline = AutoPipelineForText2Image.from_pretrained(model_id, **kwargs).to("cuda:0")
print("\nInstantiated pipeline via AutoPipelineForText2Image:\n", pipeline)
except ValueError:
pipeline = DiffusionPipeline.from_pretrained(model_id, **kwargs).to("cuda:0")
print("\nInstantiated pipeline via DiffusionPipeline:\n", pipeline)
return pipeline
def load_partiprompts(
batch_size: int,
seed: int,
num_batches: int | None = None,
) -> tuple[int, list[list[str]]]:
"""Load the parti-prompts dataset and return it as a list of batches of prompts.
Depending on the batch size, the final batch may not be full. The final batch
is dropped in that case. If `num_batches` is not None, only that many batches
is returned. If `num_batches` is None, all batches are returned.
Returns:
Total number of prompts and a list of batches of prompts.
"""
dataset = load_dataset("nateraw/parti-prompts", split="train").shuffle(seed=seed)
assert isinstance(dataset, Dataset)
if num_batches is not None:
dataset = dataset.select(range(min(num_batches * batch_size, len(dataset))))
prompts: list[str] = dataset["Prompt"]
batched = [prompts[i : i + batch_size] for i in range(0, len(prompts), batch_size)]
if len(batched[-1]) < batch_size:
batched.pop()
return len(batched) * batch_size, batched
def power_monitor(csv_path: str, gpu_indices: list[int], chan: mp.SimpleQueue) -> None:
pynvml.nvmlInit()
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in gpu_indices]
fields = [
(pynvml.NVML_FI_DEV_POWER_AVERAGE, pynvml.NVML_POWER_SCOPE_GPU),
(pynvml.NVML_FI_DEV_POWER_AVERAGE, pynvml.NVML_POWER_SCOPE_MEMORY),
]
columns = ["timestamp"] + sum([[f"gpu{i}", f"vram{i}"] for i in gpu_indices], [])
power: list[list] = []
while chan.empty():
row = [time.monotonic()]
values = [pynvml.nvmlDeviceGetFieldValues(h, fields) for h in handles]
for value in values:
row.extend((value[0].value.uiVal, value[1].value.uiVal))
power.append(row)
time.sleep(max(0.0, 0.1 - (time.monotonic() - row[0])))
pd.DataFrame(power, columns=columns).to_csv(csv_path, index=False)
def calculate_clip_score(
model: CLIPModel,
processor: CLIPProcessor,
images_np: np.ndarray,
text: list[str],
) -> torch.Tensor:
"""Calculate the CLIP score for each image and prompt pair.
`images_np` is assumed to be already scaled to [0, 255] and in uint8 format.
Returns:
The clip score of each image and prompt as a list of floats.
Tensor shape is (batch size,).
"""
model = model.to("cuda:0")
images = list(torch.from_numpy(images_np).permute(0, 3, 1, 2))
assert len(images) == len(text)
processed_input = processor(text=text, images=images, return_tensors="pt", padding=True)
img_features = model.get_image_features(processed_input["pixel_values"].to("cuda:0"))
img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True)
max_position_embeddings = model.config.text_config.max_position_embeddings
if processed_input["attention_mask"].shape[-1] > max_position_embeddings:
print(
f"Input attention mask is larger than max_position_embeddings. "
f"Truncating the attention mask to {max_position_embeddings}."
)
processed_input["attention_mask"] = processed_input["attention_mask"][..., :max_position_embeddings]
processed_input["input_ids"] = processed_input["input_ids"][..., :max_position_embeddings]
txt_features = model.get_text_features(
processed_input["input_ids"].to("cuda:0"), processed_input["attention_mask"].to("cuda:0")
)
txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True)
scores = 100 * (img_features * txt_features).sum(axis=-1)
scores = torch.max(scores, torch.zeros_like(scores))
return scores
def count_parameters(pipeline) -> dict[str, int]:
"""Count the number of parameters in the given pipeline."""
num_params = {}
for name, attr in vars(pipeline).items():
if isinstance(attr, ModelMixin):
num_params[name] = attr.num_parameters(only_trainable=False, exclude_embeddings=True)
elif isinstance(attr, torch.nn.Module):
num_params[name] = sum(p.numel() for p in attr.parameters())
return num_params
def benchmark(args: argparse.Namespace) -> None:
os.environ["HF_TOKEN"] = args.huggingface_token
if args.model.startswith("models/"):
args.model = args.model[len("models/") :]
if args.model.endswith("/"):
args.model = args.model[:-1]
set_seed(args.seed)
results_dir = Path(args.result_root) / args.model
results_dir.mkdir(parents=True, exist_ok=True)
benchmark_name = str(results_dir / f"bs{args.batch_size}+pl{args.power_limit}+steps{args.num_inference_steps}")
image_dir = results_dir / f"bs{args.batch_size}+pl{args.power_limit}+steps{args.num_inference_steps}+generated"
image_dir.mkdir(exist_ok=True)
arg_out_filename = f"{benchmark_name}+args.json"
with open(arg_out_filename, "w") as f:
f.write(json.dumps(vars(args), indent=2))
print(args)
print("Benchmark args written to", arg_out_filename)
zeus_monitor = ZeusMonitor()
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
gpu_model = pynvml.nvmlDeviceGetName(handle)
pynvml.nvmlDeviceSetPersistenceMode(handle, pynvml.NVML_FEATURE_ENABLED)
pynvml.nvmlDeviceSetPowerManagementLimit(handle, args.power_limit * 1000)
pynvml.nvmlShutdown()
num_prompts, batched_prompts = load_partiprompts(args.batch_size, args.seed, args.num_batches)
pipeline = get_pipeline(args.model)
# Warmup
print("Warming up with five batches...")
for i in range(5):
_ = pipeline(
batched_prompts[i],
num_inference_steps=args.num_inference_steps,
output_type="np",
)
rng = torch.manual_seed(args.seed)
images = []
intermediates: list[ResultIntermediateBatched] = [
ResultIntermediateBatched(prompts=batch) for batch in batched_prompts
]
pmon = None
pmon_chan = None
if args.monitor_power:
pmon_chan = mp.SimpleQueue()
pmon = mp.get_context("spawn").Process(
target=power_monitor,
args=(f"{benchmark_name}+power.csv", [g.gpu_index for g in zeus_monitor.gpus.gpus], pmon_chan),
)
pmon.start()
torch.cuda.reset_peak_memory_stats(device="cuda:0")
zeus_monitor.begin_window("benchmark", sync_execution=False)
for ind, intermediate in enumerate(intermediates):
print(f"Batch {ind + 1}/{len(intermediates)}")
zeus_monitor.begin_window("batch", sync_execution=False)
images = pipeline(
intermediate.prompts,
generator=rng,
num_inference_steps=args.num_inference_steps,
output_type="np",
).images
batch_measurements = zeus_monitor.end_window("batch", sync_execution=False)
intermediate.images = images
intermediate.batch_latency = batch_measurements.time
intermediate.batch_energy = batch_measurements.total_energy
measurements = zeus_monitor.end_window("benchmark", sync_execution=False)
peak_memory = torch.cuda.max_memory_allocated(device="cuda:0")
if pmon is not None and pmon_chan is not None:
pmon_chan.put("stop")
pmon.join(timeout=5.0)
pmon.terminate()
# Scale images to [0, 256] and convert to uint8
for intermediate in intermediates:
intermediate.images = (intermediate.images * 255).astype("uint8")
# Compute the CLIP score for each image and prompt pair.
# Code was mostly inspired from torchmetrics.multimodal.clip_score, but
# adapted here to calculate the CLIP score for each image and prompt pair.
clip_model: CLIPModel = CLIPModel.from_pretrained(CLIP).cuda() # type: ignore
clip_processor: CLIPProcessor = CLIPProcessor.from_pretrained(CLIP) # type: ignore
batch_clip_scores = []
for intermediate in intermediates:
clip_score = calculate_clip_score(
clip_model,
clip_processor,
intermediate.images,
intermediate.prompts,
)
batch_clip_scores.append(clip_score.tolist())
results: list[Result] = []
ind = 0
for intermediate, batch_clip_score in zip(intermediates, batch_clip_scores, strict=True):
for image, prompt, clip_score in zip(
intermediate.images,
intermediate.prompts,
batch_clip_score,
strict=True,
):
if ind % args.image_save_every == 0:
image_path = str(image_dir / f"{prompt}.png")
Image.fromarray(image).save(image_path)
else:
image_path = None
results.append(
Result(
batch_latency=intermediate.batch_latency,
sample_energy=intermediate.batch_energy / len(intermediate.prompts),
prompt=prompt,
image_path=image_path,
clip_score=clip_score,
)
)
ind += 1
final_results = Results(
model=args.model,
num_parameters=count_parameters(pipeline),
gpu_model=gpu_model,
power_limit=args.power_limit,
batch_size=args.batch_size,
num_inference_steps=args.num_inference_steps,
num_prompts=num_prompts,
average_clip_score=sum(r.clip_score for r in results) / len(results),
total_runtime=measurements.time,
total_energy=measurements.total_energy,
average_batch_latency=measurements.time / len(batched_prompts),
average_images_per_second=num_prompts / measurements.time,
average_batch_energy=measurements.total_energy / len(batched_prompts),
average_power_consumption=measurements.total_energy / measurements.time,
peak_memory=peak_memory,
results=results,
)
with open(f"{benchmark_name}+results.json", "w") as f:
f.write(json.dumps(asdict(final_results), indent=2))
print("Benchmark results written to", f"{benchmark_name}+results.json")
print("Benchmark results:")
pprint(final_results)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True, help="The model to benchmark.")
parser.add_argument("--result-root", type=str, help="The root directory to save results to.")
parser.add_argument("--batch-size", type=int, default=1, help="The size of each batch of prompts.")
parser.add_argument("--power-limit", type=int, default=300, help="The power limit to set for the GPU in Watts.")
parser.add_argument("--num-inference-steps", type=int, default=50, help="The number of denoising steps.")
parser.add_argument("--num-batches", type=int, default=None, help="The number of batches to use from the dataset.")
parser.add_argument("--image-save-every", type=int, default=10, help="Save images to file every N prompts.")
parser.add_argument("--seed", type=int, default=0, help="The seed to use for the RNG.")
parser.add_argument("--huggingface-token", type=str, help="The HuggingFace token to use.")
parser.add_argument("--monitor-power", default=False, action="store_true", help="Whether to monitor power over time.")
args = parser.parse_args()
benchmark(args)