Minh Q. Le
Pushed COSMIC code
a446b0b
def update_classification_losses(losses, nums, name, bs, loss):
if not isinstance(loss, float):
print(type(loss))
raise
nums[name] += bs
losses[name] += loss * bs
def update_generation_losses(losses, nums, micro, macro, bs, length, loss):
# Update Losses
nums[macro] += bs
if isinstance(length, int):
update_indiv_generation_losses(
losses, nums, micro, macro, bs, length, loss)
else:
update_tensor_generation_losses(
losses, nums, micro, macro, bs, length, loss)
def update_indiv_generation_losses(losses, nums, micro,
macro, bs, length, loss):
nums[micro] += bs * length
batch_loss = loss * bs
losses[micro] += batch_loss
losses[macro] += batch_loss / length
def update_tensor_generation_losses(losses, nums, micro,
macro, bs, length, loss):
nums[micro] += length.sum().item()
losses[micro] += loss.sum().item()
losses[macro] += (loss / length.float()).sum().item()