File size: 14,896 Bytes
9382e3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 |
<!--Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
β οΈ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# λλ²κΉ
[[debugging]]
## Multi-GPU λ€νΈμν¬ λ¬Έμ λλ²κ·Έ [[multigpu-network-issues-debug]]
`DistributedDataParallel` λ° λ€μ€ GPUλ₯Ό μ¬μ©νμ¬ νλ ¨νκ±°λ μΆλ‘ ν λ, νλ‘μΈμ€ λ°/λλ λ
Έλ κ°μ μνΈ ν΅μ λ¬Έμ κ° λ°μνλ κ²½μ°, λ€μ μ€ν¬λ¦½νΈλ₯Ό μ¬μ©νμ¬ λ€νΈμν¬ λ¬Έμ λ₯Ό μ§λ¨ν μ μμ΅λλ€.
```bash
wget https://raw.githubusercontent.com/huggingface/transformers/main/scripts/distributed/torch-distributed-gpu-test.py
```
μλ₯Ό λ€μ΄, 2κ°μ GPUκ° μνΈ μμ©νλ λ°©μμ ν
μ€νΈνλ €λ©΄ λ€μμ μ€ννμΈμ:
```bash
python -m torch.distributed.run --nproc_per_node 2 --nnodes 1 torch-distributed-gpu-test.py
```
λ νλ‘μΈμ€κ° μλ‘ ν΅μ νκ³ GPU λ©λͺ¨λ¦¬λ₯Ό ν λΉνλ κ²½μ°, κ°κ° "OK" μνλ₯Ό μΆλ ₯ν©λλ€.
λ λ§μ GPU λλ λ
Έλμ κ²½μ° μ€ν¬λ¦½νΈμ μΈμλ₯Ό μ‘°μ νλ©΄ λ©λλ€.
μ§λ¨ μ€ν¬λ¦½νΈ λ΄μμ λ λ§μ μΈλΆ μ 보μ SLURM νκ²½μμ μ€ννλ λ°©λ²μ λν λ μνΌλ₯Ό μ°Ύμ μ μμ΅λλ€.
μΆκ°μ μΈ λλ²κ·Έ μμ€μ λ€μκ³Ό κ°μ΄ `NCCL_DEBUG=INFO` νκ²½ λ³μλ₯Ό μΆκ°νλ κ²μ
λλ€:
```bash
NCCL_DEBUG=INFO python -m torch.distributed.run --nproc_per_node 2 --nnodes 1 torch-distributed-gpu-test.py
```
μ΄λ κ² νλ©΄ NCCL κ΄λ ¨ λλ²κ·Έ μ λ³΄κ° λ§μ΄ μΆλ ₯λλ©°, λ¬Έμ κ° λ³΄κ³ λ κ²½μ°μλ μΈν°λ·μμ κ²μν μ μμ΅λλ€. λλ μΆλ ₯μ ν΄μνλ λ°©λ²μ μ λͺ¨λ₯΄λ κ²½μ° λ‘κ·Έ νμΌμ μ΄μμ 곡μ ν μ μμ΅λλ€.
## μΈλνλ‘ λ° μ€λ²νλ‘ κ°μ§ [[underflow-and-overflow-detection]]
<Tip>
μ΄ κΈ°λ₯μ νμ¬ PyTorchμμλ§ μ¬μ©ν μ μμ΅λλ€.
</Tip>
<Tip>
λ€μ€ GPU νλ ¨μ μν΄μλ DDP (`torch.distributed.launch`)κ° νμν©λλ€.
</Tip>
<Tip>
μ΄ κΈ°λ₯μ `nn.Module`μ κΈ°λ°μΌλ‘ νλ λͺ¨λΈκ³Ό ν¨κ» μ¬μ©ν μ μμ΅λλ€.
</Tip>
`loss=NaN`μ΄ λνλκ±°λ λͺ¨λΈμ΄ `inf` λλ `nan`μΌλ‘ μΈν΄ λ€λ₯Έ μ΄μν λμμ νλ κ²½μ°, μΈλνλ‘ λλ μ€λ²νλ‘μ 첫 λ²μ§Έ λ°μ μμΉμ κ·Έ μμΈμ νμ
ν΄μΌ ν©λλ€. λ€ννλ μ΄λ₯Ό μλμΌλ‘ κ°μ§νλ νΉμ λͺ¨λμ νμ±ννμ¬ μ½κ² μμλΌ μ μμ΅λλ€.
[`Trainer`]λ₯Ό μ¬μ©νλ κ²½μ°, λ€μμ κΈ°μ‘΄μ λͺ
λ Ήμ€ μΈμμ μΆκ°νλ©΄ λ©λλ€.
```bash
--debug underflow_overflow
```
λλ [`TrainingArguments`] κ°μ²΄λ₯Ό μμ±ν λ `debug="underflow_overflow"`λ₯Ό μ λ¬ν©λλ€.
μ체 νλ ¨ 루νλ λ€λ₯Έ Trainerλ₯Ό μ¬μ©νλ κ²½μ°, λ€μκ³Ό κ°μ΄ μνν μ μμ΅λλ€.
```python
from transformers.debug_utils import DebugUnderflowOverflow
debug_overflow = DebugUnderflowOverflow(model)
```
[`~debug_utils.DebugUnderflowOverflow`]λ λͺ¨λΈμ νν¬λ₯Ό μ½μ
νμ¬ κ° forward νΈμΆ μ§νμ μ
λ ₯ λ° μΆλ ₯ λ³μ λ° ν΄λΉ λͺ¨λμ κ°μ€μΉλ₯Ό ν
μ€νΈν©λλ€. νμ±νλ κ°μ€μΉμ μ΅μν νλμ μμμμ `inf` λλ `nan`μ΄ κ°μ§λλ©΄ νλ‘κ·Έλ¨μ΄ μ΄μ€νΈλκ³ λ€μκ³Ό κ°μ λ³΄κ³ μκ° μΆλ ₯λ©λλ€. (μ΄ μμ λ fp16 νΌν© μ λ°λμμ `google/mt5-small`μμ μΊ‘μ²λ κ²μ
λλ€):
```
Detected inf/nan during batch_number=0
Last 21 forward frames:
abs min abs max metadata
encoder.block.1.layer.1.DenseReluDense.dropout Dropout
0.00e+00 2.57e+02 input[0]
0.00e+00 2.85e+02 output
[...]
encoder.block.2.layer.0 T5LayerSelfAttention
6.78e-04 3.15e+03 input[0]
2.65e-04 3.42e+03 output[0]
None output[1]
2.25e-01 1.00e+04 output[2]
encoder.block.2.layer.1.layer_norm T5LayerNorm
8.69e-02 4.18e-01 weight
2.65e-04 3.42e+03 input[0]
1.79e-06 4.65e+00 output
encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
2.17e-07 4.50e+00 weight
1.79e-06 4.65e+00 input[0]
2.68e-06 3.70e+01 output
encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
8.08e-07 2.66e+01 weight
1.79e-06 4.65e+00 input[0]
1.27e-04 2.37e+02 output
encoder.block.2.layer.1.DenseReluDense.dropout Dropout
0.00e+00 8.76e+03 input[0]
0.00e+00 9.74e+03 output
encoder.block.2.layer.1.DenseReluDense.wo Linear
1.01e-06 6.44e+00 weight
0.00e+00 9.74e+03 input[0]
3.18e-04 6.27e+04 output
encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
1.79e-06 4.65e+00 input[0]
3.18e-04 6.27e+04 output
encoder.block.2.layer.1.dropout Dropout
3.18e-04 6.27e+04 input[0]
0.00e+00 inf output
```
μμ μΆλ ₯μ κ°λ΅μ±μ μν΄ μ€κ° λΆλΆμ΄ μλ € μμ΅λλ€.
λ λ²μ§Έ μ΄μ μ λμ μΌλ‘ κ°μ₯ ν° μμμ κ°μ΄λ©°, λ°λΌμ λ§μ§λ§ λͺ κ°μ νλ μμ μμΈν μ΄ν΄λ³΄λ©΄ μ
λ ₯κ³Ό μΆλ ₯μ΄ `1e4` λ²μμ μμμ μ μ μμ΅λλ€. λ°λΌμ μ΄ νλ ¨μ `fp16` νΌν© μ λ°λλ‘ μνλ λ κ°μ₯ λ§μ§λ§ λ¨κ³μμ μ€λ²νλ‘μ°κ° λ°μνμ΅λλ€ (`fp16`μμ `inf` μ΄μ μ κ°μ₯ ν° μ«μλ `64e3`μ
λλ€). `fp16` μλμμ μ€λ²νλ‘μ°λ₯Ό νΌνκΈ° μν΄μλ νμ±νλ `1e4`λ³΄λ€ ν¨μ¬ μμμΌ ν©λλ€. μλνλ©΄ `1e4 * 1e4 = 1e8`μ΄κΈ° λλ¬Έμ ν° νμ±νμμ νλ ¬ κ³±μ μμΉμ μΈ μ€λ²νλ‘μ° μ‘°κ±΄μΌλ‘ μ΄μ΄μ§ κ²μ
λλ€.
μΆμ μ 맨 μ²μμμ μ΄λ λ°°μΉ λ²νΈμμ λ¬Έμ κ° λ°μνλμ§ μ μ μμ΅λλ€ (μ¬κΈ°μ `Detected inf/nan during batch_number=0`μ λ¬Έμ κ° μ²« λ²μ§Έ λ°°μΉμμ λ°μνμμ μλ―Έν©λλ€).
κ° λ³΄κ³ λ νλ μμ ν΄λΉ νλ μμ΄ λ³΄κ³ νλ ν΄λΉ λͺ¨λμ λν μμ ν νλͺ©μ μ μΈνλ©°, μ΄ νλ μλ§ μ΄ν΄λ³΄λ©΄ λ€μκ³Ό κ°μ΅λλ€.
```
encoder.block.2.layer.1.layer_norm T5LayerNorm
8.69e-02 4.18e-01 weight
2.65e-04 3.42e+03 input[0]
1.79e-06 4.65e+00 output
```
μ¬κΈ°μ `encoder.block.2.layer.1.layer_norm`μ μΈμ½λμ λ λ²μ§Έ λΈλ‘μ 첫 λ²μ§Έ λ μ΄μ΄μ λν λ μ΄μ΄ μ κ·νλ₯Ό μλ―Ένλ©°, `forward`μ νΉμ νΈμΆμ `T5LayerNorm`μ
λλ€.
μ΄ λ³΄κ³ μμ λ§μ§λ§ λͺ κ° νλ μμ μ΄ν΄λ³΄κ² μ΅λλ€:
```
Detected inf/nan during batch_number=0
Last 21 forward frames:
abs min abs max metadata
[...]
encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
2.17e-07 4.50e+00 weight
1.79e-06 4.65e+00 input[0]
2.68e-06 3.70e+01 output
encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
8.08e-07 2.66e+01 weight
1.79e-06 4.65e+00 input[0]
1.27e-04 2.37e+02 output
encoder.block.2.layer.1.DenseReluDense.wo Linear
1.01e-06 6.44e+00 weight
0.00e+00 9.74e+03 input[0]
3.18e-04 6.27e+04 output
encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
1.79e-06 4.65e+00 input[0]
3.18e-04 6.27e+04 output
encoder.block.2.layer.1.dropout Dropout
3.18e-04 6.27e+04 input[0]
0.00e+00 inf output
```
λ§μ§λ§ νλ μμ `Dropout.forward` ν¨μμ λν λ³΄κ³ μ
λλ€. 첫 λ²μ§Έ νλͺ©μ μ μΌν μ
λ ₯μ λνλ΄κ³ λ λ²μ§Έ νλͺ©μ μ μΌν μΆλ ₯μ λνλ
λλ€. μ΄ ν¨μκ° `DenseReluDense` ν΄λμ€ λ΄λΆμ `dropout` μμ±μμ νΈμΆλ κ²μ λ³Ό μ μμ΅λλ€. μ΄λ 첫 λ²μ§Έ λ μ΄μ΄μ λ λ²μ§Έ λΈλ‘μμ 첫 λ²μ§Έ λ°°μΉ μ€μ λ°μνλ€λ κ²μ μ μ μμ΅λλ€. λ§μ§λ§μΌλ‘, μ λμ μΌλ‘ κ°μ₯ ν° μ
λ ₯ μμλ `6.27e+04`μ΄κ³ μΆλ ₯λ λ§μ°¬κ°μ§λ‘ `inf`μ
λλ€.
μ¬κΈ°μμλ `T5DenseGatedGeluDense.forward`κ° μΆλ ₯ νμ±νλ₯Ό μμ±νλλ°, μ λμ μΌλ‘ κ°μ₯ ν° κ°μ΄ μ½ 62.7KμΈ κ²μ λ³Ό μ μμ΅λλ€. μ΄ κ°μ fp16μ μ΅λ μ νμΈ 64Kμ λ§€μ° κ·Όμ ν©λλ€. λ€μ νλ μμμλ μΌλΆ μμλ₯Ό 0μΌλ‘ λ§λ ν κ°μ€μΉλ₯Ό μ¬μ κ·ννλ `Dropout`μ΄ μμ΅λλ€. μ΄λ‘ μΈν΄ μ λ μ΅λκ°μ΄ 64Kλ₯Ό μ΄κ³Όνκ³ μ€λ²νλ‘μ°(`inf`)κ° λ°μν©λλ€.
보μλ€μνΌ, fp16 μ«μμ κ²½μ° μ«μκ° λ§€μ° μ»€μ§ λ μ΄μ νλ μμ μ΄ν΄λ³΄μμΌ ν©λλ€.
λ³΄κ³ μλ₯Ό `models/t5/modeling_t5.py`μ μ½λμ μΌμΉμμΌ λ³΄κ² μ΅λλ€.
```python
class T5DenseGatedGeluDense(nn.Module):
def __init__(self, config):
super().__init__()
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.gelu_act = ACT2FN["gelu_new"]
def forward(self, hidden_states):
hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
return hidden_states
```
μ΄μ `dropout` νΈμΆκ³Ό μ΄μ μ λͺ¨λ νΈμΆμ μ½κ² νμΈν μ μμ΅λλ€.
κ°μ§λ `forward` νν¬μμ λ°μνλ―λ‘, μ΄λ¬ν λ³΄κ³ μλ κ° `forward`κ° λ°νλ μ§νμ μ¦μ μΆλ ₯λ©λλ€.
μ 체 λ³΄κ³ μλ‘ λμκ°μ λ¬Έμ μ λν μ‘°μΉ λ° μμ μ νλ €λ©΄, μ«μκ° μ¦κ°νκΈ° μμν λͺ κ°μ νλ μ μλ‘ μ΄λν΄μ μ¬κΈ°μ `fp32` λͺ¨λλ‘ μ νν΄μΌ ν©λλ€. μ΄λ κ² ν΄μΌ μ«μκ° κ³±ν΄μ§κ±°λ ν©μ³μ§ λ μ€λ²νλ‘μ°λμ§ μμ κ°λ₯μ±μ΄ λμ΅λλ€. λ¬Όλ‘ λ€λ₯Έ ν΄κ²°μ±
λ μμ μ μμ΅λλ€. μλ₯Ό λ€μ΄, `amp`κ° νμ±νλ κ²½μ° μΌμμ μΌλ‘ λκ³ μλμ `forward`λ₯Ό λμ°λ―Έ λνΌλ‘ μ΄λν ν λ€μκ³Ό κ°μ΄ ν μ μμ΅λλ€:
```python
def _forward(self, hidden_states):
hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
return hidden_states
import torch
def forward(self, hidden_states):
if torch.is_autocast_enabled():
with torch.cuda.amp.autocast(enabled=False):
return self._forward(hidden_states)
else:
return self._forward(hidden_states)
```
μλ κ°μ§κΈ°λ μ 체 νλ μμ μ
λ ₯κ³Ό μΆλ ₯μ λν΄μλ§ λ³΄κ³ νλ―λ‘, μ΄λλ₯Ό μ΄ν΄λ΄μΌ νλμ§ μλ©΄ νΉμ `forward` ν¨μμ μ€κ° λ¨κ³λ λΆμν μ μμ΅λλ€. μ΄ κ²½μ°μλ `detect_overflow` λμ°λ―Έ ν¨μλ₯Ό μ¬μ©νμ¬ μνλ μμΉμ κ°μ§κΈ°λ₯Ό μ½μ
ν μ μμ΅λλ€. μλ₯Ό λ€μ΄:
```python
from debug_utils import detect_overflow
class T5LayerFF(nn.Module):
[...]
def forward(self, hidden_states):
forwarded_states = self.layer_norm(hidden_states)
detect_overflow(forwarded_states, "after layer_norm")
forwarded_states = self.DenseReluDense(forwarded_states)
detect_overflow(forwarded_states, "after DenseReluDense")
return hidden_states + self.dropout(forwarded_states)
```
μ¬κΈ°μλ μ΄λ₯Ό μΆκ°νμ¬ 2κ°μ κ²μ μΆμ νκ³ μ΄μ `forwarded_states`μ `inf` λλ `nan`μ΄ μ€κ°μ κ°μ§λμλμ§λ₯Ό μΆμ ν©λλ€.
μ€μ λ‘ μμ μμ μμ κ° νΈμΆμ΄ `nn.Module`μ΄κΈ° λλ¬Έμ νμ§κΈ°κ° μ΄λ―Έ μ΄λ₯Ό λ³΄κ³ ν©λλ€. λ‘컬μμ μ§μ κ³μ°νλ κ²½μ° μ΄λ κ² μννλ€κ³ κ°μ ν΄ λ΄
μλ€.
λν, μ체 μ½λμμ λλ²κ±°λ₯Ό μΈμ€ν΄μ€ννλ κ²½μ° κΈ°λ³Έκ°μμ μΆλ ₯λλ νλ μ μλ₯Ό μ‘°μ ν μ μμ΅λλ€. μλ₯Ό λ€μ΄:
```python
from transformers.debug_utils import DebugUnderflowOverflow
debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100)
```
### νΉμ λ°°μΉμ μ λκ° μ΅μ λ° μ΅λ κ° μΆμ [[specific-batch-absolute-min-and-max-value-tracing]]
λμΌν λλ²κΉ
ν΄λμ€λ μΈλνλ‘μ°/μ€λ²νλ‘μ° κ°μ§ κΈ°λ₯μ΄ κΊΌμ§ μνμμ λ°°μΉλ³ μΆμ μλ μ¬μ©ν μ μμ΅λλ€.
μλ₯Ό λ€μ΄, νΉμ λ°°μΉμ κ° `forward` νΈμΆμ λͺ¨λ κ΅¬μ± μ±λΆμ λν μ λ μ΅μκ°κ³Ό μ΅λκ°μ νμΈνκ³ , μ΄λ₯Ό λ°°μΉ 1κ³Ό 3μ λν΄μλ§ μννλ €λ©΄ λ€μκ³Ό κ°μ΄ μ΄ ν΄λμ€λ₯Ό μΈμ€ν΄μ€νν©λλ€:
```python
debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3])
```
κ·Έλ¬λ©΄ μ΄μ λ°°μΉ 1κ³Ό 3 μ μ²΄κ° μΈλνλ‘μ°/μ€λ²νλ‘μ° κ°μ§κΈ°μ λμΌν νμμΌλ‘ μΆμ λ©λλ€.
λ°°μΉλ 0λΆν° μμν©λλ€.
μ΄λ νλ‘κ·Έλ¨μ΄ νΉμ λ°°μΉ λ²νΈ μ΄νμ μ€μλνκΈ° μμνλ κ²μ μκ³ μλ κ²½μ°μ μ μ©ν©λλ€. κ·Έλ κΈ° λλ¬Έμ ν΄λΉ μμμΌλ‘ λ°λ‘ μ΄λν μ μμ΅λλ€. μ΄λ° ꡬμ±μ λν μν μΆμλ μΆλ ₯μ λ€μκ³Ό κ°μ΅λλ€.
```
*** Starting batch number=1 ***
abs min abs max metadata
shared Embedding
1.01e-06 7.92e+02 weight
0.00e+00 2.47e+04 input[0]
5.36e-05 7.92e+02 output
[...]
decoder.dropout Dropout
1.60e-07 2.27e+01 input[0]
0.00e+00 2.52e+01 output
decoder T5Stack
not a tensor output
lm_head Linear
1.01e-06 7.92e+02 weight
0.00e+00 1.11e+00 input[0]
6.06e-02 8.39e+01 output
T5ForConditionalGeneration
not a tensor output
*** Starting batch number=3 ***
abs min abs max metadata
shared Embedding
1.01e-06 7.92e+02 weight
0.00e+00 2.78e+04 input[0]
5.36e-05 7.92e+02 output
[...]
```
μ¬κΈ°μμλ λͺ¨λΈμ forward νΈμΆ μμ λμΌν μμ νλ μμ΄ λ€νλλ―λ‘ λ§μ μμ νλ μμ΄ μμ±λ©λλ€. λ°λΌμ μνλ κ²μΌ μλ μκ³ μλ μλ μμ΅λλ€. κ·Έλ¬λ λλ‘λ μΌλ° λλ²κ±°λ³΄λ€ λλ²κΉ
λͺ©μ μΌλ‘ λ μ½κ² μ¬μ©ν μ μμ΅λλ€. μλ₯Ό λ€μ΄, λ¬Έμ κ° λ°°μΉ λ²νΈ 150μμ μμνλ κ²½μ° 149μ 150μ μΆμ μ λ€ννκ³ μ«μκ° μ΄λμλΆν° λ€λ₯΄κ² λμλμ§ λΉκ΅ν μ μμ΅λλ€.
λν, νλ ¨μ μ€μ§ν λ°°μΉ λ²νΈλ₯Ό μ§μ ν μλ μμ΅λλ€. λ€μκ³Ό κ°μ΄ μ§μ ν μ μμ΅λλ€.
```python
debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3], abort_after_batch_num=3)
```
|