add files
Browse files- compile.ipynb +320 -0
- inference.ipynb +171 -0
- text_encoder.pt +3 -0
- unet.pt +3 -0
- vae_decoder.pt +3 -0
- vae_post_quant_conv.pt +3 -0
compile.ipynb
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "9c66a150-b2f7-4c34-b93a-ca70a0855169",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stderr",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"2023-Aug-18 23:10:12.0532 67649:67649 ERROR TDRV:tdrv_get_dev_info No neuron device available\n"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
],
|
17 |
+
"source": [
|
18 |
+
"import os\n",
|
19 |
+
"os.environ[\"NEURON_FUSE_SOFTMAX\"] = \"1\"\n",
|
20 |
+
"\n",
|
21 |
+
"import torch\n",
|
22 |
+
"import torch.nn as nn\n",
|
23 |
+
"import torch_neuronx\n",
|
24 |
+
"import numpy as np\n",
|
25 |
+
"\n",
|
26 |
+
"from matplotlib import pyplot as plt\n",
|
27 |
+
"from matplotlib import image as mpimg\n",
|
28 |
+
"import time\n",
|
29 |
+
"import copy\n",
|
30 |
+
"from IPython.display import clear_output\n",
|
31 |
+
"\n",
|
32 |
+
"from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler\n",
|
33 |
+
"from diffusers.models.unet_2d_condition import UNet2DConditionOutput\n",
|
34 |
+
"from diffusers.models.cross_attention import CrossAttention\n",
|
35 |
+
"\n",
|
36 |
+
"# Define datatype\n",
|
37 |
+
"DTYPE = torch.float32\n"
|
38 |
+
]
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"cell_type": "code",
|
42 |
+
"execution_count": 2,
|
43 |
+
"id": "54c2839b-44b5-4d27-8e83-7cc3d69a53df",
|
44 |
+
"metadata": {},
|
45 |
+
"outputs": [],
|
46 |
+
"source": [
|
47 |
+
"class UNetWrap(nn.Module):\n",
|
48 |
+
" def __init__(self, unet):\n",
|
49 |
+
" super().__init__()\n",
|
50 |
+
" self.unet = unet\n",
|
51 |
+
"\n",
|
52 |
+
" def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None):\n",
|
53 |
+
" out_tuple = self.unet(sample, timestep, encoder_hidden_states, return_dict=False)\n",
|
54 |
+
" return out_tuple\n",
|
55 |
+
"\n",
|
56 |
+
"class NeuronUNet(nn.Module):\n",
|
57 |
+
" def __init__(self, unetwrap):\n",
|
58 |
+
" super().__init__()\n",
|
59 |
+
" self.unetwrap = unetwrap\n",
|
60 |
+
" self.config = unetwrap.unet.config\n",
|
61 |
+
" self.in_channels = unetwrap.unet.in_channels\n",
|
62 |
+
" self.device = unetwrap.unet.device\n",
|
63 |
+
"\n",
|
64 |
+
" def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None):\n",
|
65 |
+
" sample = self.unetwrap(sample, timestep.to(dtype=DTYPE).expand((sample.shape[0],)), encoder_hidden_states)[0]\n",
|
66 |
+
" return UNet2DConditionOutput(sample=sample)\n",
|
67 |
+
"\n",
|
68 |
+
"class NeuronTextEncoder(nn.Module):\n",
|
69 |
+
" def __init__(self, text_encoder):\n",
|
70 |
+
" super().__init__()\n",
|
71 |
+
" self.neuron_text_encoder = text_encoder\n",
|
72 |
+
" self.config = text_encoder.config\n",
|
73 |
+
" self.dtype = text_encoder.dtype\n",
|
74 |
+
" self.device = text_encoder.device\n",
|
75 |
+
"\n",
|
76 |
+
" def forward(self, emb, attention_mask = None):\n",
|
77 |
+
" return [self.neuron_text_encoder(emb)['last_hidden_state']]\n",
|
78 |
+
"# Optimized attention\n",
|
79 |
+
"def get_attention_scores(self, query, key, attn_mask): \n",
|
80 |
+
" dtype = query.dtype\n",
|
81 |
+
"\n",
|
82 |
+
" if self.upcast_attention:\n",
|
83 |
+
" query = query.float()\n",
|
84 |
+
" key = key.float()\n",
|
85 |
+
"\n",
|
86 |
+
" # Check for square matmuls\n",
|
87 |
+
" if(query.size() == key.size()):\n",
|
88 |
+
" attention_scores = custom_badbmm(\n",
|
89 |
+
" key,\n",
|
90 |
+
" query.transpose(-1, -2)\n",
|
91 |
+
" )\n",
|
92 |
+
"\n",
|
93 |
+
" if self.upcast_softmax:\n",
|
94 |
+
" attention_scores = attention_scores.float()\n",
|
95 |
+
"\n",
|
96 |
+
" attention_probs = attention_scores.softmax(dim=1).permute(0,2,1)\n",
|
97 |
+
" attention_probs = attention_probs.to(dtype)\n",
|
98 |
+
"\n",
|
99 |
+
" else:\n",
|
100 |
+
" attention_scores = custom_badbmm(\n",
|
101 |
+
" query,\n",
|
102 |
+
" key.transpose(-1, -2)\n",
|
103 |
+
" )\n",
|
104 |
+
"\n",
|
105 |
+
" if self.upcast_softmax:\n",
|
106 |
+
" attention_scores = attention_scores.float()\n",
|
107 |
+
"\n",
|
108 |
+
" attention_probs = attention_scores.softmax(dim=-1)\n",
|
109 |
+
" attention_probs = attention_probs.to(dtype)\n",
|
110 |
+
" \n",
|
111 |
+
" return attention_probs\n",
|
112 |
+
"\n",
|
113 |
+
"# In the original badbmm the bias is all zeros, so only apply scale\n",
|
114 |
+
"def custom_badbmm(a, b):\n",
|
115 |
+
" bmm = torch.bmm(a, b)\n",
|
116 |
+
" scaled = bmm * 0.125\n",
|
117 |
+
" return scaled"
|
118 |
+
]
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"cell_type": "code",
|
122 |
+
"execution_count": 4,
|
123 |
+
"id": "e1eb8d1b-7b4e-4d55-996e-482e8f18d5e0",
|
124 |
+
"metadata": {},
|
125 |
+
"outputs": [
|
126 |
+
{
|
127 |
+
"data": {
|
128 |
+
"application/vnd.jupyter.widget-view+json": {
|
129 |
+
"model_id": "89d0ef19f2d84ac8bf742de97c95617b",
|
130 |
+
"version_major": 2,
|
131 |
+
"version_minor": 0
|
132 |
+
},
|
133 |
+
"text/plain": [
|
134 |
+
"Fetching 13 files: 0%| | 0/13 [00:00<?, ?it/s]"
|
135 |
+
]
|
136 |
+
},
|
137 |
+
"metadata": {},
|
138 |
+
"output_type": "display_data"
|
139 |
+
},
|
140 |
+
{
|
141 |
+
"ename": "AttributeError",
|
142 |
+
"evalue": "'StableDiffusionPipeline' object has no attribute 'reshape'",
|
143 |
+
"output_type": "error",
|
144 |
+
"traceback": [
|
145 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
146 |
+
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
147 |
+
"Cell \u001b[0;32mIn[4], line 11\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# --- Compile UNet and save ---\u001b[39;00m\n\u001b[1;32m 9\u001b[0m pipe \u001b[38;5;241m=\u001b[39m StableDiffusionPipeline\u001b[38;5;241m.\u001b[39mfrom_pretrained(model_id, torch_dtype\u001b[38;5;241m=\u001b[39mDTYPE)\n\u001b[0;32m---> 11\u001b[0m \u001b[43mpipe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreshape\u001b[49m(width\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1920\u001b[39m, height\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1080\u001b[39m)\n\u001b[1;32m 13\u001b[0m \u001b[38;5;66;03m# Replace original cross-attention module with custom cross-attention module for better performance\u001b[39;00m\n\u001b[1;32m 14\u001b[0m CrossAttention\u001b[38;5;241m.\u001b[39mget_attention_scores \u001b[38;5;241m=\u001b[39m get_attention_scores\n",
|
148 |
+
"\u001b[0;31mAttributeError\u001b[0m: 'StableDiffusionPipeline' object has no attribute 'reshape'"
|
149 |
+
]
|
150 |
+
}
|
151 |
+
],
|
152 |
+
"source": [
|
153 |
+
"# For saving compiler artifacts\n",
|
154 |
+
"COMPILER_WORKDIR_ROOT = 'sd2_compile_dir_768'\n",
|
155 |
+
"\n",
|
156 |
+
"# Model ID for SD version pipeline\n",
|
157 |
+
"model_id = \"stabilityai/stable-diffusion-2-1\"\n",
|
158 |
+
"\n",
|
159 |
+
"# --- Compile UNet and save ---\n",
|
160 |
+
"\n",
|
161 |
+
"pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=DTYPE)\n",
|
162 |
+
"\n",
|
163 |
+
"pipe.reshape(width=1920, height=1080)\n",
|
164 |
+
"\n",
|
165 |
+
"# Replace original cross-attention module with custom cross-attention module for better performance\n",
|
166 |
+
"CrossAttention.get_attention_scores = get_attention_scores\n",
|
167 |
+
"\n",
|
168 |
+
"# Apply double wrapper to deal with custom return type\n",
|
169 |
+
"pipe.unet = NeuronUNet(UNetWrap(pipe.unet))\n",
|
170 |
+
"\n",
|
171 |
+
"# Only keep the model being compiled in RAM to minimze memory pressure\n",
|
172 |
+
"unet = copy.deepcopy(pipe.unet.unetwrap)\n",
|
173 |
+
"\n",
|
174 |
+
"# Compile unet - FP32\n",
|
175 |
+
"sample_1b = torch.randn([1, 4, 135, 240], dtype=DTYPE)\n",
|
176 |
+
"timestep_1b = torch.tensor(999, dtype=DTYPE).expand((1,))\n",
|
177 |
+
"encoder_hidden_states_1b = torch.randn([1, 77, 1024], dtype=DTYPE)\n",
|
178 |
+
"example_inputs = sample_1b, timestep_1b, encoder_hidden_states_1b\n",
|
179 |
+
"print(1)\n",
|
180 |
+
"unet_neuron = torch_neuronx.trace(\n",
|
181 |
+
" unet,\n",
|
182 |
+
" example_inputs,\n",
|
183 |
+
" compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'unet'),\n",
|
184 |
+
" compiler_args=[\"--model-type=unet-inference\", \"--enable-fast-loading-neuron-binaries\"]\n",
|
185 |
+
")\n",
|
186 |
+
"\n",
|
187 |
+
"# Enable asynchronous and lazy loading to speed up model load\n",
|
188 |
+
"torch_neuronx.async_load(unet_neuron)\n",
|
189 |
+
"torch_neuronx.lazy_load(unet_neuron)\n",
|
190 |
+
"\n",
|
191 |
+
"# save compiled unet\n",
|
192 |
+
"unet_filename = 'unet.pt'\n",
|
193 |
+
"torch.jit.save(unet_neuron, unet_filename)\n",
|
194 |
+
"\n",
|
195 |
+
"# delete unused objects\n",
|
196 |
+
"del unet\n",
|
197 |
+
"del unet_neuron\n"
|
198 |
+
]
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"cell_type": "code",
|
202 |
+
"execution_count": null,
|
203 |
+
"id": "e1301369-2008-496f-a52f-65309ab138ac",
|
204 |
+
"metadata": {},
|
205 |
+
"outputs": [],
|
206 |
+
"source": [
|
207 |
+
"text_encoder = copy.deepcopy(pipe.text_encoder)\n",
|
208 |
+
"\n",
|
209 |
+
"# Apply the wrapper to deal with custom return type\n",
|
210 |
+
"text_encoder = NeuronTextEncoder(text_encoder)\n",
|
211 |
+
"\n",
|
212 |
+
"# Compile text encoder\n",
|
213 |
+
"# This is used for indexing a lookup table in torch.nn.Embedding,\n",
|
214 |
+
"# so using random numbers may give errors (out of range).\n",
|
215 |
+
"emb = torch.tensor([[49406, 18376, 525, 7496, 49407, 0, 0, 0, 0, 0,\n",
|
216 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
217 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
218 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
219 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
220 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
221 |
+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
222 |
+
" 0, 0, 0, 0, 0, 0, 0]])\n",
|
223 |
+
"text_encoder_neuron = torch_neuronx.trace(\n",
|
224 |
+
" text_encoder.neuron_text_encoder, \n",
|
225 |
+
" emb, \n",
|
226 |
+
" compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'text_encoder'),\n",
|
227 |
+
" compiler_args=[\"--enable-fast-loading-neuron-binaries\"]\n",
|
228 |
+
" )\n",
|
229 |
+
"\n",
|
230 |
+
"# Enable asynchronous loading to speed up model load\n",
|
231 |
+
"torch_neuronx.async_load(text_encoder_neuron)\n",
|
232 |
+
"\n",
|
233 |
+
"# Save the compiled text encoder\n",
|
234 |
+
"text_encoder_filename = 'text_encoder.pt'\n",
|
235 |
+
"torch.jit.save(text_encoder_neuron, text_encoder_filename)\n",
|
236 |
+
"\n",
|
237 |
+
"# delete unused objects\n",
|
238 |
+
"del text_encoder\n",
|
239 |
+
"del text_encoder_neuron\n",
|
240 |
+
"\n",
|
241 |
+
"# --- Compile VAE decoder and save ---\n",
|
242 |
+
"\n",
|
243 |
+
"# Only keep the model being compiled in RAM to minimze memory pressure\n",
|
244 |
+
"\n",
|
245 |
+
"decoder = copy.deepcopy(pipe.vae.decoder)\n",
|
246 |
+
"# Compile vae decoder\n",
|
247 |
+
"decoder_in = torch.randn([1, 4, 96, 96], dtype=DTYPE)\n",
|
248 |
+
"decoder_neuron = torch_neuronx.trace(\n",
|
249 |
+
" decoder, \n",
|
250 |
+
" decoder_in, \n",
|
251 |
+
" compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'vae_decoder'),\n",
|
252 |
+
" compiler_args=[\"--enable-fast-loading-neuron-binaries\"]\n",
|
253 |
+
")\n",
|
254 |
+
"\n",
|
255 |
+
"# Enable asynchronous loading to speed up model load\n",
|
256 |
+
"torch_neuronx.async_load(decoder_neuron)\n",
|
257 |
+
"\n",
|
258 |
+
"# Save the compiled vae decoder\n",
|
259 |
+
"decoder_filename = 'vae_decoder.pt'\n",
|
260 |
+
"torch.jit.save(decoder_neuron, decoder_filename)\n",
|
261 |
+
"\n",
|
262 |
+
"# delete unused objects\n",
|
263 |
+
"del decoder\n",
|
264 |
+
"del decoder_neuron\n",
|
265 |
+
"\n",
|
266 |
+
"\n",
|
267 |
+
"\n",
|
268 |
+
"\n",
|
269 |
+
"post_quant_conv = copy.deepcopy(pipe.vae.post_quant_conv)\n",
|
270 |
+
"\n",
|
271 |
+
"# # Compile vae post_quant_conv\n",
|
272 |
+
"post_quant_conv_in = torch.randn([1, 4, 96, 96], dtype=DTYPE)\n",
|
273 |
+
"post_quant_conv_neuron = torch_neuronx.trace(\n",
|
274 |
+
" post_quant_conv, \n",
|
275 |
+
" post_quant_conv_in,\n",
|
276 |
+
" compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'vae_post_quant_conv'),\n",
|
277 |
+
")\n",
|
278 |
+
"# Enable asynchronous loading to speed up model load\n",
|
279 |
+
"torch_neuronx.async_load(post_quant_conv_neuron)\n",
|
280 |
+
"\n",
|
281 |
+
"# # Save the compiled vae post_quant_conv\n",
|
282 |
+
"post_quant_conv_filename = 'vae_post_quant_conv.pt'\n",
|
283 |
+
"torch.jit.save(post_quant_conv_neuron, post_quant_conv_filename)\n",
|
284 |
+
"\n",
|
285 |
+
"# delete unused objects\n",
|
286 |
+
"del post_quant_conv\n",
|
287 |
+
"del post_quant_conv_neuron"
|
288 |
+
]
|
289 |
+
},
|
290 |
+
{
|
291 |
+
"cell_type": "code",
|
292 |
+
"execution_count": null,
|
293 |
+
"id": "07524a73-3bbf-4f76-945e-358ca833c335",
|
294 |
+
"metadata": {},
|
295 |
+
"outputs": [],
|
296 |
+
"source": []
|
297 |
+
}
|
298 |
+
],
|
299 |
+
"metadata": {
|
300 |
+
"kernelspec": {
|
301 |
+
"display_name": "Python (torch-neuronx)",
|
302 |
+
"language": "python",
|
303 |
+
"name": "aws_neuron_venv_pytorch"
|
304 |
+
},
|
305 |
+
"language_info": {
|
306 |
+
"codemirror_mode": {
|
307 |
+
"name": "ipython",
|
308 |
+
"version": 3
|
309 |
+
},
|
310 |
+
"file_extension": ".py",
|
311 |
+
"mimetype": "text/x-python",
|
312 |
+
"name": "python",
|
313 |
+
"nbconvert_exporter": "python",
|
314 |
+
"pygments_lexer": "ipython3",
|
315 |
+
"version": "3.8.10"
|
316 |
+
}
|
317 |
+
},
|
318 |
+
"nbformat": 4,
|
319 |
+
"nbformat_minor": 5
|
320 |
+
}
|
inference.ipynb
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "07b2bef9-bbaf-41b8-b960-7ac373ff3e8d",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"!pip install diffusers==0.14.0 transformers==4.26.1 accelerate==0.16.0 safetensors==0.3.1 matplotlib"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": null,
|
16 |
+
"id": "6ebecb44-f796-4c76-8385-888a2f46fd6a",
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"import os\n",
|
21 |
+
"os.environ[\"NEURON_FUSE_SOFTMAX\"] = \"1\"\n",
|
22 |
+
"\n",
|
23 |
+
"import torch\n",
|
24 |
+
"import torch.nn as nn\n",
|
25 |
+
"import torch_neuronx\n",
|
26 |
+
"import numpy as np\n",
|
27 |
+
"\n",
|
28 |
+
"from matplotlib import pyplot as plt\n",
|
29 |
+
"from matplotlib import image as mpimg\n",
|
30 |
+
"import time\n",
|
31 |
+
"import copy\n",
|
32 |
+
"\n",
|
33 |
+
"from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler\n",
|
34 |
+
"from diffusers.models.unet_2d_condition import UNet2DConditionOutput\n",
|
35 |
+
"from diffusers.models.cross_attention import CrossAttention\n",
|
36 |
+
"\n",
|
37 |
+
"# Define datatype\n",
|
38 |
+
"DTYPE = torch.float32"
|
39 |
+
]
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"cell_type": "code",
|
43 |
+
"execution_count": null,
|
44 |
+
"id": "9950025f-877a-4c11-b30e-9c32f0825e94",
|
45 |
+
"metadata": {},
|
46 |
+
"outputs": [],
|
47 |
+
"source": [
|
48 |
+
"class UNetWrap(nn.Module):\n",
|
49 |
+
" def __init__(self, unet):\n",
|
50 |
+
" super().__init__()\n",
|
51 |
+
" self.unet = unet\n",
|
52 |
+
"\n",
|
53 |
+
" def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None):\n",
|
54 |
+
" out_tuple = self.unet(sample, timestep, encoder_hidden_states, return_dict=False)\n",
|
55 |
+
" return out_tuple\n",
|
56 |
+
"\n",
|
57 |
+
"class NeuronUNet(nn.Module):\n",
|
58 |
+
" def __init__(self, unetwrap):\n",
|
59 |
+
" super().__init__()\n",
|
60 |
+
" self.unetwrap = unetwrap\n",
|
61 |
+
" self.config = unetwrap.unet.config\n",
|
62 |
+
" self.in_channels = unetwrap.unet.in_channels\n",
|
63 |
+
" self.device = unetwrap.unet.device\n",
|
64 |
+
"\n",
|
65 |
+
" def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None):\n",
|
66 |
+
" sample = self.unetwrap(sample, timestep.to(dtype=DTYPE).expand((sample.shape[0],)), encoder_hidden_states)[0]\n",
|
67 |
+
" return UNet2DConditionOutput(sample=sample)\n",
|
68 |
+
"\n",
|
69 |
+
"class NeuronTextEncoder(nn.Module):\n",
|
70 |
+
" def __init__(self, text_encoder):\n",
|
71 |
+
" super().__init__()\n",
|
72 |
+
" self.neuron_text_encoder = text_encoder\n",
|
73 |
+
" self.config = text_encoder.config\n",
|
74 |
+
" self.dtype = text_encoder.dtype\n",
|
75 |
+
" self.device = text_encoder.device\n",
|
76 |
+
"\n",
|
77 |
+
" def forward(self, emb, attention_mask = None):\n",
|
78 |
+
" return [self.neuron_text_encoder(emb)['last_hidden_state']]\n",
|
79 |
+
"# Optimized attention\n",
|
80 |
+
"def get_attention_scores(self, query, key, attn_mask): \n",
|
81 |
+
" dtype = query.dtype\n",
|
82 |
+
"\n",
|
83 |
+
" if self.upcast_attention:\n",
|
84 |
+
" query = query.float()\n",
|
85 |
+
" key = key.float()\n",
|
86 |
+
"\n",
|
87 |
+
" # Check for square matmuls\n",
|
88 |
+
" if(query.size() == key.size()):\n",
|
89 |
+
" attention_scores = custom_badbmm(\n",
|
90 |
+
" key,\n",
|
91 |
+
" query.transpose(-1, -2)\n",
|
92 |
+
" )\n",
|
93 |
+
"\n",
|
94 |
+
" if self.upcast_softmax:\n",
|
95 |
+
" attention_scores = attention_scores.float()\n",
|
96 |
+
"\n",
|
97 |
+
" attention_probs = attention_scores.softmax(dim=1).permute(0,2,1)\n",
|
98 |
+
" attention_probs = attention_probs.to(dtype)\n",
|
99 |
+
"\n",
|
100 |
+
" else:\n",
|
101 |
+
" attention_scores = custom_badbmm(\n",
|
102 |
+
" query,\n",
|
103 |
+
" key.transpose(-1, -2)\n",
|
104 |
+
" )\n",
|
105 |
+
"\n",
|
106 |
+
" if self.upcast_softmax:\n",
|
107 |
+
" attention_scores = attention_scores.float()\n",
|
108 |
+
"\n",
|
109 |
+
" attention_probs = attention_scores.softmax(dim=-1)\n",
|
110 |
+
" attention_probs = attention_probs.to(dtype)\n",
|
111 |
+
" \n",
|
112 |
+
" return attention_probs\n",
|
113 |
+
"\n",
|
114 |
+
"# In the original badbmm the bias is all zeros, so only apply scale\n",
|
115 |
+
"def custom_badbmm(a, b):\n",
|
116 |
+
" bmm = torch.bmm(a, b)\n",
|
117 |
+
" scaled = bmm * 0.125\n",
|
118 |
+
" return scaled"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"cell_type": "code",
|
123 |
+
"execution_count": null,
|
124 |
+
"id": "ffc64d14-f48c-488c-b60a-36e3ebfdab83",
|
125 |
+
"metadata": {},
|
126 |
+
"outputs": [],
|
127 |
+
"source": [
|
128 |
+
"model_id = \"stabilityai/stable-diffusion-2-1\"\n",
|
129 |
+
"text_encoder_filename = 'text_encoder.pt'\n",
|
130 |
+
"decoder_filename = 'vae_decoder.pt'\n",
|
131 |
+
"unet_filename = 'unet.pt'\n",
|
132 |
+
"post_quant_conv_filename = 'vae_post_quant_conv.pt'\n",
|
133 |
+
"\n",
|
134 |
+
"pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=DTYPE)\n",
|
135 |
+
"pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\n",
|
136 |
+
"\n",
|
137 |
+
"# Load the compiled UNet onto two neuron cores.\n",
|
138 |
+
"pipe.unet = NeuronUNet(UNetWrap(pipe.unet))\n",
|
139 |
+
"device_ids = [0,1]\n",
|
140 |
+
"pipe.unet.unetwrap = torch_neuronx.DataParallel(torch.jit.load(unet_filename), device_ids, set_dynamic_batching=False)\n",
|
141 |
+
"\n",
|
142 |
+
"# Load other compiled models onto a single neuron core.\n",
|
143 |
+
"pipe.text_encoder = NeuronTextEncoder(pipe.text_encoder)\n",
|
144 |
+
"pipe.text_encoder.neuron_text_encoder = torch.jit.load(text_encoder_filename)\n",
|
145 |
+
"pipe.vae.decoder = torch.jit.load(decoder_filename)\n",
|
146 |
+
"pipe.vae.post_quant_conv = torch.jit.load(post_quant_conv_filename)"
|
147 |
+
]
|
148 |
+
}
|
149 |
+
],
|
150 |
+
"metadata": {
|
151 |
+
"kernelspec": {
|
152 |
+
"display_name": "Python (torch-neuronx)",
|
153 |
+
"language": "python",
|
154 |
+
"name": "aws_neuron_venv_pytorch"
|
155 |
+
},
|
156 |
+
"language_info": {
|
157 |
+
"codemirror_mode": {
|
158 |
+
"name": "ipython",
|
159 |
+
"version": 3
|
160 |
+
},
|
161 |
+
"file_extension": ".py",
|
162 |
+
"mimetype": "text/x-python",
|
163 |
+
"name": "python",
|
164 |
+
"nbconvert_exporter": "python",
|
165 |
+
"pygments_lexer": "ipython3",
|
166 |
+
"version": "3.8.10"
|
167 |
+
}
|
168 |
+
},
|
169 |
+
"nbformat": 4,
|
170 |
+
"nbformat_minor": 5
|
171 |
+
}
|
text_encoder.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:96834421e918e0c762d7157bb5629a8c220dcd450e645a14fc81578e73b9e62b
|
3 |
+
size 862847068
|
unet.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:284ed1961f753da632a5d8369bbb40edadc01c513bc98c9a77653e27d4cdbb0e
|
3 |
+
size 2073112440
|
vae_decoder.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:74e0a2daa1572c36ba56c846f1210451e46ea383a0dc5598881c35696fb0aef7
|
3 |
+
size 532453276
|
vae_post_quant_conv.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:459cd0aaabee7b2ec9610b55d8d6ba71db8300c16cd9beeeddcf822a6d2b66e8
|
3 |
+
size 35228
|