Spaces:
Sleeping
Sleeping
| import torch | |
| from diffusers import StableDiffusionInpaintPipeline | |
| import os | |
| def convert_to_onnx(model_path, output_dir): | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Load the pipeline | |
| pipe = StableDiffusionInpaintPipeline.from_single_file( | |
| model_path | |
| ) | |
| # Move to CPU and ensure float32 | |
| pipe = pipe.to("cpu") | |
| pipe.to(torch.float32) | |
| # Set to evaluation mode | |
| pipe.unet.eval() | |
| pipe.vae.eval() | |
| pipe.text_encoder.eval() | |
| # First convert the image through VAE to get correct latent dimensions | |
| with torch.no_grad(): | |
| # Sample random latent in correct shape | |
| latent_height = 64 # standard height for SD latents | |
| latent_width = 64 # standard width for SD latents | |
| # Create sample inputs for UNet | |
| # The UNet expects concatenated latent + mask channels | |
| latents = torch.randn(1, 4, latent_height, latent_width, dtype=torch.float32) | |
| mask = torch.ones(1, 1, latent_height, latent_width, dtype=torch.float32) | |
| masked_image_latents = torch.randn(1, 4, latent_height, latent_width, dtype=torch.float32) | |
| masked_latents = torch.cat([latents, masked_image_latents, mask], dim=1) # 4 + 4 + 1 = 9 channels | |
| # Time embeddings | |
| timestep = torch.tensor([1], dtype=torch.int64) | |
| # Text embeddings (77 is the standard sequence length) | |
| text_embeddings = torch.randn(1, 77, 768, dtype=torch.float32) | |
| # Export UNet | |
| pipe.text_encoder.text_model.encoder.layers[0].self_attn.scale = torch.tensor(0.125, dtype=torch.float32) | |
| torch.onnx.export( | |
| pipe.unet, | |
| args=(masked_latents, timestep, text_embeddings), | |
| f=f"{output_dir}/unet.onnx", | |
| input_names=["sample", "timestep", "encoder_hidden_states"], | |
| output_names=["out_sample"], | |
| dynamic_axes={ | |
| "sample": {0: "batch", 2: "height", 3: "width"}, | |
| "encoder_hidden_states": {0: "batch", 1: "sequence"}, | |
| "out_sample": {0: "batch", 2: "height", 3: "width"} | |
| }, | |
| opset_version=17, | |
| export_params=True | |
| ) | |
| # Export VAE Decoder | |
| vae_latent = torch.randn(1, 4, latent_height, latent_width, dtype=torch.float32) | |
| torch.onnx.export( | |
| pipe.vae.decoder, | |
| args=(vae_latent,), | |
| f=f"{output_dir}/vae_decoder.onnx", | |
| input_names=["latent"], | |
| output_names=["image"], | |
| dynamic_axes={ | |
| "latent": {0: "batch", 2: "height", 3: "width"}, | |
| "image": {0: "batch", 2: "height", 3: "width"} | |
| }, | |
| opset_version=17, | |
| export_params=True | |
| ) | |
| # Export Text Encoder | |
| input_ids = torch.ones(1, 77, dtype=torch.int64) | |
| torch.onnx.export( | |
| pipe.text_encoder, | |
| args=(input_ids,), | |
| f=f"{output_dir}/text_encoder.onnx", | |
| input_names=["input_ids"], | |
| output_names=["last_hidden_state", "pooler_output"], | |
| dynamic_axes={ | |
| "input_ids": {0: "batch"}, | |
| "last_hidden_state": {0: "batch"}, | |
| "pooler_output": {0: "batch"} | |
| }, | |
| opset_version=17, | |
| export_params=True | |
| ) | |
| print("Conversion completed successfully!") | |
| return True | |
| def verify_paths(model_path): | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"Model file not found at: {model_path}") | |
| print(f"Model file found at: {model_path}") | |
| return True | |
| if __name__ == "__main__": | |
| # Set your paths here | |
| model_path = "realisticVisionV60B1_v51VAE-inpainting.safetensors" | |
| output_dir = "onnx_output" | |
| try: | |
| verify_paths(model_path) | |
| success = convert_to_onnx(model_path, output_dir) | |
| if success: | |
| print(f"ONNX models saved to: {output_dir}") | |
| except Exception as e: | |
| print(f"Error during conversion: {str(e)}") | |
| raise # Re-raise the exception to see full traceback |