Spaces:
Sleeping
Sleeping
File size: 4,244 Bytes
43c5517 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import torch
from diffusers import StableDiffusionInpaintPipeline
import os
def convert_to_onnx(model_path, output_dir):
os.makedirs(output_dir, exist_ok=True)
# Load the pipeline
pipe = StableDiffusionInpaintPipeline.from_single_file(
model_path
)
# Move to CPU and ensure float32
pipe = pipe.to("cpu")
pipe.to(torch.float32)
# Set to evaluation mode
pipe.unet.eval()
pipe.vae.eval()
pipe.text_encoder.eval()
# First convert the image through VAE to get correct latent dimensions
with torch.no_grad():
# Sample random latent in correct shape
latent_height = 64 # standard height for SD latents
latent_width = 64 # standard width for SD latents
# Create sample inputs for UNet
# The UNet expects concatenated latent + mask channels
latents = torch.randn(1, 4, latent_height, latent_width, dtype=torch.float32)
mask = torch.ones(1, 1, latent_height, latent_width, dtype=torch.float32)
masked_image_latents = torch.randn(1, 4, latent_height, latent_width, dtype=torch.float32)
masked_latents = torch.cat([latents, masked_image_latents, mask], dim=1) # 4 + 4 + 1 = 9 channels
# Time embeddings
timestep = torch.tensor([1], dtype=torch.int64)
# Text embeddings (77 is the standard sequence length)
text_embeddings = torch.randn(1, 77, 768, dtype=torch.float32)
# Export UNet
pipe.text_encoder.text_model.encoder.layers[0].self_attn.scale = torch.tensor(0.125, dtype=torch.float32)
torch.onnx.export(
pipe.unet,
args=(masked_latents, timestep, text_embeddings),
f=f"{output_dir}/unet.onnx",
input_names=["sample", "timestep", "encoder_hidden_states"],
output_names=["out_sample"],
dynamic_axes={
"sample": {0: "batch", 2: "height", 3: "width"},
"encoder_hidden_states": {0: "batch", 1: "sequence"},
"out_sample": {0: "batch", 2: "height", 3: "width"}
},
opset_version=17,
export_params=True
)
# Export VAE Decoder
vae_latent = torch.randn(1, 4, latent_height, latent_width, dtype=torch.float32)
torch.onnx.export(
pipe.vae.decoder,
args=(vae_latent,),
f=f"{output_dir}/vae_decoder.onnx",
input_names=["latent"],
output_names=["image"],
dynamic_axes={
"latent": {0: "batch", 2: "height", 3: "width"},
"image": {0: "batch", 2: "height", 3: "width"}
},
opset_version=17,
export_params=True
)
# Export Text Encoder
input_ids = torch.ones(1, 77, dtype=torch.int64)
torch.onnx.export(
pipe.text_encoder,
args=(input_ids,),
f=f"{output_dir}/text_encoder.onnx",
input_names=["input_ids"],
output_names=["last_hidden_state", "pooler_output"],
dynamic_axes={
"input_ids": {0: "batch"},
"last_hidden_state": {0: "batch"},
"pooler_output": {0: "batch"}
},
opset_version=17,
export_params=True
)
print("Conversion completed successfully!")
return True
def verify_paths(model_path):
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found at: {model_path}")
print(f"Model file found at: {model_path}")
return True
if __name__ == "__main__":
# Set your paths here
model_path = "realisticVisionV60B1_v51VAE-inpainting.safetensors"
output_dir = "onnx_output"
try:
verify_paths(model_path)
success = convert_to_onnx(model_path, output_dir)
if success:
print(f"ONNX models saved to: {output_dir}")
except Exception as e:
print(f"Error during conversion: {str(e)}")
raise # Re-raise the exception to see full traceback |