Spaces:
Build error
Build error
File size: 1,064 Bytes
a446b0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
def update_classification_losses(losses, nums, name, bs, loss):
if not isinstance(loss, float):
print(type(loss))
raise
nums[name] += bs
losses[name] += loss * bs
def update_generation_losses(losses, nums, micro, macro, bs, length, loss):
# Update Losses
nums[macro] += bs
if isinstance(length, int):
update_indiv_generation_losses(
losses, nums, micro, macro, bs, length, loss)
else:
update_tensor_generation_losses(
losses, nums, micro, macro, bs, length, loss)
def update_indiv_generation_losses(losses, nums, micro,
macro, bs, length, loss):
nums[micro] += bs * length
batch_loss = loss * bs
losses[micro] += batch_loss
losses[macro] += batch_loss / length
def update_tensor_generation_losses(losses, nums, micro,
macro, bs, length, loss):
nums[micro] += length.sum().item()
losses[micro] += loss.sum().item()
losses[macro] += (loss / length.float()).sum().item()
|