File size: 4,244 Bytes
43c5517
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
from diffusers import StableDiffusionInpaintPipeline
import os

def convert_to_onnx(model_path, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    
    # Load the pipeline
    pipe = StableDiffusionInpaintPipeline.from_single_file(
        model_path
    )
    
    # Move to CPU and ensure float32
    pipe = pipe.to("cpu")
    pipe.to(torch.float32)
    
    # Set to evaluation mode
    pipe.unet.eval()
    pipe.vae.eval()
    pipe.text_encoder.eval()
    
    # First convert the image through VAE to get correct latent dimensions
    with torch.no_grad():
        # Sample random latent in correct shape
        latent_height = 64  # standard height for SD latents
        latent_width = 64   # standard width for SD latents
        
        # Create sample inputs for UNet
        # The UNet expects concatenated latent + mask channels
        latents = torch.randn(1, 4, latent_height, latent_width, dtype=torch.float32)
        mask = torch.ones(1, 1, latent_height, latent_width, dtype=torch.float32)
        masked_image_latents = torch.randn(1, 4, latent_height, latent_width, dtype=torch.float32)
        masked_latents = torch.cat([latents, masked_image_latents, mask], dim=1)  # 4 + 4 + 1 = 9 channels

        
        # Time embeddings
        timestep = torch.tensor([1], dtype=torch.int64)
        
        # Text embeddings (77 is the standard sequence length)
        text_embeddings = torch.randn(1, 77, 768, dtype=torch.float32)
        
        # Export UNet
        pipe.text_encoder.text_model.encoder.layers[0].self_attn.scale = torch.tensor(0.125, dtype=torch.float32)

        torch.onnx.export(
            pipe.unet,
            args=(masked_latents, timestep, text_embeddings),
            f=f"{output_dir}/unet.onnx",
            input_names=["sample", "timestep", "encoder_hidden_states"],
            output_names=["out_sample"],
            dynamic_axes={
                "sample": {0: "batch", 2: "height", 3: "width"},
                "encoder_hidden_states": {0: "batch", 1: "sequence"},
                "out_sample": {0: "batch", 2: "height", 3: "width"}
            },
            opset_version=17,
            export_params=True
        )
        
        # Export VAE Decoder
        vae_latent = torch.randn(1, 4, latent_height, latent_width, dtype=torch.float32)
        torch.onnx.export(
            pipe.vae.decoder,
            args=(vae_latent,),
            f=f"{output_dir}/vae_decoder.onnx",
            input_names=["latent"],
            output_names=["image"],
            dynamic_axes={
                "latent": {0: "batch", 2: "height", 3: "width"},
                "image": {0: "batch", 2: "height", 3: "width"}
            },
            opset_version=17,
            export_params=True
        )
        
        # Export Text Encoder
        input_ids = torch.ones(1, 77, dtype=torch.int64)
        torch.onnx.export(
            pipe.text_encoder,
            args=(input_ids,),
            f=f"{output_dir}/text_encoder.onnx",
            input_names=["input_ids"],
            output_names=["last_hidden_state", "pooler_output"],
            dynamic_axes={
                "input_ids": {0: "batch"},
                "last_hidden_state": {0: "batch"},
                "pooler_output": {0: "batch"}
            },
            opset_version=17,
            export_params=True
        )
    
    print("Conversion completed successfully!")
    return True

def verify_paths(model_path):
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model file not found at: {model_path}")
    
    print(f"Model file found at: {model_path}")
    return True

if __name__ == "__main__":
    # Set your paths here
    model_path = "realisticVisionV60B1_v51VAE-inpainting.safetensors"
    output_dir = "onnx_output"
    
    try:
        verify_paths(model_path)
        success = convert_to_onnx(model_path, output_dir)
        if success:
            print(f"ONNX models saved to: {output_dir}")
    except Exception as e:
        print(f"Error during conversion: {str(e)}")
        raise  # Re-raise the exception to see full traceback