Spaces:
Sleeping
Sleeping
import torch | |
from diffusers import StableDiffusionInpaintPipeline | |
import os | |
def convert_to_onnx(model_path, output_dir): | |
os.makedirs(output_dir, exist_ok=True) | |
# Load the pipeline | |
pipe = StableDiffusionInpaintPipeline.from_single_file( | |
model_path | |
) | |
# Move to CPU and ensure float32 | |
pipe = pipe.to("cpu") | |
pipe.to(torch.float32) | |
# Set to evaluation mode | |
pipe.unet.eval() | |
pipe.vae.eval() | |
pipe.text_encoder.eval() | |
# First convert the image through VAE to get correct latent dimensions | |
with torch.no_grad(): | |
# Sample random latent in correct shape | |
latent_height = 64 # standard height for SD latents | |
latent_width = 64 # standard width for SD latents | |
# Create sample inputs for UNet | |
# The UNet expects concatenated latent + mask channels | |
latents = torch.randn(1, 4, latent_height, latent_width, dtype=torch.float32) | |
mask = torch.ones(1, 1, latent_height, latent_width, dtype=torch.float32) | |
masked_image_latents = torch.randn(1, 4, latent_height, latent_width, dtype=torch.float32) | |
masked_latents = torch.cat([latents, masked_image_latents, mask], dim=1) # 4 + 4 + 1 = 9 channels | |
# Time embeddings | |
timestep = torch.tensor([1], dtype=torch.int64) | |
# Text embeddings (77 is the standard sequence length) | |
text_embeddings = torch.randn(1, 77, 768, dtype=torch.float32) | |
# Export UNet | |
pipe.text_encoder.text_model.encoder.layers[0].self_attn.scale = torch.tensor(0.125, dtype=torch.float32) | |
torch.onnx.export( | |
pipe.unet, | |
args=(masked_latents, timestep, text_embeddings), | |
f=f"{output_dir}/unet.onnx", | |
input_names=["sample", "timestep", "encoder_hidden_states"], | |
output_names=["out_sample"], | |
dynamic_axes={ | |
"sample": {0: "batch", 2: "height", 3: "width"}, | |
"encoder_hidden_states": {0: "batch", 1: "sequence"}, | |
"out_sample": {0: "batch", 2: "height", 3: "width"} | |
}, | |
opset_version=17, | |
export_params=True | |
) | |
# Export VAE Decoder | |
vae_latent = torch.randn(1, 4, latent_height, latent_width, dtype=torch.float32) | |
torch.onnx.export( | |
pipe.vae.decoder, | |
args=(vae_latent,), | |
f=f"{output_dir}/vae_decoder.onnx", | |
input_names=["latent"], | |
output_names=["image"], | |
dynamic_axes={ | |
"latent": {0: "batch", 2: "height", 3: "width"}, | |
"image": {0: "batch", 2: "height", 3: "width"} | |
}, | |
opset_version=17, | |
export_params=True | |
) | |
# Export Text Encoder | |
input_ids = torch.ones(1, 77, dtype=torch.int64) | |
torch.onnx.export( | |
pipe.text_encoder, | |
args=(input_ids,), | |
f=f"{output_dir}/text_encoder.onnx", | |
input_names=["input_ids"], | |
output_names=["last_hidden_state", "pooler_output"], | |
dynamic_axes={ | |
"input_ids": {0: "batch"}, | |
"last_hidden_state": {0: "batch"}, | |
"pooler_output": {0: "batch"} | |
}, | |
opset_version=17, | |
export_params=True | |
) | |
print("Conversion completed successfully!") | |
return True | |
def verify_paths(model_path): | |
if not os.path.exists(model_path): | |
raise FileNotFoundError(f"Model file not found at: {model_path}") | |
print(f"Model file found at: {model_path}") | |
return True | |
if __name__ == "__main__": | |
# Set your paths here | |
model_path = "realisticVisionV60B1_v51VAE-inpainting.safetensors" | |
output_dir = "onnx_output" | |
try: | |
verify_paths(model_path) | |
success = convert_to_onnx(model_path, output_dir) | |
if success: | |
print(f"ONNX models saved to: {output_dir}") | |
except Exception as e: | |
print(f"Error during conversion: {str(e)}") | |
raise # Re-raise the exception to see full traceback |