File size: 53,681 Bytes
7885a28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 |
{{py:
"""
Template file to easily generate loops over samples using Tempita
(https://github.com/cython/cython/blob/master/Cython/Tempita/_tempita.py).
Generated file: _loss.pyx
Each loss class is generated by a cdef functions on single samples.
The keywords between double braces are substituted during the build.
"""
doc_HalfSquaredError = (
"""Half Squared Error with identity link.
Domain:
y_true and y_pred all real numbers
Link:
y_pred = raw_prediction
"""
)
doc_AbsoluteError = (
"""Absolute Error with identity link.
Domain:
y_true and y_pred all real numbers
Link:
y_pred = raw_prediction
"""
)
doc_PinballLoss = (
"""Quantile Loss aka Pinball Loss with identity link.
Domain:
y_true and y_pred all real numbers
quantile in (0, 1)
Link:
y_pred = raw_prediction
Note: 2 * cPinballLoss(quantile=0.5) equals cAbsoluteError()
"""
)
doc_HuberLoss = (
"""Huber Loss with identity link.
Domain:
y_true and y_pred all real numbers
delta in positive real numbers
Link:
y_pred = raw_prediction
"""
)
doc_HalfPoissonLoss = (
"""Half Poisson deviance loss with log-link.
Domain:
y_true in non-negative real numbers
y_pred in positive real numbers
Link:
y_pred = exp(raw_prediction)
Half Poisson deviance with log-link is
y_true * log(y_true/y_pred) + y_pred - y_true
= y_true * log(y_true) - y_true * raw_prediction
+ exp(raw_prediction) - y_true
Dropping constant terms, this gives:
exp(raw_prediction) - y_true * raw_prediction
"""
)
doc_HalfGammaLoss = (
"""Half Gamma deviance loss with log-link.
Domain:
y_true and y_pred in positive real numbers
Link:
y_pred = exp(raw_prediction)
Half Gamma deviance with log-link is
log(y_pred/y_true) + y_true/y_pred - 1
= raw_prediction - log(y_true) + y_true * exp(-raw_prediction) - 1
Dropping constant terms, this gives:
raw_prediction + y_true * exp(-raw_prediction)
"""
)
doc_HalfTweedieLoss = (
"""Half Tweedie deviance loss with log-link.
Domain:
y_true in real numbers if p <= 0
y_true in non-negative real numbers if 0 < p < 2
y_true in positive real numbers if p >= 2
y_pred and power in positive real numbers
Link:
y_pred = exp(raw_prediction)
Half Tweedie deviance with log-link and p=power is
max(y_true, 0)**(2-p) / (1-p) / (2-p)
- y_true * y_pred**(1-p) / (1-p)
+ y_pred**(2-p) / (2-p)
= max(y_true, 0)**(2-p) / (1-p) / (2-p)
- y_true * exp((1-p) * raw_prediction) / (1-p)
+ exp((2-p) * raw_prediction) / (2-p)
Dropping constant terms, this gives:
exp((2-p) * raw_prediction) / (2-p)
- y_true * exp((1-p) * raw_prediction) / (1-p)
Notes:
- Poisson with p=1 and and Gamma with p=2 have different terms dropped such
that cHalfTweedieLoss is not continuous in p=power at p=1 and p=2.
- While the Tweedie distribution only exists for p<=0 or p>=1, the range
0<p<1 still gives a strictly consistent scoring function for the
expectation.
"""
)
doc_HalfTweedieLossIdentity = (
"""Half Tweedie deviance loss with identity link.
Domain:
y_true in real numbers if p <= 0
y_true in non-negative real numbers if 0 < p < 2
y_true in positive real numbers if p >= 2
y_pred and power in positive real numbers, y_pred may be negative for p=0.
Link:
y_pred = raw_prediction
Half Tweedie deviance with identity link and p=power is
max(y_true, 0)**(2-p) / (1-p) / (2-p)
- y_true * y_pred**(1-p) / (1-p)
+ y_pred**(2-p) / (2-p)
Notes:
- Here, we do not drop constant terms in contrast to the version with log-link.
"""
)
doc_HalfBinomialLoss = (
"""Half Binomial deviance loss with logit link.
Domain:
y_true in [0, 1]
y_pred in (0, 1), i.e. boundaries excluded
Link:
y_pred = expit(raw_prediction)
"""
)
doc_ExponentialLoss = (
""""Exponential loss with (half) logit link
Domain:
y_true in [0, 1]
y_pred in (0, 1), i.e. boundaries excluded
Link:
y_pred = expit(2 * raw_prediction)
"""
)
# loss class name, docstring, param,
# cy_loss, cy_loss_grad,
# cy_grad, cy_grad_hess,
class_list = [
("CyHalfSquaredError", doc_HalfSquaredError, None,
"closs_half_squared_error", None,
"cgradient_half_squared_error", "cgrad_hess_half_squared_error"),
("CyAbsoluteError", doc_AbsoluteError, None,
"closs_absolute_error", None,
"cgradient_absolute_error", "cgrad_hess_absolute_error"),
("CyPinballLoss", doc_PinballLoss, "quantile",
"closs_pinball_loss", None,
"cgradient_pinball_loss", "cgrad_hess_pinball_loss"),
("CyHuberLoss", doc_HuberLoss, "delta",
"closs_huber_loss", None,
"cgradient_huber_loss", "cgrad_hess_huber_loss"),
("CyHalfPoissonLoss", doc_HalfPoissonLoss, None,
"closs_half_poisson", "closs_grad_half_poisson",
"cgradient_half_poisson", "cgrad_hess_half_poisson"),
("CyHalfGammaLoss", doc_HalfGammaLoss, None,
"closs_half_gamma", "closs_grad_half_gamma",
"cgradient_half_gamma", "cgrad_hess_half_gamma"),
("CyHalfTweedieLoss", doc_HalfTweedieLoss, "power",
"closs_half_tweedie", "closs_grad_half_tweedie",
"cgradient_half_tweedie", "cgrad_hess_half_tweedie"),
("CyHalfTweedieLossIdentity", doc_HalfTweedieLossIdentity, "power",
"closs_half_tweedie_identity", "closs_grad_half_tweedie_identity",
"cgradient_half_tweedie_identity", "cgrad_hess_half_tweedie_identity"),
("CyHalfBinomialLoss", doc_HalfBinomialLoss, None,
"closs_half_binomial", "closs_grad_half_binomial",
"cgradient_half_binomial", "cgrad_hess_half_binomial"),
("CyExponentialLoss", doc_ExponentialLoss, None,
"closs_exponential", "closs_grad_exponential",
"cgradient_exponential", "cgrad_hess_exponential"),
]
}}
# Design:
# See https://github.com/scikit-learn/scikit-learn/issues/15123 for reasons.
# a) Merge link functions into loss functions for speed and numerical
# stability, i.e. use raw_prediction instead of y_pred in signature.
# b) Pure C functions (nogil) calculate single points (single sample)
# c) Wrap C functions in a loop to get Python functions operating on ndarrays.
# - Write loops manually---use Tempita for this.
# Reason: There is still some performance overhead when using a wrapper
# function "wrap" that carries out the loop and gets as argument a function
# pointer to one of the C functions from b), e.g.
# wrap(closs_half_poisson, y_true, ...)
# - Pass n_threads as argument to prange and propagate option to all callers.
# d) Provide classes (Cython extension types) per loss (names start with Cy) in
# order to have semantical structured objects.
# - Member functions for single points just call the C function from b).
# These are used e.g. in SGD `_plain_sgd`.
# - Member functions operating on ndarrays, see c), looping over calls to C
# functions from b).
# e) Provide convenience Python classes that compose from these extension types
# elsewhere (see loss.py)
# - Example: loss.gradient calls CyLoss.gradient but does some input
# checking like None -> np.empty().
#
# Note: We require 1-dim ndarrays to be contiguous.
from cython.parallel import parallel, prange
import numpy as np
from libc.math cimport exp, fabs, log, log1p, pow
from libc.stdlib cimport malloc, free
# -------------------------------------
# Helper functions
# -------------------------------------
# Numerically stable version of log(1 + exp(x)) for double precision, see Eq. (10) of
# https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
# Note: The only important cutoff is at x = 18. All others are to save computation
# time. Compared to the reference, we add the additional case distinction x <= -2 in
# order to use log instead of log1p for improved performance. As with the other
# cutoffs, this is accurate within machine precision of double.
cdef inline double log1pexp(double x) noexcept nogil:
if x <= -37:
return exp(x)
elif x <= -2:
return log1p(exp(x))
elif x <= 18:
return log(1. + exp(x))
elif x <= 33.3:
return x + exp(-x)
else:
return x
cdef inline double_pair sum_exp_minus_max(
const int i,
const floating_in[:, :] raw_prediction, # IN
floating_out *p # OUT
) noexcept nogil:
# Thread local buffers are used to store part of the results via p.
# The results are stored as follows:
# p[k] = exp(raw_prediction_i_k - max_value) for k = 0 to n_classes-1
# return.val1 = max_value = max(raw_prediction_i_k, k = 0 to n_classes-1)
# return.val2 = sum_exps = sum(p[k], k = 0 to n_classes-1) = sum of exponentials
# len(p) must be n_classes
# Notes:
# - We return the max value and sum of exps (stored in p) as a double_pair.
# - i needs to be passed (and stays constant) because otherwise Cython does
# not generate optimal code, see
# https://github.com/scikit-learn/scikit-learn/issues/17299
# - We do not normalize p by calculating p[k] = p[k] / sum_exps.
# This helps to save one loop over k.
cdef:
int k
int n_classes = raw_prediction.shape[1]
double_pair max_value_and_sum_exps # val1 = max_value, val2 = sum_exps
max_value_and_sum_exps.val1 = raw_prediction[i, 0]
max_value_and_sum_exps.val2 = 0
for k in range(1, n_classes):
# Compute max value of array for numerical stability
if max_value_and_sum_exps.val1 < raw_prediction[i, k]:
max_value_and_sum_exps.val1 = raw_prediction[i, k]
for k in range(n_classes):
p[k] = exp(raw_prediction[i, k] - max_value_and_sum_exps.val1)
max_value_and_sum_exps.val2 += p[k]
return max_value_and_sum_exps
# -------------------------------------
# Single point inline C functions
# -------------------------------------
# Half Squared Error
cdef inline double closs_half_squared_error(
double y_true,
double raw_prediction
) noexcept nogil:
return 0.5 * (raw_prediction - y_true) * (raw_prediction - y_true)
cdef inline double cgradient_half_squared_error(
double y_true,
double raw_prediction
) noexcept nogil:
return raw_prediction - y_true
cdef inline double_pair cgrad_hess_half_squared_error(
double y_true,
double raw_prediction
) noexcept nogil:
cdef double_pair gh
gh.val1 = raw_prediction - y_true # gradient
gh.val2 = 1. # hessian
return gh
# Absolute Error
cdef inline double closs_absolute_error(
double y_true,
double raw_prediction
) noexcept nogil:
return fabs(raw_prediction - y_true)
cdef inline double cgradient_absolute_error(
double y_true,
double raw_prediction
) noexcept nogil:
return 1. if raw_prediction > y_true else -1.
cdef inline double_pair cgrad_hess_absolute_error(
double y_true,
double raw_prediction
) noexcept nogil:
cdef double_pair gh
# Note that exact hessian = 0 almost everywhere. Optimization routines like
# in HGBT, however, need a hessian > 0. Therefore, we assign 1.
gh.val1 = 1. if raw_prediction > y_true else -1. # gradient
gh.val2 = 1. # hessian
return gh
# Quantile Loss / Pinball Loss
cdef inline double closs_pinball_loss(
double y_true,
double raw_prediction,
double quantile
) noexcept nogil:
return (quantile * (y_true - raw_prediction) if y_true >= raw_prediction
else (1. - quantile) * (raw_prediction - y_true))
cdef inline double cgradient_pinball_loss(
double y_true,
double raw_prediction,
double quantile
) noexcept nogil:
return -quantile if y_true >=raw_prediction else 1. - quantile
cdef inline double_pair cgrad_hess_pinball_loss(
double y_true,
double raw_prediction,
double quantile
) noexcept nogil:
cdef double_pair gh
# Note that exact hessian = 0 almost everywhere. Optimization routines like
# in HGBT, however, need a hessian > 0. Therefore, we assign 1.
gh.val1 = -quantile if y_true >=raw_prediction else 1. - quantile # gradient
gh.val2 = 1. # hessian
return gh
# Huber Loss
cdef inline double closs_huber_loss(
double y_true,
double raw_prediction,
double delta,
) noexcept nogil:
cdef double abserr = fabs(y_true - raw_prediction)
if abserr <= delta:
return 0.5 * abserr**2
else:
return delta * (abserr - 0.5 * delta)
cdef inline double cgradient_huber_loss(
double y_true,
double raw_prediction,
double delta,
) noexcept nogil:
cdef double res = raw_prediction - y_true
if fabs(res) <= delta:
return res
else:
return delta if res >=0 else -delta
cdef inline double_pair cgrad_hess_huber_loss(
double y_true,
double raw_prediction,
double delta,
) noexcept nogil:
cdef double_pair gh
gh.val2 = raw_prediction - y_true # used as temporary
if fabs(gh.val2) <= delta:
gh.val1 = gh.val2 # gradient
gh.val2 = 1 # hessian
else:
gh.val1 = delta if gh.val2 >=0 else -delta # gradient
gh.val2 = 0 # hessian
return gh
# Half Poisson Deviance with Log-Link, dropping constant terms
cdef inline double closs_half_poisson(
double y_true,
double raw_prediction
) noexcept nogil:
return exp(raw_prediction) - y_true * raw_prediction
cdef inline double cgradient_half_poisson(
double y_true,
double raw_prediction
) noexcept nogil:
# y_pred - y_true
return exp(raw_prediction) - y_true
cdef inline double_pair closs_grad_half_poisson(
double y_true,
double raw_prediction
) noexcept nogil:
cdef double_pair lg
lg.val2 = exp(raw_prediction) # used as temporary
lg.val1 = lg.val2 - y_true * raw_prediction # loss
lg.val2 -= y_true # gradient
return lg
cdef inline double_pair cgrad_hess_half_poisson(
double y_true,
double raw_prediction
) noexcept nogil:
cdef double_pair gh
gh.val2 = exp(raw_prediction) # hessian
gh.val1 = gh.val2 - y_true # gradient
return gh
# Half Gamma Deviance with Log-Link, dropping constant terms
cdef inline double closs_half_gamma(
double y_true,
double raw_prediction
) noexcept nogil:
return raw_prediction + y_true * exp(-raw_prediction)
cdef inline double cgradient_half_gamma(
double y_true,
double raw_prediction
) noexcept nogil:
return 1. - y_true * exp(-raw_prediction)
cdef inline double_pair closs_grad_half_gamma(
double y_true,
double raw_prediction
) noexcept nogil:
cdef double_pair lg
lg.val2 = exp(-raw_prediction) # used as temporary
lg.val1 = raw_prediction + y_true * lg.val2 # loss
lg.val2 = 1. - y_true * lg.val2 # gradient
return lg
cdef inline double_pair cgrad_hess_half_gamma(
double y_true,
double raw_prediction
) noexcept nogil:
cdef double_pair gh
gh.val2 = exp(-raw_prediction) # used as temporary
gh.val1 = 1. - y_true * gh.val2 # gradient
gh.val2 *= y_true # hessian
return gh
# Half Tweedie Deviance with Log-Link, dropping constant terms
# Note that by dropping constants this is no longer continuous in parameter power.
cdef inline double closs_half_tweedie(
double y_true,
double raw_prediction,
double power
) noexcept nogil:
if power == 0.:
return closs_half_squared_error(y_true, exp(raw_prediction))
elif power == 1.:
return closs_half_poisson(y_true, raw_prediction)
elif power == 2.:
return closs_half_gamma(y_true, raw_prediction)
else:
return (exp((2. - power) * raw_prediction) / (2. - power)
- y_true * exp((1. - power) * raw_prediction) / (1. - power))
cdef inline double cgradient_half_tweedie(
double y_true,
double raw_prediction,
double power
) noexcept nogil:
cdef double exp1
if power == 0.:
exp1 = exp(raw_prediction)
return exp1 * (exp1 - y_true)
elif power == 1.:
return cgradient_half_poisson(y_true, raw_prediction)
elif power == 2.:
return cgradient_half_gamma(y_true, raw_prediction)
else:
return (exp((2. - power) * raw_prediction)
- y_true * exp((1. - power) * raw_prediction))
cdef inline double_pair closs_grad_half_tweedie(
double y_true,
double raw_prediction,
double power
) noexcept nogil:
cdef double_pair lg
cdef double exp1, exp2
if power == 0.:
exp1 = exp(raw_prediction)
lg.val1 = closs_half_squared_error(y_true, exp1) # loss
lg.val2 = exp1 * (exp1 - y_true) # gradient
elif power == 1.:
return closs_grad_half_poisson(y_true, raw_prediction)
elif power == 2.:
return closs_grad_half_gamma(y_true, raw_prediction)
else:
exp1 = exp((1. - power) * raw_prediction)
exp2 = exp((2. - power) * raw_prediction)
lg.val1 = exp2 / (2. - power) - y_true * exp1 / (1. - power) # loss
lg.val2 = exp2 - y_true * exp1 # gradient
return lg
cdef inline double_pair cgrad_hess_half_tweedie(
double y_true,
double raw_prediction,
double power
) noexcept nogil:
cdef double_pair gh
cdef double exp1, exp2
if power == 0.:
exp1 = exp(raw_prediction)
gh.val1 = exp1 * (exp1 - y_true) # gradient
gh.val2 = exp1 * (2 * exp1 - y_true) # hessian
elif power == 1.:
return cgrad_hess_half_poisson(y_true, raw_prediction)
elif power == 2.:
return cgrad_hess_half_gamma(y_true, raw_prediction)
else:
exp1 = exp((1. - power) * raw_prediction)
exp2 = exp((2. - power) * raw_prediction)
gh.val1 = exp2 - y_true * exp1 # gradient
gh.val2 = (2. - power) * exp2 - (1. - power) * y_true * exp1 # hessian
return gh
# Half Tweedie Deviance with identity link, without dropping constant terms!
# Therefore, best loss value is zero.
cdef inline double closs_half_tweedie_identity(
double y_true,
double raw_prediction,
double power
) noexcept nogil:
cdef double tmp
if power == 0.:
return closs_half_squared_error(y_true, raw_prediction)
elif power == 1.:
if y_true == 0:
return raw_prediction
else:
return y_true * log(y_true/raw_prediction) + raw_prediction - y_true
elif power == 2.:
return log(raw_prediction/y_true) + y_true/raw_prediction - 1.
else:
tmp = pow(raw_prediction, 1. - power)
tmp = raw_prediction * tmp / (2. - power) - y_true * tmp / (1. - power)
if y_true > 0:
tmp += pow(y_true, 2. - power) / ((1. - power) * (2. - power))
return tmp
cdef inline double cgradient_half_tweedie_identity(
double y_true,
double raw_prediction,
double power
) noexcept nogil:
if power == 0.:
return raw_prediction - y_true
elif power == 1.:
return 1. - y_true / raw_prediction
elif power == 2.:
return (raw_prediction - y_true) / (raw_prediction * raw_prediction)
else:
return pow(raw_prediction, -power) * (raw_prediction - y_true)
cdef inline double_pair closs_grad_half_tweedie_identity(
double y_true,
double raw_prediction,
double power
) noexcept nogil:
cdef double_pair lg
cdef double tmp
if power == 0.:
lg.val2 = raw_prediction - y_true # gradient
lg.val1 = 0.5 * lg.val2 * lg.val2 # loss
elif power == 1.:
if y_true == 0:
lg.val1 = raw_prediction
else:
lg.val1 = (y_true * log(y_true/raw_prediction) # loss
+ raw_prediction - y_true)
lg.val2 = 1. - y_true / raw_prediction # gradient
elif power == 2.:
lg.val1 = log(raw_prediction/y_true) + y_true/raw_prediction - 1. # loss
tmp = raw_prediction * raw_prediction
lg.val2 = (raw_prediction - y_true) / tmp # gradient
else:
tmp = pow(raw_prediction, 1. - power)
lg.val1 = (raw_prediction * tmp / (2. - power) # loss
- y_true * tmp / (1. - power))
if y_true > 0:
lg.val1 += (pow(y_true, 2. - power)
/ ((1. - power) * (2. - power)))
lg.val2 = tmp * (1. - y_true / raw_prediction) # gradient
return lg
cdef inline double_pair cgrad_hess_half_tweedie_identity(
double y_true,
double raw_prediction,
double power
) noexcept nogil:
cdef double_pair gh
cdef double tmp
if power == 0.:
gh.val1 = raw_prediction - y_true # gradient
gh.val2 = 1. # hessian
elif power == 1.:
gh.val1 = 1. - y_true / raw_prediction # gradient
gh.val2 = y_true / (raw_prediction * raw_prediction) # hessian
elif power == 2.:
tmp = raw_prediction * raw_prediction
gh.val1 = (raw_prediction - y_true) / tmp # gradient
gh.val2 = (-1. + 2. * y_true / raw_prediction) / tmp # hessian
else:
tmp = pow(raw_prediction, -power)
gh.val1 = tmp * (raw_prediction - y_true) # gradient
gh.val2 = tmp * ((1. - power) + power * y_true / raw_prediction) # hessian
return gh
# Half Binomial deviance with logit-link, aka log-loss or binary cross entropy
cdef inline double closs_half_binomial(
double y_true,
double raw_prediction
) noexcept nogil:
# log1p(exp(raw_prediction)) - y_true * raw_prediction
return log1pexp(raw_prediction) - y_true * raw_prediction
cdef inline double cgradient_half_binomial(
double y_true,
double raw_prediction
) noexcept nogil:
# gradient = y_pred - y_true = expit(raw_prediction) - y_true
# Numerically more stable, see http://fa.bianp.net/blog/2019/evaluate_logistic/
# if raw_prediction < 0:
# exp_tmp = exp(raw_prediction)
# return ((1 - y_true) * exp_tmp - y_true) / (1 + exp_tmp)
# else:
# exp_tmp = exp(-raw_prediction)
# return ((1 - y_true) - y_true * exp_tmp) / (1 + exp_tmp)
# Note that optimal speed would be achieved, at the cost of precision, by
# return expit(raw_prediction) - y_true
# i.e. no "if else" and an own inline implementation of expit instead of
# from scipy.special.cython_special cimport expit
# The case distinction raw_prediction < 0 in the stable implementation does not
# provide significant better precision apart from protecting overflow of exp(..).
# The branch (if else), however, can incur runtime costs of up to 30%.
# Instead, we help branch prediction by almost always ending in the first if clause
# and making the second branch (else) a bit simpler. This has the exact same
# precision but is faster than the stable implementation.
# As branching criteria, we use the same cutoff as in log1pexp. Note that the
# maximal value to get gradient = -1 with y_true = 1 is -37.439198610162731
# (based on mpmath), and scipy.special.logit(np.finfo(float).eps) ~ -36.04365.
cdef double exp_tmp
if raw_prediction > -37:
exp_tmp = exp(-raw_prediction)
return ((1 - y_true) - y_true * exp_tmp) / (1 + exp_tmp)
else:
# expit(raw_prediction) = exp(raw_prediction) for raw_prediction <= -37
return exp(raw_prediction) - y_true
cdef inline double_pair closs_grad_half_binomial(
double y_true,
double raw_prediction
) noexcept nogil:
cdef double_pair lg
# Same if else conditions as in log1pexp.
if raw_prediction <= -37:
lg.val2 = exp(raw_prediction) # used as temporary
lg.val1 = lg.val2 - y_true * raw_prediction # loss
lg.val2 -= y_true # gradient
elif raw_prediction <= -2:
lg.val2 = exp(raw_prediction) # used as temporary
lg.val1 = log1p(lg.val2) - y_true * raw_prediction # loss
lg.val2 = ((1 - y_true) * lg.val2 - y_true) / (1 + lg.val2) # gradient
elif raw_prediction <= 18:
lg.val2 = exp(-raw_prediction) # used as temporary
# log1p(exp(x)) = log(1 + exp(x)) = x + log1p(exp(-x))
lg.val1 = log1p(lg.val2) + (1 - y_true) * raw_prediction # loss
lg.val2 = ((1 - y_true) - y_true * lg.val2) / (1 + lg.val2) # gradient
else:
lg.val2 = exp(-raw_prediction) # used as temporary
lg.val1 = lg.val2 + (1 - y_true) * raw_prediction # loss
lg.val2 = ((1 - y_true) - y_true * lg.val2) / (1 + lg.val2) # gradient
return lg
cdef inline double_pair cgrad_hess_half_binomial(
double y_true,
double raw_prediction
) noexcept nogil:
# with y_pred = expit(raw)
# hessian = y_pred * (1 - y_pred) = exp( raw) / (1 + exp( raw))**2
# = exp(-raw) / (1 + exp(-raw))**2
cdef double_pair gh
# See comment in cgradient_half_binomial.
if raw_prediction > -37:
gh.val2 = exp(-raw_prediction) # used as temporary
gh.val1 = ((1 - y_true) - y_true * gh.val2) / (1 + gh.val2) # gradient
gh.val2 = gh.val2 / (1 + gh.val2)**2 # hessian
else:
gh.val2 = exp(raw_prediction) # = 1. order Taylor in exp(raw_prediction)
gh.val1 = gh.val2 - y_true
return gh
# Exponential loss with (half) logit-link, aka boosting loss
cdef inline double closs_exponential(
double y_true,
double raw_prediction
) noexcept nogil:
cdef double tmp = exp(raw_prediction)
return y_true / tmp + (1 - y_true) * tmp
cdef inline double cgradient_exponential(
double y_true,
double raw_prediction
) noexcept nogil:
cdef double tmp = exp(raw_prediction)
return -y_true / tmp + (1 - y_true) * tmp
cdef inline double_pair closs_grad_exponential(
double y_true,
double raw_prediction
) noexcept nogil:
cdef double_pair lg
lg.val2 = exp(raw_prediction) # used as temporary
lg.val1 = y_true / lg.val2 + (1 - y_true) * lg.val2 # loss
lg.val2 = -y_true / lg.val2 + (1 - y_true) * lg.val2 # gradient
return lg
cdef inline double_pair cgrad_hess_exponential(
double y_true,
double raw_prediction
) noexcept nogil:
# Note that hessian = loss
cdef double_pair gh
gh.val2 = exp(raw_prediction) # used as temporary
gh.val1 = -y_true / gh.val2 + (1 - y_true) * gh.val2 # gradient
gh.val2 = y_true / gh.val2 + (1 - y_true) * gh.val2 # hessian
return gh
# ---------------------------------------------------
# Extension Types for Loss Functions of 1-dim targets
# ---------------------------------------------------
cdef class CyLossFunction:
"""Base class for convex loss functions."""
def __reduce__(self):
return (self.__class__, ())
cdef double cy_loss(self, double y_true, double raw_prediction) noexcept nogil:
"""Compute the loss for a single sample.
Parameters
----------
y_true : double
Observed, true target value.
raw_prediction : double
Raw prediction value (in link space).
Returns
-------
double
The loss evaluated at `y_true` and `raw_prediction`.
"""
pass
cdef double cy_gradient(self, double y_true, double raw_prediction) noexcept nogil:
"""Compute gradient of loss w.r.t. raw_prediction for a single sample.
Parameters
----------
y_true : double
Observed, true target value.
raw_prediction : double
Raw prediction value (in link space).
Returns
-------
double
The derivative of the loss function w.r.t. `raw_prediction`.
"""
pass
cdef double_pair cy_grad_hess(
self, double y_true, double raw_prediction
) noexcept nogil:
"""Compute gradient and hessian.
Gradient and hessian of loss w.r.t. raw_prediction for a single sample.
This is usually diagonal in raw_prediction_i and raw_prediction_j.
Therefore, we return the diagonal element i=j.
For a loss with a non-canonical link, this might implement the diagonal
of the Fisher matrix (=expected hessian) instead of the hessian.
Parameters
----------
y_true : double
Observed, true target value.
raw_prediction : double
Raw prediction value (in link space).
Returns
-------
double_pair
Gradient and hessian of the loss function w.r.t. `raw_prediction`.
"""
pass
def loss(
self,
const floating_in[::1] y_true, # IN
const floating_in[::1] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[::1] loss_out, # OUT
int n_threads=1
):
"""Compute the point-wise loss value for each input.
The point-wise loss is written to `loss_out` and no array is returned.
Parameters
----------
y_true : array of shape (n_samples,)
Observed, true target values.
raw_prediction : array of shape (n_samples,)
Raw prediction values (in link space).
sample_weight : array of shape (n_samples,) or None
Sample weights.
loss_out : array of shape (n_samples,)
A location into which the result is stored.
n_threads : int
Number of threads used by OpenMP (if any).
"""
pass
def gradient(
self,
const floating_in[::1] y_true, # IN
const floating_in[::1] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[::1] gradient_out, # OUT
int n_threads=1
):
"""Compute gradient of loss w.r.t raw_prediction for each input.
The gradient is written to `gradient_out` and no array is returned.
Parameters
----------
y_true : array of shape (n_samples,)
Observed, true target values.
raw_prediction : array of shape (n_samples,)
Raw prediction values (in link space).
sample_weight : array of shape (n_samples,) or None
Sample weights.
gradient_out : array of shape (n_samples,)
A location into which the result is stored.
n_threads : int
Number of threads used by OpenMP (if any).
"""
pass
def loss_gradient(
self,
const floating_in[::1] y_true, # IN
const floating_in[::1] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[::1] loss_out, # OUT
floating_out[::1] gradient_out, # OUT
int n_threads=1
):
"""Compute loss and gradient of loss w.r.t raw_prediction.
The loss and gradient are written to `loss_out` and `gradient_out` and no arrays
are returned.
Parameters
----------
y_true : array of shape (n_samples,)
Observed, true target values.
raw_prediction : array of shape (n_samples,)
Raw prediction values (in link space).
sample_weight : array of shape (n_samples,) or None
Sample weights.
loss_out : array of shape (n_samples,) or None
A location into which the element-wise loss is stored.
gradient_out : array of shape (n_samples,)
A location into which the gradient is stored.
n_threads : int
Number of threads used by OpenMP (if any).
"""
self.loss(y_true, raw_prediction, sample_weight, loss_out, n_threads)
self.gradient(y_true, raw_prediction, sample_weight, gradient_out, n_threads)
def gradient_hessian(
self,
const floating_in[::1] y_true, # IN
const floating_in[::1] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[::1] gradient_out, # OUT
floating_out[::1] hessian_out, # OUT
int n_threads=1
):
"""Compute gradient and hessian of loss w.r.t raw_prediction.
The gradient and hessian are written to `gradient_out` and `hessian_out` and no
arrays are returned.
Parameters
----------
y_true : array of shape (n_samples,)
Observed, true target values.
raw_prediction : array of shape (n_samples,)
Raw prediction values (in link space).
sample_weight : array of shape (n_samples,) or None
Sample weights.
gradient_out : array of shape (n_samples,)
A location into which the gradient is stored.
hessian_out : array of shape (n_samples,)
A location into which the hessian is stored.
n_threads : int
Number of threads used by OpenMP (if any).
"""
pass
{{for name, docstring, param, closs, closs_grad, cgrad, cgrad_hess, in class_list}}
{{py:
if param is None:
with_param = ""
else:
with_param = ", self." + param
}}
cdef class {{name}}(CyLossFunction):
"""{{docstring}}"""
{{if param is not None}}
def __init__(self, {{param}}):
self.{{param}} = {{param}}
{{endif}}
{{if param is not None}}
def __reduce__(self):
return (self.__class__, (self.{{param}},))
{{endif}}
cdef inline double cy_loss(self, double y_true, double raw_prediction) noexcept nogil:
return {{closs}}(y_true, raw_prediction{{with_param}})
cdef inline double cy_gradient(self, double y_true, double raw_prediction) noexcept nogil:
return {{cgrad}}(y_true, raw_prediction{{with_param}})
cdef inline double_pair cy_grad_hess(self, double y_true, double raw_prediction) noexcept nogil:
return {{cgrad_hess}}(y_true, raw_prediction{{with_param}})
def loss(
self,
const floating_in[::1] y_true, # IN
const floating_in[::1] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[::1] loss_out, # OUT
int n_threads=1
):
cdef:
int i
int n_samples = y_true.shape[0]
if sample_weight is None:
for i in prange(
n_samples, schedule='static', nogil=True, num_threads=n_threads
):
loss_out[i] = {{closs}}(y_true[i], raw_prediction[i]{{with_param}})
else:
for i in prange(
n_samples, schedule='static', nogil=True, num_threads=n_threads
):
loss_out[i] = sample_weight[i] * {{closs}}(y_true[i], raw_prediction[i]{{with_param}})
{{if closs_grad is not None}}
def loss_gradient(
self,
const floating_in[::1] y_true, # IN
const floating_in[::1] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[::1] loss_out, # OUT
floating_out[::1] gradient_out, # OUT
int n_threads=1
):
cdef:
int i
int n_samples = y_true.shape[0]
double_pair dbl2
if sample_weight is None:
for i in prange(
n_samples, schedule='static', nogil=True, num_threads=n_threads
):
dbl2 = {{closs_grad}}(y_true[i], raw_prediction[i]{{with_param}})
loss_out[i] = dbl2.val1
gradient_out[i] = dbl2.val2
else:
for i in prange(
n_samples, schedule='static', nogil=True, num_threads=n_threads
):
dbl2 = {{closs_grad}}(y_true[i], raw_prediction[i]{{with_param}})
loss_out[i] = sample_weight[i] * dbl2.val1
gradient_out[i] = sample_weight[i] * dbl2.val2
{{endif}}
def gradient(
self,
const floating_in[::1] y_true, # IN
const floating_in[::1] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[::1] gradient_out, # OUT
int n_threads=1
):
cdef:
int i
int n_samples = y_true.shape[0]
if sample_weight is None:
for i in prange(
n_samples, schedule='static', nogil=True, num_threads=n_threads
):
gradient_out[i] = {{cgrad}}(y_true[i], raw_prediction[i]{{with_param}})
else:
for i in prange(
n_samples, schedule='static', nogil=True, num_threads=n_threads
):
gradient_out[i] = sample_weight[i] * {{cgrad}}(y_true[i], raw_prediction[i]{{with_param}})
def gradient_hessian(
self,
const floating_in[::1] y_true, # IN
const floating_in[::1] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[::1] gradient_out, # OUT
floating_out[::1] hessian_out, # OUT
int n_threads=1
):
cdef:
int i
int n_samples = y_true.shape[0]
double_pair dbl2
if sample_weight is None:
for i in prange(
n_samples, schedule='static', nogil=True, num_threads=n_threads
):
dbl2 = {{cgrad_hess}}(y_true[i], raw_prediction[i]{{with_param}})
gradient_out[i] = dbl2.val1
hessian_out[i] = dbl2.val2
else:
for i in prange(
n_samples, schedule='static', nogil=True, num_threads=n_threads
):
dbl2 = {{cgrad_hess}}(y_true[i], raw_prediction[i]{{with_param}})
gradient_out[i] = sample_weight[i] * dbl2.val1
hessian_out[i] = sample_weight[i] * dbl2.val2
{{endfor}}
# The multinomial deviance loss is also known as categorical cross-entropy or
# multinomial log-likelihood.
# Here, we do not inherit from CyLossFunction as its cy_gradient method deviates
# from the API.
cdef class CyHalfMultinomialLoss():
"""Half Multinomial deviance loss with multinomial logit link.
Domain:
y_true in {0, 1, 2, 3, .., n_classes - 1}
y_pred in (0, 1)**n_classes, i.e. interval with boundaries excluded
Link:
y_pred = softmax(raw_prediction)
Note: Label encoding is built-in, i.e. {0, 1, 2, 3, .., n_classes - 1} is
mapped to (y_true == k) for k = 0 .. n_classes - 1 which is either 0 or 1.
"""
# Here we deviate from the CyLossFunction API. SAG/SAGA needs direct access to
# sample-wise gradients which we provide here.
cdef inline void cy_gradient(
self,
const floating_in y_true,
const floating_in[::1] raw_prediction, # IN
const floating_in sample_weight,
floating_out[::1] gradient_out, # OUT
) noexcept nogil:
"""Compute gradient of loss w.r.t. `raw_prediction` for a single sample.
The gradient of the multinomial logistic loss with respect to a class k,
and for one sample is:
grad_k = - sw * (p[k] - (y==k))
where:
p[k] = proba[k] = exp(raw_prediction[k] - logsumexp(raw_prediction))
sw = sample_weight
Parameters
----------
y_true : double
Observed, true target value.
raw_prediction : array of shape (n_classes,)
Raw prediction values (in link space).
sample_weight : double
Sample weight.
gradient_out : array of shape (n_classs,)
A location into which the gradient is stored.
Returns
-------
gradient : double
The derivative of the loss function w.r.t. `raw_prediction`.
"""
cdef:
int k
int n_classes = raw_prediction.shape[0]
double_pair max_value_and_sum_exps
const floating_in[:, :] raw = raw_prediction[None, :]
max_value_and_sum_exps = sum_exp_minus_max(0, raw, &gradient_out[0])
for k in range(n_classes):
# gradient_out[k] = p_k = y_pred_k = prob of class k
gradient_out[k] /= max_value_and_sum_exps.val2
# gradient_k = (p_k - (y_true == k)) * sw
gradient_out[k] = (gradient_out[k] - (y_true == k)) * sample_weight
def _test_cy_gradient(
self,
const floating_in[::1] y_true, # IN
const floating_in[:, ::1] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
):
"""For testing only."""
cdef:
int i, k
int n_samples = y_true.shape[0]
int n_classes = raw_prediction.shape[1]
floating_in [:, ::1] gradient_out
gradient = np.empty((n_samples, n_classes), dtype=np.float64)
gradient_out = gradient
for i in range(n_samples):
self.cy_gradient(
y_true=y_true[i],
raw_prediction=raw_prediction[i, :],
sample_weight=1.0 if sample_weight is None else sample_weight[i],
gradient_out=gradient_out[i, :],
)
return gradient
# Note that we do not assume memory alignment/contiguity of 2d arrays.
# There seems to be little benefit in doing so. Benchmarks proofing the
# opposite are welcome.
def loss(
self,
const floating_in[::1] y_true, # IN
const floating_in[:, :] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[::1] loss_out, # OUT
int n_threads=1
):
cdef:
int i, k
int n_samples = y_true.shape[0]
int n_classes = raw_prediction.shape[1]
floating_in max_value, sum_exps
floating_in* p # temporary buffer
double_pair max_value_and_sum_exps
# We assume n_samples > n_classes. In this case having the inner loop
# over n_classes is a good default.
# TODO: If every memoryview is contiguous and raw_prediction is
# f-contiguous, can we write a better algo (loops) to improve
# performance?
if sample_weight is None:
# inner loop over n_classes
with nogil, parallel(num_threads=n_threads):
# Define private buffer variables as each thread might use its
# own.
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
for i in prange(n_samples, schedule='static'):
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
max_value = max_value_and_sum_exps.val1
sum_exps = max_value_and_sum_exps.val2
loss_out[i] = log(sum_exps) + max_value
# label encoded y_true
k = int(y_true[i])
loss_out[i] -= raw_prediction[i, k]
free(p)
else:
with nogil, parallel(num_threads=n_threads):
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
for i in prange(n_samples, schedule='static'):
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
max_value = max_value_and_sum_exps.val1
sum_exps = max_value_and_sum_exps.val2
loss_out[i] = log(sum_exps) + max_value
# label encoded y_true
k = int(y_true[i])
loss_out[i] -= raw_prediction[i, k]
loss_out[i] *= sample_weight[i]
free(p)
def loss_gradient(
self,
const floating_in[::1] y_true, # IN
const floating_in[:, :] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[::1] loss_out, # OUT
floating_out[:, :] gradient_out, # OUT
int n_threads=1
):
cdef:
int i, k
int n_samples = y_true.shape[0]
int n_classes = raw_prediction.shape[1]
floating_in max_value, sum_exps
floating_in* p # temporary buffer
double_pair max_value_and_sum_exps
if sample_weight is None:
# inner loop over n_classes
with nogil, parallel(num_threads=n_threads):
# Define private buffer variables as each thread might use its
# own.
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
for i in prange(n_samples, schedule='static'):
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
max_value = max_value_and_sum_exps.val1
sum_exps = max_value_and_sum_exps.val2
loss_out[i] = log(sum_exps) + max_value
for k in range(n_classes):
# label decode y_true
if y_true[i] == k:
loss_out[i] -= raw_prediction[i, k]
p[k] /= sum_exps # p_k = y_pred_k = prob of class k
# gradient_k = p_k - (y_true == k)
gradient_out[i, k] = p[k] - (y_true[i] == k)
free(p)
else:
with nogil, parallel(num_threads=n_threads):
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
for i in prange(n_samples, schedule='static'):
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
max_value = max_value_and_sum_exps.val1
sum_exps = max_value_and_sum_exps.val2
loss_out[i] = log(sum_exps) + max_value
for k in range(n_classes):
# label decode y_true
if y_true[i] == k:
loss_out[i] -= raw_prediction[i, k]
p[k] /= sum_exps # p_k = y_pred_k = prob of class k
# gradient_k = (p_k - (y_true == k)) * sw
gradient_out[i, k] = (p[k] - (y_true[i] == k)) * sample_weight[i]
loss_out[i] *= sample_weight[i]
free(p)
def gradient(
self,
const floating_in[::1] y_true, # IN
const floating_in[:, :] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[:, :] gradient_out, # OUT
int n_threads=1
):
cdef:
int i, k
int n_samples = y_true.shape[0]
int n_classes = raw_prediction.shape[1]
floating_in sum_exps
floating_in* p # temporary buffer
double_pair max_value_and_sum_exps
if sample_weight is None:
# inner loop over n_classes
with nogil, parallel(num_threads=n_threads):
# Define private buffer variables as each thread might use its
# own.
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
for i in prange(n_samples, schedule='static'):
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
sum_exps = max_value_and_sum_exps.val2
for k in range(n_classes):
p[k] /= sum_exps # p_k = y_pred_k = prob of class k
# gradient_k = y_pred_k - (y_true == k)
gradient_out[i, k] = p[k] - (y_true[i] == k)
free(p)
else:
with nogil, parallel(num_threads=n_threads):
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
for i in prange(n_samples, schedule='static'):
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
sum_exps = max_value_and_sum_exps.val2
for k in range(n_classes):
p[k] /= sum_exps # p_k = y_pred_k = prob of class k
# gradient_k = (p_k - (y_true == k)) * sw
gradient_out[i, k] = (p[k] - (y_true[i] == k)) * sample_weight[i]
free(p)
def gradient_hessian(
self,
const floating_in[::1] y_true, # IN
const floating_in[:, :] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[:, :] gradient_out, # OUT
floating_out[:, :] hessian_out, # OUT
int n_threads=1
):
cdef:
int i, k
int n_samples = y_true.shape[0]
int n_classes = raw_prediction.shape[1]
floating_in sum_exps
floating_in* p # temporary buffer
double_pair max_value_and_sum_exps
if sample_weight is None:
# inner loop over n_classes
with nogil, parallel(num_threads=n_threads):
# Define private buffer variables as each thread might use its
# own.
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
for i in prange(n_samples, schedule='static'):
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
sum_exps = max_value_and_sum_exps.val2
for k in range(n_classes):
p[k] /= sum_exps # p_k = y_pred_k = prob of class k
# hessian_k = p_k * (1 - p_k)
# gradient_k = p_k - (y_true == k)
gradient_out[i, k] = p[k] - (y_true[i] == k)
hessian_out[i, k] = p[k] * (1. - p[k])
free(p)
else:
with nogil, parallel(num_threads=n_threads):
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
for i in prange(n_samples, schedule='static'):
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
sum_exps = max_value_and_sum_exps.val2
for k in range(n_classes):
p[k] /= sum_exps # p_k = y_pred_k = prob of class k
# gradient_k = (p_k - (y_true == k)) * sw
# hessian_k = p_k * (1 - p_k) * sw
gradient_out[i, k] = (p[k] - (y_true[i] == k)) * sample_weight[i]
hessian_out[i, k] = (p[k] * (1. - p[k])) * sample_weight[i]
free(p)
# This method simplifies the implementation of hessp in linear models,
# i.e. the matrix-vector product of the full hessian, not only of the
# diagonal (in the classes) approximation as implemented above.
def gradient_proba(
self,
const floating_in[::1] y_true, # IN
const floating_in[:, :] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[:, :] gradient_out, # OUT
floating_out[:, :] proba_out, # OUT
int n_threads=1
):
cdef:
int i, k
int n_samples = y_true.shape[0]
int n_classes = raw_prediction.shape[1]
floating_in sum_exps
floating_in* p # temporary buffer
double_pair max_value_and_sum_exps
if sample_weight is None:
# inner loop over n_classes
with nogil, parallel(num_threads=n_threads):
# Define private buffer variables as each thread might use its
# own.
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
for i in prange(n_samples, schedule='static'):
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
sum_exps = max_value_and_sum_exps.val2
for k in range(n_classes):
proba_out[i, k] = p[k] / sum_exps # y_pred_k = prob of class k
# gradient_k = y_pred_k - (y_true == k)
gradient_out[i, k] = proba_out[i, k] - (y_true[i] == k)
free(p)
else:
with nogil, parallel(num_threads=n_threads):
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
for i in prange(n_samples, schedule='static'):
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
sum_exps = max_value_and_sum_exps.val2
for k in range(n_classes):
proba_out[i, k] = p[k] / sum_exps # y_pred_k = prob of class k
# gradient_k = (p_k - (y_true == k)) * sw
gradient_out[i, k] = (proba_out[i, k] - (y_true[i] == k)) * sample_weight[i]
free(p)
|