Spaces:
Build error
Build error
| import torch | |
| import torch.optim | |
| import torch.nn.functional as F | |
| import copy | |
| def update_generation_losses(losses, nums, micro, macro, bs, length, loss): | |
| # Update Losses | |
| losses[micro] += \ | |
| [copy.deepcopy(losses[micro][-1])] | |
| losses[macro] += \ | |
| [copy.deepcopy(losses[macro][-1])] | |
| losses[micro][-1] *= nums[micro] | |
| losses[macro][-1] *= nums[macro] | |
| nums[macro] += bs | |
| if isinstance(length, int): | |
| update_indiv_generation_losses( | |
| losses, nums, micro, macro, bs, length, loss) | |
| else: | |
| update_tensor_generation_losses( | |
| losses, nums, micro, macro, bs, length, loss) | |
| def update_indiv_generation_losses(losses, nums, micro, | |
| macro, bs, length, loss): | |
| nums[micro] += (bs * length) | |
| batch_loss = loss * bs | |
| losses[micro][-1] += batch_loss | |
| losses[micro][-1] /= nums[micro] | |
| losses[macro][-1] += batch_loss / length | |
| losses[macro][-1] /= nums[macro] | |
| def update_tensor_generation_losses(losses, nums, micro, | |
| macro, bs, length, loss): | |
| nums[micro] += length.sum().item() | |
| losses[micro][-1] += loss.sum().item() | |
| losses[micro][-1] /= nums[micro] | |
| losses[macro][-1] += (loss / length.float()).sum().item() | |
| losses[macro][-1] /= nums[macro] | |
| def modify_output_for_loss_fn(loss_fn, output, dim): | |
| if loss_fn == "ce": | |
| return output | |
| if loss_fn == "mse": | |
| return F.softmax(output, dim=dim) | |
| if loss_fn == "nll": | |
| return F.log_softmax(output, dim=dim) | |
| if loss_fn in ["bce", "wbce", "wbce1"]: | |
| return torch.sigmoid(output) | |