Spaces:
Build error
Build error
File size: 1,641 Bytes
a446b0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import torch
import torch.optim
import torch.nn.functional as F
import copy
def update_generation_losses(losses, nums, micro, macro, bs, length, loss):
# Update Losses
losses[micro] += \
[copy.deepcopy(losses[micro][-1])]
losses[macro] += \
[copy.deepcopy(losses[macro][-1])]
losses[micro][-1] *= nums[micro]
losses[macro][-1] *= nums[macro]
nums[macro] += bs
if isinstance(length, int):
update_indiv_generation_losses(
losses, nums, micro, macro, bs, length, loss)
else:
update_tensor_generation_losses(
losses, nums, micro, macro, bs, length, loss)
def update_indiv_generation_losses(losses, nums, micro,
macro, bs, length, loss):
nums[micro] += (bs * length)
batch_loss = loss * bs
losses[micro][-1] += batch_loss
losses[micro][-1] /= nums[micro]
losses[macro][-1] += batch_loss / length
losses[macro][-1] /= nums[macro]
def update_tensor_generation_losses(losses, nums, micro,
macro, bs, length, loss):
nums[micro] += length.sum().item()
losses[micro][-1] += loss.sum().item()
losses[micro][-1] /= nums[micro]
losses[macro][-1] += (loss / length.float()).sum().item()
losses[macro][-1] /= nums[macro]
def modify_output_for_loss_fn(loss_fn, output, dim):
if loss_fn == "ce":
return output
if loss_fn == "mse":
return F.softmax(output, dim=dim)
if loss_fn == "nll":
return F.log_softmax(output, dim=dim)
if loss_fn in ["bce", "wbce", "wbce1"]:
return torch.sigmoid(output)
|