Spaces:
Runtime error
Runtime error
| # ---- INT8 (optional) ---- | |
| from demo_utils.vae import ( | |
| VAEDecoderWrapperSingle, # main nn.Module | |
| ZERO_VAE_CACHE # helper constants shipped with your code base | |
| ) | |
| import pycuda.driver as cuda # β add | |
| import pycuda.autoinit # noqa | |
| import sys | |
| from pathlib import Path | |
| import torch | |
| import tensorrt as trt | |
| from utils.dataset import ShardingLMDBDataset | |
| data_path = "/mnt/localssd/wanx_14B_shift-3.0_cfg-5.0_lmdb_oneshard" | |
| dataset = ShardingLMDBDataset(data_path, max_pair=int(1e8)) | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=1, | |
| num_workers=0 | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 1οΈβ£ Bring the PyTorch model into scope | |
| # (all code you pasted lives in `vae_decoder.py`) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # --- dummy tensors (exact shapes you posted) --- | |
| dummy_input = torch.randn(1, 1, 16, 60, 104).half().cuda() | |
| is_first_frame = torch.tensor([1.0], device="cuda", dtype=torch.float16) | |
| dummy_cache_input = [ | |
| torch.randn(*s.shape).half().cuda() if isinstance(s, torch.Tensor) else s | |
| for s in ZERO_VAE_CACHE # keep exactly the same ordering | |
| ] | |
| inputs = [dummy_input, is_first_frame, *dummy_cache_input] | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 2οΈβ£ Export β ONNX | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| model = VAEDecoderWrapperSingle().half().cuda().eval() | |
| vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu") | |
| decoder_state_dict = {} | |
| for key, value in vae_state_dict.items(): | |
| if 'decoder.' in key or 'conv2' in key: | |
| decoder_state_dict[key] = value | |
| model.load_state_dict(decoder_state_dict) | |
| model = model.half().cuda().eval() # only batch dim dynamic | |
| onnx_path = Path("vae_decoder.onnx") | |
| feat_names = [f"vae_cache_{i}" for i in range(len(dummy_cache_input))] | |
| all_inputs_names = ["z", "use_cache"] + feat_names | |
| with torch.inference_mode(): | |
| torch.onnx.export( | |
| model, | |
| tuple(inputs), # must be a tuple | |
| onnx_path.as_posix(), | |
| input_names=all_inputs_names, | |
| output_names=["rgb_out", "cache_out"], | |
| opset_version=17, | |
| do_constant_folding=True, | |
| dynamo=True | |
| ) | |
| print(f"β ONNX graph saved to {onnx_path.resolve()}") | |
| # (Optional) quick sanity-check with ONNX-Runtime | |
| try: | |
| import onnxruntime as ort | |
| sess = ort.InferenceSession(onnx_path.as_posix(), | |
| providers=["CUDAExecutionProvider"]) | |
| ort_inputs = {n: t.cpu().numpy() for n, t in zip(all_inputs_names, inputs)} | |
| _ = sess.run(None, ort_inputs) | |
| print("β ONNX graph is executable") | |
| except Exception as e: | |
| print("β οΈ ONNX check failed:", e) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 3οΈβ£ Build the TensorRT engine | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| TRT_LOGGER = trt.Logger(trt.Logger.WARNING) | |
| builder = trt.Builder(TRT_LOGGER) | |
| network = builder.create_network( | |
| 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) | |
| parser = trt.OnnxParser(network, TRT_LOGGER) | |
| with open(onnx_path, "rb") as f: | |
| if not parser.parse(f.read()): | |
| for i in range(parser.num_errors): | |
| print(parser.get_error(i)) | |
| sys.exit("β ONNX β TRT parsing failed") | |
| config = builder.create_builder_config() | |
| def set_workspace(config, bytes_): | |
| """Version-agnostic workspace limit.""" | |
| if hasattr(config, "max_workspace_size"): # TRT 8 / 9 | |
| config.max_workspace_size = bytes_ | |
| else: # TRT 10+ | |
| config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, bytes_) | |
| # β¦ | |
| config = builder.create_builder_config() | |
| set_workspace(config, 4 << 30) # 4 GB | |
| # 4 GB | |
| if builder.platform_has_fast_fp16: | |
| config.set_flag(trt.BuilderFlag.FP16) | |
| # ---- INT8 (optional) ---- | |
| # provide a calibrator if you need an INT8 engine; comment this | |
| # block if you only care about FP16. | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # helper: version-agnostic workspace limit | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def set_workspace(config: trt.IBuilderConfig, bytes_: int = 4 << 30): | |
| """ | |
| TRT < 10.x β config.max_workspace_size | |
| TRT β₯ 10.x β config.set_memory_pool_limit(...) | |
| """ | |
| if hasattr(config, "max_workspace_size"): # TRT 8 / 9 | |
| config.max_workspace_size = bytes_ | |
| else: # TRT 10+ | |
| config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, | |
| bytes_) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # (optional) INT-8 calibrator | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # βΌ Only keep this block if you really need INT-8 βΌ # gracefully skip if PyCUDA not present | |
| class VAECalibrator(trt.IInt8EntropyCalibrator2): | |
| def __init__(self, loader, cache="calibration.cache", max_batches=10): | |
| super().__init__() | |
| self.loader = iter(loader) | |
| self.batch_size = loader.batch_size or 1 | |
| self.max_batches = max_batches | |
| self.count = 0 | |
| self.cache_file = cache | |
| self.stream = cuda.Stream() | |
| self.dev_ptrs = {} | |
| # --- TRT 10 needs BOTH spellings --- | |
| def get_batch_size(self): | |
| return self.batch_size | |
| def getBatchSize(self): | |
| return self.batch_size | |
| def get_batch(self, names): | |
| if self.count >= self.max_batches: | |
| return None | |
| # Randomly sample a number from 1 to 10 | |
| import random | |
| vae_idx = random.randint(0, 10) | |
| data = next(self.loader) | |
| latent = data['ode_latent'][0][:, :1] | |
| is_first_frame = torch.tensor([1.0], device="cuda", dtype=torch.float16) | |
| feat_cache = ZERO_VAE_CACHE | |
| for i in range(vae_idx): | |
| inputs = [latent, is_first_frame, *feat_cache] | |
| with torch.inference_mode(): | |
| outputs = model(*inputs) | |
| latent = data['ode_latent'][0][:, i + 1:i + 2] | |
| is_first_frame = torch.tensor([0.0], device="cuda", dtype=torch.float16) | |
| feat_cache = outputs[1:] | |
| # -------- ensure context is current -------- | |
| z_np = latent.cpu().numpy().astype('float32') | |
| ptrs = [] # list[int] β one entry per name | |
| for name in names: # <-- match TRT's binding order | |
| if name == "z": | |
| arr = z_np | |
| elif name == "use_cache": | |
| arr = is_first_frame.cpu().numpy().astype('float32') | |
| else: | |
| idx = int(name.split('_')[-1]) # "vae_cache_17" -> 17 | |
| arr = feat_cache[idx].cpu().numpy().astype('float32') | |
| if name not in self.dev_ptrs: | |
| self.dev_ptrs[name] = cuda.mem_alloc(arr.nbytes) | |
| cuda.memcpy_htod_async(self.dev_ptrs[name], arr, self.stream) | |
| ptrs.append(int(self.dev_ptrs[name])) # ***int() is required*** | |
| self.stream.synchronize() | |
| self.count += 1 | |
| print(f"Calibration batch {self.count}/{self.max_batches}") | |
| return ptrs | |
| # --- calibration-cache helpers (both spellings) --- | |
| def read_calibration_cache(self): | |
| try: | |
| with open(self.cache_file, "rb") as f: | |
| return f.read() | |
| except FileNotFoundError: | |
| return None | |
| def readCalibrationCache(self): | |
| return self.read_calibration_cache() | |
| def write_calibration_cache(self, cache): | |
| with open(self.cache_file, "wb") as f: | |
| f.write(cache) | |
| def writeCalibrationCache(self, cache): | |
| self.write_calibration_cache(cache) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Builder-config + optimisation profile | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| config = builder.create_builder_config() | |
| set_workspace(config, 4 << 30) # 4 GB | |
| # βΊ enable FP16 if possible | |
| if builder.platform_has_fast_fp16: | |
| config.set_flag(trt.BuilderFlag.FP16) | |
| # βΊ enable INT-8 (delete this block if you donβt need it) | |
| if cuda is not None: | |
| config.set_flag(trt.BuilderFlag.INT8) | |
| # supply any representative batch you like β here we reuse the latent z | |
| calib = VAECalibrator(dataloader) | |
| # TRT-10 renamed the setter: | |
| if hasattr(config, "set_int8_calibrator"): # TRT 10+ | |
| config.set_int8_calibrator(calib) | |
| else: # TRT β€ 9 | |
| config.int8_calibrator = calib | |
| # ---- optimisation profile ---- | |
| profile = builder.create_optimization_profile() | |
| profile.set_shape(all_inputs_names[0], # latent z | |
| min=(1, 1, 16, 60, 104), | |
| opt=(1, 1, 16, 60, 104), | |
| max=(1, 1, 16, 60, 104)) | |
| profile.set_shape("use_cache", # scalar flag | |
| min=(1,), opt=(1,), max=(1,)) | |
| for name, tensor in zip(all_inputs_names[2:], dummy_cache_input): | |
| profile.set_shape(name, tensor.shape, tensor.shape, tensor.shape) | |
| config.add_optimization_profile(profile) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Build the engine (API changed in TRT-10) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("βοΈ Building engine β¦ (can take a minute)") | |
| if hasattr(builder, "build_serialized_network"): # TRT 10+ | |
| serialized_engine = builder.build_serialized_network(network, config) | |
| assert serialized_engine is not None, "build_serialized_network() failed" | |
| plan_path = Path("checkpoints/vae_decoder_int8.trt") | |
| plan_path.write_bytes(serialized_engine) | |
| engine_bytes = serialized_engine # keep for smoke-test | |
| else: # TRT β€ 9 | |
| engine = builder.build_engine(network, config) | |
| assert engine is not None, "build_engine() returned None" | |
| plan_path = Path("checkpoints/vae_decoder_int8.trt") | |
| plan_path.write_bytes(engine.serialize()) | |
| engine_bytes = engine.serialize() | |
| print(f"β TensorRT engine written to {plan_path.resolve()}") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 4οΈβ£ Quick smoke-test with the brand-new engine | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with trt.Runtime(TRT_LOGGER) as rt: | |
| engine = rt.deserialize_cuda_engine(engine_bytes) | |
| context = engine.create_execution_context() | |
| stream = torch.cuda.current_stream().cuda_stream | |
| # pre-allocate device buffers once | |
| device_buffers, outputs = {}, [] | |
| dtype_map = {trt.float32: torch.float32, | |
| trt.float16: torch.float16, | |
| trt.int8: torch.int8, | |
| trt.int32: torch.int32} | |
| for name, tensor in zip(all_inputs_names, inputs): | |
| if -1 in engine.get_tensor_shape(name): # dynamic input | |
| context.set_input_shape(name, tensor.shape) | |
| context.set_tensor_address(name, int(tensor.data_ptr())) | |
| device_buffers[name] = tensor | |
| context.infer_shapes() # propagate β’ outputs | |
| for i in range(engine.num_io_tensors): | |
| name = engine.get_tensor_name(i) | |
| if engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT: | |
| shape = tuple(context.get_tensor_shape(name)) | |
| dtype = dtype_map[engine.get_tensor_dtype(name)] | |
| out = torch.empty(shape, dtype=dtype, device="cuda") | |
| context.set_tensor_address(name, int(out.data_ptr())) | |
| outputs.append(out) | |
| print(f"output {name} shape: {shape}") | |
| context.execute_async_v3(stream_handle=stream) | |
| torch.cuda.current_stream().synchronize() | |
| print("β TRT execution OK β first output shape:", outputs[0].shape) | |