Commit
·
66b7238
1
Parent(s):
9fbbef8
Upload semanticallysegmentdeezglaciers.ipynb
Browse files
semanticallysegmentdeezglaciers.ipynb
CHANGED
@@ -333,8 +333,11 @@
|
|
333 |
"def compute_metrics(eval_pred):\n",
|
334 |
" # Ensure that gradient computation is turned off, as it is not needed for evaluation\n",
|
335 |
" with torch.no_grad():\n",
|
|
|
|
|
336 |
" logits, labels = eval_pred\n",
|
337 |
" logits_tensor = torch.from_numpy(logits)\n",
|
|
|
338 |
" logits_tensor = nn.functional.interpolate(\n",
|
339 |
" logits_tensor,\n",
|
340 |
" size=labels.shape[-2:],\n",
|
@@ -343,8 +346,10 @@
|
|
343 |
" )\n",
|
344 |
" # Take the argmax of the logits tensor along dimension 1 to get the predicted labels\n",
|
345 |
" logits_tensor = logits_tensor.argmax(dim=1)\n",
|
346 |
-
" # Detach the predicted labels from the computation graph and move them to the CPU
|
|
|
347 |
" pred_labels = logits_tensor.detach().cpu().numpy()\n",
|
|
|
348 |
" metrics = metric.compute(\n",
|
349 |
" predictions=pred_labels,\n",
|
350 |
" references=labels,\n",
|
|
|
333 |
"def compute_metrics(eval_pred):\n",
|
334 |
" # Ensure that gradient computation is turned off, as it is not needed for evaluation\n",
|
335 |
" with torch.no_grad():\n",
|
336 |
+
" # This computes the final logits tensor by interpolating the output logits to the size of the labels tensor from an input of size (batch_size, num_labels, height, width)\n",
|
337 |
+
" # This is input that has gone through the model's forward pass\n",
|
338 |
" logits, labels = eval_pred\n",
|
339 |
" logits_tensor = torch.from_numpy(logits)\n",
|
340 |
+
" # this can lead to very high ram usage for the upscaling\n",
|
341 |
" logits_tensor = nn.functional.interpolate(\n",
|
342 |
" logits_tensor,\n",
|
343 |
" size=labels.shape[-2:],\n",
|
|
|
346 |
" )\n",
|
347 |
" # Take the argmax of the logits tensor along dimension 1 to get the predicted labels\n",
|
348 |
" logits_tensor = logits_tensor.argmax(dim=1)\n",
|
349 |
+
" # Detach the predicted labels from the computation graph and move them to the CPU \n",
|
350 |
+
" # (although they are already on the CPU) to save memory and to use numpy features like the metrics module\n",
|
351 |
" pred_labels = logits_tensor.detach().cpu().numpy()\n",
|
352 |
+
" # Computes metrics\n",
|
353 |
" metrics = metric.compute(\n",
|
354 |
" predictions=pred_labels,\n",
|
355 |
" references=labels,\n",
|