aashraychegu commited on
Commit
66b7238
·
1 Parent(s): 9fbbef8

Upload semanticallysegmentdeezglaciers.ipynb

Browse files
semanticallysegmentdeezglaciers.ipynb CHANGED
@@ -333,8 +333,11 @@
333
  "def compute_metrics(eval_pred):\n",
334
  " # Ensure that gradient computation is turned off, as it is not needed for evaluation\n",
335
  " with torch.no_grad():\n",
 
 
336
  " logits, labels = eval_pred\n",
337
  " logits_tensor = torch.from_numpy(logits)\n",
 
338
  " logits_tensor = nn.functional.interpolate(\n",
339
  " logits_tensor,\n",
340
  " size=labels.shape[-2:],\n",
@@ -343,8 +346,10 @@
343
  " )\n",
344
  " # Take the argmax of the logits tensor along dimension 1 to get the predicted labels\n",
345
  " logits_tensor = logits_tensor.argmax(dim=1)\n",
346
- " # Detach the predicted labels from the computation graph and move them to the CPU to save memory and to use numpy features\n",
 
347
  " pred_labels = logits_tensor.detach().cpu().numpy()\n",
 
348
  " metrics = metric.compute(\n",
349
  " predictions=pred_labels,\n",
350
  " references=labels,\n",
 
333
  "def compute_metrics(eval_pred):\n",
334
  " # Ensure that gradient computation is turned off, as it is not needed for evaluation\n",
335
  " with torch.no_grad():\n",
336
+ " # This computes the final logits tensor by interpolating the output logits to the size of the labels tensor from an input of size (batch_size, num_labels, height, width)\n",
337
+ " # This is input that has gone through the model's forward pass\n",
338
  " logits, labels = eval_pred\n",
339
  " logits_tensor = torch.from_numpy(logits)\n",
340
+ " # this can lead to very high ram usage for the upscaling\n",
341
  " logits_tensor = nn.functional.interpolate(\n",
342
  " logits_tensor,\n",
343
  " size=labels.shape[-2:],\n",
 
346
  " )\n",
347
  " # Take the argmax of the logits tensor along dimension 1 to get the predicted labels\n",
348
  " logits_tensor = logits_tensor.argmax(dim=1)\n",
349
+ " # Detach the predicted labels from the computation graph and move them to the CPU \n",
350
+ " # (although they are already on the CPU) to save memory and to use numpy features like the metrics module\n",
351
  " pred_labels = logits_tensor.detach().cpu().numpy()\n",
352
+ " # Computes metrics\n",
353
  " metrics = metric.compute(\n",
354
  " predictions=pred_labels,\n",
355
  " references=labels,\n",