File size: 78,236 Bytes
7885a28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 |
"""Metrics to assess performance on classification task given scores.
Functions named as ``*_score`` return a scalar value to maximize: the higher
the better.
Function named as ``*_error`` or ``*_loss`` return a scalar value to minimize:
the lower the better.
"""
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import warnings
from functools import partial
from numbers import Integral, Real
import numpy as np
from scipy.integrate import trapezoid
from scipy.sparse import csr_matrix, issparse
from scipy.stats import rankdata
from ..exceptions import UndefinedMetricWarning
from ..preprocessing import label_binarize
from ..utils import (
assert_all_finite,
check_array,
check_consistent_length,
column_or_1d,
)
from ..utils._encode import _encode, _unique
from ..utils._param_validation import Hidden, Interval, StrOptions, validate_params
from ..utils.extmath import stable_cumsum
from ..utils.multiclass import type_of_target
from ..utils.sparsefuncs import count_nonzero
from ..utils.validation import _check_pos_label_consistency, _check_sample_weight
from ._base import _average_binary_score, _average_multiclass_ovo_score
@validate_params(
{"x": ["array-like"], "y": ["array-like"]},
prefer_skip_nested_validation=True,
)
def auc(x, y):
"""Compute Area Under the Curve (AUC) using the trapezoidal rule.
This is a general function, given points on a curve. For computing the
area under the ROC-curve, see :func:`roc_auc_score`. For an alternative
way to summarize a precision-recall curve, see
:func:`average_precision_score`.
Parameters
----------
x : array-like of shape (n,)
X coordinates. These must be either monotonic increasing or monotonic
decreasing.
y : array-like of shape (n,)
Y coordinates.
Returns
-------
auc : float
Area Under the Curve.
See Also
--------
roc_auc_score : Compute the area under the ROC curve.
average_precision_score : Compute average precision from prediction scores.
precision_recall_curve : Compute precision-recall pairs for different
probability thresholds.
Examples
--------
>>> import numpy as np
>>> from sklearn import metrics
>>> y = np.array([1, 1, 2, 2])
>>> pred = np.array([0.1, 0.4, 0.35, 0.8])
>>> fpr, tpr, thresholds = metrics.roc_curve(y, pred, pos_label=2)
>>> metrics.auc(fpr, tpr)
np.float64(0.75)
"""
check_consistent_length(x, y)
x = column_or_1d(x)
y = column_or_1d(y)
if x.shape[0] < 2:
raise ValueError(
"At least 2 points are needed to compute area under curve, but x.shape = %s"
% x.shape
)
direction = 1
dx = np.diff(x)
if np.any(dx < 0):
if np.all(dx <= 0):
direction = -1
else:
raise ValueError("x is neither increasing nor decreasing : {}.".format(x))
area = direction * trapezoid(y, x)
if isinstance(area, np.memmap):
# Reductions such as .sum used internally in trapezoid do not return a
# scalar by default for numpy.memmap instances contrary to
# regular numpy.ndarray instances.
area = area.dtype.type(area)
return area
@validate_params(
{
"y_true": ["array-like"],
"y_score": ["array-like"],
"average": [StrOptions({"micro", "samples", "weighted", "macro"}), None],
"pos_label": [Real, str, "boolean"],
"sample_weight": ["array-like", None],
},
prefer_skip_nested_validation=True,
)
def average_precision_score(
y_true, y_score, *, average="macro", pos_label=1, sample_weight=None
):
"""Compute average precision (AP) from prediction scores.
AP summarizes a precision-recall curve as the weighted mean of precisions
achieved at each threshold, with the increase in recall from the previous
threshold used as the weight:
.. math::
\\text{AP} = \\sum_n (R_n - R_{n-1}) P_n
where :math:`P_n` and :math:`R_n` are the precision and recall at the nth
threshold [1]_. This implementation is not interpolated and is different
from computing the area under the precision-recall curve with the
trapezoidal rule, which uses linear interpolation and can be too
optimistic.
Read more in the :ref:`User Guide <precision_recall_f_measure_metrics>`.
Parameters
----------
y_true : array-like of shape (n_samples,) or (n_samples, n_classes)
True binary labels or binary label indicators.
y_score : array-like of shape (n_samples,) or (n_samples, n_classes)
Target scores, can either be probability estimates of the positive
class, confidence values, or non-thresholded measure of decisions
(as returned by :term:`decision_function` on some classifiers).
For :term:`decision_function` scores, values greater than or equal to
zero should indicate the positive class.
average : {'micro', 'samples', 'weighted', 'macro'} or None, \
default='macro'
If ``None``, the scores for each class are returned. Otherwise,
this determines the type of averaging performed on the data:
``'micro'``:
Calculate metrics globally by considering each element of the label
indicator matrix as a label.
``'macro'``:
Calculate metrics for each label, and find their unweighted
mean. This does not take label imbalance into account.
``'weighted'``:
Calculate metrics for each label, and find their average, weighted
by support (the number of true instances for each label).
``'samples'``:
Calculate metrics for each instance, and find their average.
Will be ignored when ``y_true`` is binary.
pos_label : int, float, bool or str, default=1
The label of the positive class. Only applied to binary ``y_true``.
For multilabel-indicator ``y_true``, ``pos_label`` is fixed to 1.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
Returns
-------
average_precision : float
Average precision score.
See Also
--------
roc_auc_score : Compute the area under the ROC curve.
precision_recall_curve : Compute precision-recall pairs for different
probability thresholds.
Notes
-----
.. versionchanged:: 0.19
Instead of linearly interpolating between operating points, precisions
are weighted by the change in recall since the last operating point.
References
----------
.. [1] `Wikipedia entry for the Average precision
<https://en.wikipedia.org/w/index.php?title=Information_retrieval&
oldid=793358396#Average_precision>`_
Examples
--------
>>> import numpy as np
>>> from sklearn.metrics import average_precision_score
>>> y_true = np.array([0, 0, 1, 1])
>>> y_scores = np.array([0.1, 0.4, 0.35, 0.8])
>>> average_precision_score(y_true, y_scores)
np.float64(0.83...)
>>> y_true = np.array([0, 0, 1, 1, 2, 2])
>>> y_scores = np.array([
... [0.7, 0.2, 0.1],
... [0.4, 0.3, 0.3],
... [0.1, 0.8, 0.1],
... [0.2, 0.3, 0.5],
... [0.4, 0.4, 0.2],
... [0.1, 0.2, 0.7],
... ])
>>> average_precision_score(y_true, y_scores)
np.float64(0.77...)
"""
def _binary_uninterpolated_average_precision(
y_true, y_score, pos_label=1, sample_weight=None
):
precision, recall, _ = precision_recall_curve(
y_true, y_score, pos_label=pos_label, sample_weight=sample_weight
)
# Return the step function integral
# The following works because the last entry of precision is
# guaranteed to be 1, as returned by precision_recall_curve.
# Due to numerical error, we can get `-0.0` and we therefore clip it.
return max(0.0, -np.sum(np.diff(recall) * np.array(precision)[:-1]))
y_type = type_of_target(y_true, input_name="y_true")
# Convert to Python primitive type to avoid NumPy type / Python str
# comparison. See https://github.com/numpy/numpy/issues/6784
present_labels = np.unique(y_true).tolist()
if y_type == "binary":
if len(present_labels) == 2 and pos_label not in present_labels:
raise ValueError(
f"pos_label={pos_label} is not a valid label. It should be "
f"one of {present_labels}"
)
elif y_type == "multilabel-indicator" and pos_label != 1:
raise ValueError(
"Parameter pos_label is fixed to 1 for multilabel-indicator y_true. "
"Do not set pos_label or set pos_label to 1."
)
elif y_type == "multiclass":
if pos_label != 1:
raise ValueError(
"Parameter pos_label is fixed to 1 for multiclass y_true. "
"Do not set pos_label or set pos_label to 1."
)
y_true = label_binarize(y_true, classes=present_labels)
average_precision = partial(
_binary_uninterpolated_average_precision, pos_label=pos_label
)
return _average_binary_score(
average_precision, y_true, y_score, average, sample_weight=sample_weight
)
@validate_params(
{
"y_true": ["array-like"],
"y_score": ["array-like"],
"pos_label": [Real, str, "boolean", None],
"sample_weight": ["array-like", None],
},
prefer_skip_nested_validation=True,
)
def det_curve(y_true, y_score, pos_label=None, sample_weight=None):
"""Compute error rates for different probability thresholds.
.. note::
This metric is used for evaluation of ranking and error tradeoffs of
a binary classification task.
Read more in the :ref:`User Guide <det_curve>`.
.. versionadded:: 0.24
Parameters
----------
y_true : ndarray of shape (n_samples,)
True binary labels. If labels are not either {-1, 1} or {0, 1}, then
pos_label should be explicitly given.
y_score : ndarray of shape of (n_samples,)
Target scores, can either be probability estimates of the positive
class, confidence values, or non-thresholded measure of decisions
(as returned by "decision_function" on some classifiers).
For :term:`decision_function` scores, values greater than or equal to
zero should indicate the positive class.
pos_label : int, float, bool or str, default=None
The label of the positive class.
When ``pos_label=None``, if `y_true` is in {-1, 1} or {0, 1},
``pos_label`` is set to 1, otherwise an error will be raised.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
Returns
-------
fpr : ndarray of shape (n_thresholds,)
False positive rate (FPR) such that element i is the false positive
rate of predictions with score >= thresholds[i]. This is occasionally
referred to as false acceptance probability or fall-out.
fnr : ndarray of shape (n_thresholds,)
False negative rate (FNR) such that element i is the false negative
rate of predictions with score >= thresholds[i]. This is occasionally
referred to as false rejection or miss rate.
thresholds : ndarray of shape (n_thresholds,)
Decreasing score values.
See Also
--------
DetCurveDisplay.from_estimator : Plot DET curve given an estimator and
some data.
DetCurveDisplay.from_predictions : Plot DET curve given the true and
predicted labels.
DetCurveDisplay : DET curve visualization.
roc_curve : Compute Receiver operating characteristic (ROC) curve.
precision_recall_curve : Compute precision-recall curve.
Examples
--------
>>> import numpy as np
>>> from sklearn.metrics import det_curve
>>> y_true = np.array([0, 0, 1, 1])
>>> y_scores = np.array([0.1, 0.4, 0.35, 0.8])
>>> fpr, fnr, thresholds = det_curve(y_true, y_scores)
>>> fpr
array([0.5, 0.5, 0. ])
>>> fnr
array([0. , 0.5, 0.5])
>>> thresholds
array([0.35, 0.4 , 0.8 ])
"""
fps, tps, thresholds = _binary_clf_curve(
y_true, y_score, pos_label=pos_label, sample_weight=sample_weight
)
if len(np.unique(y_true)) != 2:
raise ValueError(
"Only one class is present in y_true. Detection error "
"tradeoff curve is not defined in that case."
)
fns = tps[-1] - tps
p_count = tps[-1]
n_count = fps[-1]
# start with false positives zero
first_ind = (
fps.searchsorted(fps[0], side="right") - 1
if fps.searchsorted(fps[0], side="right") > 0
else None
)
# stop with false negatives zero
last_ind = tps.searchsorted(tps[-1]) + 1
sl = slice(first_ind, last_ind)
# reverse the output such that list of false positives is decreasing
return (fps[sl][::-1] / n_count, fns[sl][::-1] / p_count, thresholds[sl][::-1])
def _binary_roc_auc_score(y_true, y_score, sample_weight=None, max_fpr=None):
"""Binary roc auc score."""
if len(np.unique(y_true)) != 2:
warnings.warn(
(
"Only one class is present in y_true. ROC AUC score "
"is not defined in that case."
),
UndefinedMetricWarning,
)
return np.nan
fpr, tpr, _ = roc_curve(y_true, y_score, sample_weight=sample_weight)
if max_fpr is None or max_fpr == 1:
return auc(fpr, tpr)
if max_fpr <= 0 or max_fpr > 1:
raise ValueError("Expected max_fpr in range (0, 1], got: %r" % max_fpr)
# Add a single point at max_fpr by linear interpolation
stop = np.searchsorted(fpr, max_fpr, "right")
x_interp = [fpr[stop - 1], fpr[stop]]
y_interp = [tpr[stop - 1], tpr[stop]]
tpr = np.append(tpr[:stop], np.interp(max_fpr, x_interp, y_interp))
fpr = np.append(fpr[:stop], max_fpr)
partial_auc = auc(fpr, tpr)
# McClish correction: standardize result to be 0.5 if non-discriminant
# and 1 if maximal
min_area = 0.5 * max_fpr**2
max_area = max_fpr
return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area))
@validate_params(
{
"y_true": ["array-like"],
"y_score": ["array-like"],
"average": [StrOptions({"micro", "macro", "samples", "weighted"}), None],
"sample_weight": ["array-like", None],
"max_fpr": [Interval(Real, 0.0, 1, closed="right"), None],
"multi_class": [StrOptions({"raise", "ovr", "ovo"})],
"labels": ["array-like", None],
},
prefer_skip_nested_validation=True,
)
def roc_auc_score(
y_true,
y_score,
*,
average="macro",
sample_weight=None,
max_fpr=None,
multi_class="raise",
labels=None,
):
"""Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) \
from prediction scores.
Note: this implementation can be used with binary, multiclass and
multilabel classification, but some restrictions apply (see Parameters).
Read more in the :ref:`User Guide <roc_metrics>`.
Parameters
----------
y_true : array-like of shape (n_samples,) or (n_samples, n_classes)
True labels or binary label indicators. The binary and multiclass cases
expect labels with shape (n_samples,) while the multilabel case expects
binary label indicators with shape (n_samples, n_classes).
y_score : array-like of shape (n_samples,) or (n_samples, n_classes)
Target scores.
* In the binary case, it corresponds to an array of shape
`(n_samples,)`. Both probability estimates and non-thresholded
decision values can be provided. The probability estimates correspond
to the **probability of the class with the greater label**,
i.e. `estimator.classes_[1]` and thus
`estimator.predict_proba(X, y)[:, 1]`. The decision values
corresponds to the output of `estimator.decision_function(X, y)`.
See more information in the :ref:`User guide <roc_auc_binary>`;
* In the multiclass case, it corresponds to an array of shape
`(n_samples, n_classes)` of probability estimates provided by the
`predict_proba` method. The probability estimates **must**
sum to 1 across the possible classes. In addition, the order of the
class scores must correspond to the order of ``labels``,
if provided, or else to the numerical or lexicographical order of
the labels in ``y_true``. See more information in the
:ref:`User guide <roc_auc_multiclass>`;
* In the multilabel case, it corresponds to an array of shape
`(n_samples, n_classes)`. Probability estimates are provided by the
`predict_proba` method and the non-thresholded decision values by
the `decision_function` method. The probability estimates correspond
to the **probability of the class with the greater label for each
output** of the classifier. See more information in the
:ref:`User guide <roc_auc_multilabel>`.
average : {'micro', 'macro', 'samples', 'weighted'} or None, \
default='macro'
If ``None``, the scores for each class are returned.
Otherwise, this determines the type of averaging performed on the data.
Note: multiclass ROC AUC currently only handles the 'macro' and
'weighted' averages. For multiclass targets, `average=None` is only
implemented for `multi_class='ovr'` and `average='micro'` is only
implemented for `multi_class='ovr'`.
``'micro'``:
Calculate metrics globally by considering each element of the label
indicator matrix as a label.
``'macro'``:
Calculate metrics for each label, and find their unweighted
mean. This does not take label imbalance into account.
``'weighted'``:
Calculate metrics for each label, and find their average, weighted
by support (the number of true instances for each label).
``'samples'``:
Calculate metrics for each instance, and find their average.
Will be ignored when ``y_true`` is binary.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
max_fpr : float > 0 and <= 1, default=None
If not ``None``, the standardized partial AUC [2]_ over the range
[0, max_fpr] is returned. For the multiclass case, ``max_fpr``,
should be either equal to ``None`` or ``1.0`` as AUC ROC partial
computation currently is not supported for multiclass.
multi_class : {'raise', 'ovr', 'ovo'}, default='raise'
Only used for multiclass targets. Determines the type of configuration
to use. The default value raises an error, so either
``'ovr'`` or ``'ovo'`` must be passed explicitly.
``'ovr'``:
Stands for One-vs-rest. Computes the AUC of each class
against the rest [3]_ [4]_. This
treats the multiclass case in the same way as the multilabel case.
Sensitive to class imbalance even when ``average == 'macro'``,
because class imbalance affects the composition of each of the
'rest' groupings.
``'ovo'``:
Stands for One-vs-one. Computes the average AUC of all
possible pairwise combinations of classes [5]_.
Insensitive to class imbalance when
``average == 'macro'``.
labels : array-like of shape (n_classes,), default=None
Only used for multiclass targets. List of labels that index the
classes in ``y_score``. If ``None``, the numerical or lexicographical
order of the labels in ``y_true`` is used.
Returns
-------
auc : float
Area Under the Curve score.
See Also
--------
average_precision_score : Area under the precision-recall curve.
roc_curve : Compute Receiver operating characteristic (ROC) curve.
RocCurveDisplay.from_estimator : Plot Receiver Operating Characteristic
(ROC) curve given an estimator and some data.
RocCurveDisplay.from_predictions : Plot Receiver Operating Characteristic
(ROC) curve given the true and predicted values.
Notes
-----
The Gini Coefficient is a summary measure of the ranking ability of binary
classifiers. It is expressed using the area under of the ROC as follows:
G = 2 * AUC - 1
Where G is the Gini coefficient and AUC is the ROC-AUC score. This normalisation
will ensure that random guessing will yield a score of 0 in expectation, and it is
upper bounded by 1.
References
----------
.. [1] `Wikipedia entry for the Receiver operating characteristic
<https://en.wikipedia.org/wiki/Receiver_operating_characteristic>`_
.. [2] `Analyzing a portion of the ROC curve. McClish, 1989
<https://www.ncbi.nlm.nih.gov/pubmed/2668680>`_
.. [3] Provost, F., Domingos, P. (2000). Well-trained PETs: Improving
probability estimation trees (Section 6.2), CeDER Working Paper
#IS-00-04, Stern School of Business, New York University.
.. [4] `Fawcett, T. (2006). An introduction to ROC analysis. Pattern
Recognition Letters, 27(8), 861-874.
<https://www.sciencedirect.com/science/article/pii/S016786550500303X>`_
.. [5] `Hand, D.J., Till, R.J. (2001). A Simple Generalisation of the Area
Under the ROC Curve for Multiple Class Classification Problems.
Machine Learning, 45(2), 171-186.
<http://link.springer.com/article/10.1023/A:1010920819831>`_
.. [6] `Wikipedia entry for the Gini coefficient
<https://en.wikipedia.org/wiki/Gini_coefficient>`_
Examples
--------
Binary case:
>>> from sklearn.datasets import load_breast_cancer
>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.metrics import roc_auc_score
>>> X, y = load_breast_cancer(return_X_y=True)
>>> clf = LogisticRegression(solver="liblinear", random_state=0).fit(X, y)
>>> roc_auc_score(y, clf.predict_proba(X)[:, 1])
np.float64(0.99...)
>>> roc_auc_score(y, clf.decision_function(X))
np.float64(0.99...)
Multiclass case:
>>> from sklearn.datasets import load_iris
>>> X, y = load_iris(return_X_y=True)
>>> clf = LogisticRegression(solver="liblinear").fit(X, y)
>>> roc_auc_score(y, clf.predict_proba(X), multi_class='ovr')
np.float64(0.99...)
Multilabel case:
>>> import numpy as np
>>> from sklearn.datasets import make_multilabel_classification
>>> from sklearn.multioutput import MultiOutputClassifier
>>> X, y = make_multilabel_classification(random_state=0)
>>> clf = MultiOutputClassifier(clf).fit(X, y)
>>> # get a list of n_output containing probability arrays of shape
>>> # (n_samples, n_classes)
>>> y_pred = clf.predict_proba(X)
>>> # extract the positive columns for each output
>>> y_pred = np.transpose([pred[:, 1] for pred in y_pred])
>>> roc_auc_score(y, y_pred, average=None)
array([0.82..., 0.86..., 0.94..., 0.85... , 0.94...])
>>> from sklearn.linear_model import RidgeClassifierCV
>>> clf = RidgeClassifierCV().fit(X, y)
>>> roc_auc_score(y, clf.decision_function(X), average=None)
array([0.81..., 0.84... , 0.93..., 0.87..., 0.94...])
"""
y_type = type_of_target(y_true, input_name="y_true")
y_true = check_array(y_true, ensure_2d=False, dtype=None)
y_score = check_array(y_score, ensure_2d=False)
if y_type == "multiclass" or (
y_type == "binary" and y_score.ndim == 2 and y_score.shape[1] > 2
):
# do not support partial ROC computation for multiclass
if max_fpr is not None and max_fpr != 1.0:
raise ValueError(
"Partial AUC computation not available in "
"multiclass setting, 'max_fpr' must be"
" set to `None`, received `max_fpr={0}` "
"instead".format(max_fpr)
)
if multi_class == "raise":
raise ValueError("multi_class must be in ('ovo', 'ovr')")
return _multiclass_roc_auc_score(
y_true, y_score, labels, multi_class, average, sample_weight
)
elif y_type == "binary":
labels = np.unique(y_true)
y_true = label_binarize(y_true, classes=labels)[:, 0]
return _average_binary_score(
partial(_binary_roc_auc_score, max_fpr=max_fpr),
y_true,
y_score,
average,
sample_weight=sample_weight,
)
else: # multilabel-indicator
return _average_binary_score(
partial(_binary_roc_auc_score, max_fpr=max_fpr),
y_true,
y_score,
average,
sample_weight=sample_weight,
)
def _multiclass_roc_auc_score(
y_true, y_score, labels, multi_class, average, sample_weight
):
"""Multiclass roc auc score.
Parameters
----------
y_true : array-like of shape (n_samples,)
True multiclass labels.
y_score : array-like of shape (n_samples, n_classes)
Target scores corresponding to probability estimates of a sample
belonging to a particular class
labels : array-like of shape (n_classes,) or None
List of labels to index ``y_score`` used for multiclass. If ``None``,
the lexical order of ``y_true`` is used to index ``y_score``.
multi_class : {'ovr', 'ovo'}
Determines the type of multiclass configuration to use.
``'ovr'``:
Calculate metrics for the multiclass case using the one-vs-rest
approach.
``'ovo'``:
Calculate metrics for the multiclass case using the one-vs-one
approach.
average : {'micro', 'macro', 'weighted'}
Determines the type of averaging performed on the pairwise binary
metric scores
``'micro'``:
Calculate metrics for the binarized-raveled classes. Only supported
for `multi_class='ovr'`.
.. versionadded:: 1.2
``'macro'``:
Calculate metrics for each label, and find their unweighted
mean. This does not take label imbalance into account. Classes
are assumed to be uniformly distributed.
``'weighted'``:
Calculate metrics for each label, taking into account the
prevalence of the classes.
sample_weight : array-like of shape (n_samples,) or None
Sample weights.
"""
# validation of the input y_score
if not np.allclose(1, y_score.sum(axis=1)):
raise ValueError(
"Target scores need to be probabilities for multiclass "
"roc_auc, i.e. they should sum up to 1.0 over classes"
)
# validation for multiclass parameter specifications
average_options = ("macro", "weighted", None)
if multi_class == "ovr":
average_options = ("micro",) + average_options
if average not in average_options:
raise ValueError(
"average must be one of {0} for multiclass problems".format(average_options)
)
multiclass_options = ("ovo", "ovr")
if multi_class not in multiclass_options:
raise ValueError(
"multi_class='{0}' is not supported "
"for multiclass ROC AUC, multi_class must be "
"in {1}".format(multi_class, multiclass_options)
)
if average is None and multi_class == "ovo":
raise NotImplementedError(
"average=None is not implemented for multi_class='ovo'."
)
if labels is not None:
labels = column_or_1d(labels)
classes = _unique(labels)
if len(classes) != len(labels):
raise ValueError("Parameter 'labels' must be unique")
if not np.array_equal(classes, labels):
raise ValueError("Parameter 'labels' must be ordered")
if len(classes) != y_score.shape[1]:
raise ValueError(
"Number of given labels, {0}, not equal to the number "
"of columns in 'y_score', {1}".format(len(classes), y_score.shape[1])
)
if len(np.setdiff1d(y_true, classes)):
raise ValueError("'y_true' contains labels not in parameter 'labels'")
else:
classes = _unique(y_true)
if len(classes) != y_score.shape[1]:
raise ValueError(
"Number of classes in y_true not equal to the number of "
"columns in 'y_score'"
)
if multi_class == "ovo":
if sample_weight is not None:
raise ValueError(
"sample_weight is not supported "
"for multiclass one-vs-one ROC AUC, "
"'sample_weight' must be None in this case."
)
y_true_encoded = _encode(y_true, uniques=classes)
# Hand & Till (2001) implementation (ovo)
return _average_multiclass_ovo_score(
_binary_roc_auc_score, y_true_encoded, y_score, average=average
)
else:
# ovr is same as multi-label
y_true_multilabel = label_binarize(y_true, classes=classes)
return _average_binary_score(
_binary_roc_auc_score,
y_true_multilabel,
y_score,
average,
sample_weight=sample_weight,
)
def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None):
"""Calculate true and false positives per binary classification threshold.
Parameters
----------
y_true : ndarray of shape (n_samples,)
True targets of binary classification.
y_score : ndarray of shape (n_samples,)
Estimated probabilities or output of a decision function.
pos_label : int, float, bool or str, default=None
The label of the positive class.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
Returns
-------
fps : ndarray of shape (n_thresholds,)
A count of false positives, at index i being the number of negative
samples assigned a score >= thresholds[i]. The total number of
negative samples is equal to fps[-1] (thus true negatives are given by
fps[-1] - fps).
tps : ndarray of shape (n_thresholds,)
An increasing count of true positives, at index i being the number
of positive samples assigned a score >= thresholds[i]. The total
number of positive samples is equal to tps[-1] (thus false negatives
are given by tps[-1] - tps).
thresholds : ndarray of shape (n_thresholds,)
Decreasing score values.
"""
# Check to make sure y_true is valid
y_type = type_of_target(y_true, input_name="y_true")
if not (y_type == "binary" or (y_type == "multiclass" and pos_label is not None)):
raise ValueError("{0} format is not supported".format(y_type))
check_consistent_length(y_true, y_score, sample_weight)
y_true = column_or_1d(y_true)
y_score = column_or_1d(y_score)
assert_all_finite(y_true)
assert_all_finite(y_score)
# Filter out zero-weighted samples, as they should not impact the result
if sample_weight is not None:
sample_weight = column_or_1d(sample_weight)
sample_weight = _check_sample_weight(sample_weight, y_true)
nonzero_weight_mask = sample_weight != 0
y_true = y_true[nonzero_weight_mask]
y_score = y_score[nonzero_weight_mask]
sample_weight = sample_weight[nonzero_weight_mask]
pos_label = _check_pos_label_consistency(pos_label, y_true)
# make y_true a boolean vector
y_true = y_true == pos_label
# sort scores and corresponding truth values
desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1]
y_score = y_score[desc_score_indices]
y_true = y_true[desc_score_indices]
if sample_weight is not None:
weight = sample_weight[desc_score_indices]
else:
weight = 1.0
# y_score typically has many tied values. Here we extract
# the indices associated with the distinct values. We also
# concatenate a value for the end of the curve.
distinct_value_indices = np.where(np.diff(y_score))[0]
threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]
# accumulate the true positives with decreasing threshold
tps = stable_cumsum(y_true * weight)[threshold_idxs]
if sample_weight is not None:
# express fps as a cumsum to ensure fps is increasing even in
# the presence of floating point errors
fps = stable_cumsum((1 - y_true) * weight)[threshold_idxs]
else:
fps = 1 + threshold_idxs - tps
return fps, tps, y_score[threshold_idxs]
@validate_params(
{
"y_true": ["array-like"],
"y_score": ["array-like", Hidden(None)],
"pos_label": [Real, str, "boolean", None],
"sample_weight": ["array-like", None],
"drop_intermediate": ["boolean"],
"probas_pred": [
"array-like",
Hidden(StrOptions({"deprecated"})),
],
},
prefer_skip_nested_validation=True,
)
def precision_recall_curve(
y_true,
y_score=None,
*,
pos_label=None,
sample_weight=None,
drop_intermediate=False,
probas_pred="deprecated",
):
"""Compute precision-recall pairs for different probability thresholds.
Note: this implementation is restricted to the binary classification task.
The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of
true positives and ``fp`` the number of false positives. The precision is
intuitively the ability of the classifier not to label as positive a sample
that is negative.
The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of
true positives and ``fn`` the number of false negatives. The recall is
intuitively the ability of the classifier to find all the positive samples.
The last precision and recall values are 1. and 0. respectively and do not
have a corresponding threshold. This ensures that the graph starts on the
y axis.
The first precision and recall values are precision=class balance and recall=1.0
which corresponds to a classifier that always predicts the positive class.
Read more in the :ref:`User Guide <precision_recall_f_measure_metrics>`.
Parameters
----------
y_true : array-like of shape (n_samples,)
True binary labels. If labels are not either {-1, 1} or {0, 1}, then
pos_label should be explicitly given.
y_score : array-like of shape (n_samples,)
Target scores, can either be probability estimates of the positive
class, or non-thresholded measure of decisions (as returned by
`decision_function` on some classifiers).
For :term:`decision_function` scores, values greater than or equal to
zero should indicate the positive class.
pos_label : int, float, bool or str, default=None
The label of the positive class.
When ``pos_label=None``, if y_true is in {-1, 1} or {0, 1},
``pos_label`` is set to 1, otherwise an error will be raised.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
drop_intermediate : bool, default=False
Whether to drop some suboptimal thresholds which would not appear
on a plotted precision-recall curve. This is useful in order to create
lighter precision-recall curves.
.. versionadded:: 1.3
probas_pred : array-like of shape (n_samples,)
Target scores, can either be probability estimates of the positive
class, or non-thresholded measure of decisions (as returned by
`decision_function` on some classifiers).
.. deprecated:: 1.5
`probas_pred` is deprecated and will be removed in 1.7. Use
`y_score` instead.
Returns
-------
precision : ndarray of shape (n_thresholds + 1,)
Precision values such that element i is the precision of
predictions with score >= thresholds[i] and the last element is 1.
recall : ndarray of shape (n_thresholds + 1,)
Decreasing recall values such that element i is the recall of
predictions with score >= thresholds[i] and the last element is 0.
thresholds : ndarray of shape (n_thresholds,)
Increasing thresholds on the decision function used to compute
precision and recall where `n_thresholds = len(np.unique(probas_pred))`.
See Also
--------
PrecisionRecallDisplay.from_estimator : Plot Precision Recall Curve given
a binary classifier.
PrecisionRecallDisplay.from_predictions : Plot Precision Recall Curve
using predictions from a binary classifier.
average_precision_score : Compute average precision from prediction scores.
det_curve: Compute error rates for different probability thresholds.
roc_curve : Compute Receiver operating characteristic (ROC) curve.
Examples
--------
>>> import numpy as np
>>> from sklearn.metrics import precision_recall_curve
>>> y_true = np.array([0, 0, 1, 1])
>>> y_scores = np.array([0.1, 0.4, 0.35, 0.8])
>>> precision, recall, thresholds = precision_recall_curve(
... y_true, y_scores)
>>> precision
array([0.5 , 0.66666667, 0.5 , 1. , 1. ])
>>> recall
array([1. , 1. , 0.5, 0.5, 0. ])
>>> thresholds
array([0.1 , 0.35, 0.4 , 0.8 ])
"""
# TODO(1.7): remove in 1.7 and reset y_score to be required
# Note: validate params will raise an error if probas_pred is not array-like,
# or "deprecated"
if y_score is not None and not isinstance(probas_pred, str):
raise ValueError(
"`probas_pred` and `y_score` cannot be both specified. Please use `y_score`"
" only as `probas_pred` is deprecated in v1.5 and will be removed in v1.7."
)
if y_score is None:
warnings.warn(
(
"probas_pred was deprecated in version 1.5 and will be removed in 1.7."
"Please use ``y_score`` instead."
),
FutureWarning,
)
y_score = probas_pred
fps, tps, thresholds = _binary_clf_curve(
y_true, y_score, pos_label=pos_label, sample_weight=sample_weight
)
if drop_intermediate and len(fps) > 2:
# Drop thresholds corresponding to points where true positives (tps)
# do not change from the previous or subsequent point. This will keep
# only the first and last point for each tps value. All points
# with the same tps value have the same recall and thus x coordinate.
# They appear as a vertical line on the plot.
optimal_idxs = np.where(
np.concatenate(
[[True], np.logical_or(np.diff(tps[:-1]), np.diff(tps[1:])), [True]]
)
)[0]
fps = fps[optimal_idxs]
tps = tps[optimal_idxs]
thresholds = thresholds[optimal_idxs]
ps = tps + fps
# Initialize the result array with zeros to make sure that precision[ps == 0]
# does not contain uninitialized values.
precision = np.zeros_like(tps)
np.divide(tps, ps, out=precision, where=(ps != 0))
# When no positive label in y_true, recall is set to 1 for all thresholds
# tps[-1] == 0 <=> y_true == all negative labels
if tps[-1] == 0:
warnings.warn(
"No positive class found in y_true, "
"recall is set to one for all thresholds."
)
recall = np.ones_like(tps)
else:
recall = tps / tps[-1]
# reverse the outputs so recall is decreasing
sl = slice(None, None, -1)
return np.hstack((precision[sl], 1)), np.hstack((recall[sl], 0)), thresholds[sl]
@validate_params(
{
"y_true": ["array-like"],
"y_score": ["array-like"],
"pos_label": [Real, str, "boolean", None],
"sample_weight": ["array-like", None],
"drop_intermediate": ["boolean"],
},
prefer_skip_nested_validation=True,
)
def roc_curve(
y_true, y_score, *, pos_label=None, sample_weight=None, drop_intermediate=True
):
"""Compute Receiver operating characteristic (ROC).
Note: this implementation is restricted to the binary classification task.
Read more in the :ref:`User Guide <roc_metrics>`.
Parameters
----------
y_true : array-like of shape (n_samples,)
True binary labels. If labels are not either {-1, 1} or {0, 1}, then
pos_label should be explicitly given.
y_score : array-like of shape (n_samples,)
Target scores, can either be probability estimates of the positive
class, confidence values, or non-thresholded measure of decisions
(as returned by "decision_function" on some classifiers).
For :term:`decision_function` scores, values greater than or equal to
zero should indicate the positive class.
pos_label : int, float, bool or str, default=None
The label of the positive class.
When ``pos_label=None``, if `y_true` is in {-1, 1} or {0, 1},
``pos_label`` is set to 1, otherwise an error will be raised.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
drop_intermediate : bool, default=True
Whether to drop some suboptimal thresholds which would not appear
on a plotted ROC curve. This is useful in order to create lighter
ROC curves.
.. versionadded:: 0.17
parameter *drop_intermediate*.
Returns
-------
fpr : ndarray of shape (>2,)
Increasing false positive rates such that element i is the false
positive rate of predictions with score >= `thresholds[i]`.
tpr : ndarray of shape (>2,)
Increasing true positive rates such that element `i` is the true
positive rate of predictions with score >= `thresholds[i]`.
thresholds : ndarray of shape (n_thresholds,)
Decreasing thresholds on the decision function used to compute
fpr and tpr. `thresholds[0]` represents no instances being predicted
and is arbitrarily set to `np.inf`.
See Also
--------
RocCurveDisplay.from_estimator : Plot Receiver Operating Characteristic
(ROC) curve given an estimator and some data.
RocCurveDisplay.from_predictions : Plot Receiver Operating Characteristic
(ROC) curve given the true and predicted values.
det_curve: Compute error rates for different probability thresholds.
roc_auc_score : Compute the area under the ROC curve.
Notes
-----
Since the thresholds are sorted from low to high values, they
are reversed upon returning them to ensure they correspond to both ``fpr``
and ``tpr``, which are sorted in reversed order during their calculation.
An arbitrary threshold is added for the case `tpr=0` and `fpr=0` to
ensure that the curve starts at `(0, 0)`. This threshold corresponds to the
`np.inf`.
References
----------
.. [1] `Wikipedia entry for the Receiver operating characteristic
<https://en.wikipedia.org/wiki/Receiver_operating_characteristic>`_
.. [2] Fawcett T. An introduction to ROC analysis[J]. Pattern Recognition
Letters, 2006, 27(8):861-874.
Examples
--------
>>> import numpy as np
>>> from sklearn import metrics
>>> y = np.array([1, 1, 2, 2])
>>> scores = np.array([0.1, 0.4, 0.35, 0.8])
>>> fpr, tpr, thresholds = metrics.roc_curve(y, scores, pos_label=2)
>>> fpr
array([0. , 0. , 0.5, 0.5, 1. ])
>>> tpr
array([0. , 0.5, 0.5, 1. , 1. ])
>>> thresholds
array([ inf, 0.8 , 0.4 , 0.35, 0.1 ])
"""
fps, tps, thresholds = _binary_clf_curve(
y_true, y_score, pos_label=pos_label, sample_weight=sample_weight
)
# Attempt to drop thresholds corresponding to points in between and
# collinear with other points. These are always suboptimal and do not
# appear on a plotted ROC curve (and thus do not affect the AUC).
# Here np.diff(_, 2) is used as a "second derivative" to tell if there
# is a corner at the point. Both fps and tps must be tested to handle
# thresholds with multiple data points (which are combined in
# _binary_clf_curve). This keeps all cases where the point should be kept,
# but does not drop more complicated cases like fps = [1, 3, 7],
# tps = [1, 2, 4]; there is no harm in keeping too many thresholds.
if drop_intermediate and len(fps) > 2:
optimal_idxs = np.where(
np.r_[True, np.logical_or(np.diff(fps, 2), np.diff(tps, 2)), True]
)[0]
fps = fps[optimal_idxs]
tps = tps[optimal_idxs]
thresholds = thresholds[optimal_idxs]
# Add an extra threshold position
# to make sure that the curve starts at (0, 0)
tps = np.r_[0, tps]
fps = np.r_[0, fps]
# get dtype of `y_score` even if it is an array-like
thresholds = np.r_[np.inf, thresholds]
if fps[-1] <= 0:
warnings.warn(
"No negative samples in y_true, false positive value should be meaningless",
UndefinedMetricWarning,
)
fpr = np.repeat(np.nan, fps.shape)
else:
fpr = fps / fps[-1]
if tps[-1] <= 0:
warnings.warn(
"No positive samples in y_true, true positive value should be meaningless",
UndefinedMetricWarning,
)
tpr = np.repeat(np.nan, tps.shape)
else:
tpr = tps / tps[-1]
return fpr, tpr, thresholds
@validate_params(
{
"y_true": ["array-like", "sparse matrix"],
"y_score": ["array-like"],
"sample_weight": ["array-like", None],
},
prefer_skip_nested_validation=True,
)
def label_ranking_average_precision_score(y_true, y_score, *, sample_weight=None):
"""Compute ranking-based average precision.
Label ranking average precision (LRAP) is the average over each ground
truth label assigned to each sample, of the ratio of true vs. total
labels with lower score.
This metric is used in multilabel ranking problem, where the goal
is to give better rank to the labels associated to each sample.
The obtained score is always strictly greater than 0 and
the best value is 1.
Read more in the :ref:`User Guide <label_ranking_average_precision>`.
Parameters
----------
y_true : {array-like, sparse matrix} of shape (n_samples, n_labels)
True binary labels in binary indicator format.
y_score : array-like of shape (n_samples, n_labels)
Target scores, can either be probability estimates of the positive
class, confidence values, or non-thresholded measure of decisions
(as returned by "decision_function" on some classifiers).
For :term:`decision_function` scores, values greater than or equal to
zero should indicate the positive class.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
.. versionadded:: 0.20
Returns
-------
score : float
Ranking-based average precision score.
Examples
--------
>>> import numpy as np
>>> from sklearn.metrics import label_ranking_average_precision_score
>>> y_true = np.array([[1, 0, 0], [0, 0, 1]])
>>> y_score = np.array([[0.75, 0.5, 1], [1, 0.2, 0.1]])
>>> label_ranking_average_precision_score(y_true, y_score)
np.float64(0.416...)
"""
check_consistent_length(y_true, y_score, sample_weight)
y_true = check_array(y_true, ensure_2d=False, accept_sparse="csr")
y_score = check_array(y_score, ensure_2d=False)
if y_true.shape != y_score.shape:
raise ValueError("y_true and y_score have different shape")
# Handle badly formatted array and the degenerate case with one label
y_type = type_of_target(y_true, input_name="y_true")
if y_type != "multilabel-indicator" and not (
y_type == "binary" and y_true.ndim == 2
):
raise ValueError("{0} format is not supported".format(y_type))
if not issparse(y_true):
y_true = csr_matrix(y_true)
y_score = -y_score
n_samples, n_labels = y_true.shape
out = 0.0
for i, (start, stop) in enumerate(zip(y_true.indptr, y_true.indptr[1:])):
relevant = y_true.indices[start:stop]
if relevant.size == 0 or relevant.size == n_labels:
# If all labels are relevant or unrelevant, the score is also
# equal to 1. The label ranking has no meaning.
aux = 1.0
else:
scores_i = y_score[i]
rank = rankdata(scores_i, "max")[relevant]
L = rankdata(scores_i[relevant], "max")
aux = (L / rank).mean()
if sample_weight is not None:
aux = aux * sample_weight[i]
out += aux
if sample_weight is None:
out /= n_samples
else:
out /= np.sum(sample_weight)
return out
@validate_params(
{
"y_true": ["array-like"],
"y_score": ["array-like"],
"sample_weight": ["array-like", None],
},
prefer_skip_nested_validation=True,
)
def coverage_error(y_true, y_score, *, sample_weight=None):
"""Coverage error measure.
Compute how far we need to go through the ranked scores to cover all
true labels. The best value is equal to the average number
of labels in ``y_true`` per sample.
Ties in ``y_scores`` are broken by giving maximal rank that would have
been assigned to all tied values.
Note: Our implementation's score is 1 greater than the one given in
Tsoumakas et al., 2010. This extends it to handle the degenerate case
in which an instance has 0 true labels.
Read more in the :ref:`User Guide <coverage_error>`.
Parameters
----------
y_true : array-like of shape (n_samples, n_labels)
True binary labels in binary indicator format.
y_score : array-like of shape (n_samples, n_labels)
Target scores, can either be probability estimates of the positive
class, confidence values, or non-thresholded measure of decisions
(as returned by "decision_function" on some classifiers).
For :term:`decision_function` scores, values greater than or equal to
zero should indicate the positive class.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
Returns
-------
coverage_error : float
The coverage error.
References
----------
.. [1] Tsoumakas, G., Katakis, I., & Vlahavas, I. (2010).
Mining multi-label data. In Data mining and knowledge discovery
handbook (pp. 667-685). Springer US.
Examples
--------
>>> from sklearn.metrics import coverage_error
>>> y_true = [[1, 0, 0], [0, 1, 1]]
>>> y_score = [[1, 0, 0], [0, 1, 1]]
>>> coverage_error(y_true, y_score)
np.float64(1.5)
"""
y_true = check_array(y_true, ensure_2d=True)
y_score = check_array(y_score, ensure_2d=True)
check_consistent_length(y_true, y_score, sample_weight)
y_type = type_of_target(y_true, input_name="y_true")
if y_type != "multilabel-indicator":
raise ValueError("{0} format is not supported".format(y_type))
if y_true.shape != y_score.shape:
raise ValueError("y_true and y_score have different shape")
y_score_mask = np.ma.masked_array(y_score, mask=np.logical_not(y_true))
y_min_relevant = y_score_mask.min(axis=1).reshape((-1, 1))
coverage = (y_score >= y_min_relevant).sum(axis=1)
coverage = coverage.filled(0)
return np.average(coverage, weights=sample_weight)
@validate_params(
{
"y_true": ["array-like", "sparse matrix"],
"y_score": ["array-like"],
"sample_weight": ["array-like", None],
},
prefer_skip_nested_validation=True,
)
def label_ranking_loss(y_true, y_score, *, sample_weight=None):
"""Compute Ranking loss measure.
Compute the average number of label pairs that are incorrectly ordered
given y_score weighted by the size of the label set and the number of
labels not in the label set.
This is similar to the error set size, but weighted by the number of
relevant and irrelevant labels. The best performance is achieved with
a ranking loss of zero.
Read more in the :ref:`User Guide <label_ranking_loss>`.
.. versionadded:: 0.17
A function *label_ranking_loss*
Parameters
----------
y_true : {array-like, sparse matrix} of shape (n_samples, n_labels)
True binary labels in binary indicator format.
y_score : array-like of shape (n_samples, n_labels)
Target scores, can either be probability estimates of the positive
class, confidence values, or non-thresholded measure of decisions
(as returned by "decision_function" on some classifiers).
For :term:`decision_function` scores, values greater than or equal to
zero should indicate the positive class.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
Returns
-------
loss : float
Average number of label pairs that are incorrectly ordered given
y_score weighted by the size of the label set and the number of labels not
in the label set.
References
----------
.. [1] Tsoumakas, G., Katakis, I., & Vlahavas, I. (2010).
Mining multi-label data. In Data mining and knowledge discovery
handbook (pp. 667-685). Springer US.
Examples
--------
>>> from sklearn.metrics import label_ranking_loss
>>> y_true = [[1, 0, 0], [0, 0, 1]]
>>> y_score = [[0.75, 0.5, 1], [1, 0.2, 0.1]]
>>> label_ranking_loss(y_true, y_score)
np.float64(0.75...)
"""
y_true = check_array(y_true, ensure_2d=False, accept_sparse="csr")
y_score = check_array(y_score, ensure_2d=False)
check_consistent_length(y_true, y_score, sample_weight)
y_type = type_of_target(y_true, input_name="y_true")
if y_type not in ("multilabel-indicator",):
raise ValueError("{0} format is not supported".format(y_type))
if y_true.shape != y_score.shape:
raise ValueError("y_true and y_score have different shape")
n_samples, n_labels = y_true.shape
y_true = csr_matrix(y_true)
loss = np.zeros(n_samples)
for i, (start, stop) in enumerate(zip(y_true.indptr, y_true.indptr[1:])):
# Sort and bin the label scores
unique_scores, unique_inverse = np.unique(y_score[i], return_inverse=True)
true_at_reversed_rank = np.bincount(
unique_inverse[y_true.indices[start:stop]], minlength=len(unique_scores)
)
all_at_reversed_rank = np.bincount(unique_inverse, minlength=len(unique_scores))
false_at_reversed_rank = all_at_reversed_rank - true_at_reversed_rank
# if the scores are ordered, it's possible to count the number of
# incorrectly ordered paires in linear time by cumulatively counting
# how many false labels of a given score have a score higher than the
# accumulated true labels with lower score.
loss[i] = np.dot(true_at_reversed_rank.cumsum(), false_at_reversed_rank)
n_positives = count_nonzero(y_true, axis=1)
with np.errstate(divide="ignore", invalid="ignore"):
loss /= (n_labels - n_positives) * n_positives
# When there is no positive or no negative labels, those values should
# be consider as correct, i.e. the ranking doesn't matter.
loss[np.logical_or(n_positives == 0, n_positives == n_labels)] = 0.0
return np.average(loss, weights=sample_weight)
def _dcg_sample_scores(y_true, y_score, k=None, log_base=2, ignore_ties=False):
"""Compute Discounted Cumulative Gain.
Sum the true scores ranked in the order induced by the predicted scores,
after applying a logarithmic discount.
This ranking metric yields a high value if true labels are ranked high by
``y_score``.
Parameters
----------
y_true : ndarray of shape (n_samples, n_labels)
True targets of multilabel classification, or true scores of entities
to be ranked.
y_score : ndarray of shape (n_samples, n_labels)
Target scores, can either be probability estimates, confidence values,
or non-thresholded measure of decisions (as returned by
"decision_function" on some classifiers).
k : int, default=None
Only consider the highest k scores in the ranking. If `None`, use all
outputs.
log_base : float, default=2
Base of the logarithm used for the discount. A low value means a
sharper discount (top results are more important).
ignore_ties : bool, default=False
Assume that there are no ties in y_score (which is likely to be the
case if y_score is continuous) for efficiency gains.
Returns
-------
discounted_cumulative_gain : ndarray of shape (n_samples,)
The DCG score for each sample.
See Also
--------
ndcg_score : The Discounted Cumulative Gain divided by the Ideal Discounted
Cumulative Gain (the DCG obtained for a perfect ranking), in order to
have a score between 0 and 1.
"""
discount = 1 / (np.log(np.arange(y_true.shape[1]) + 2) / np.log(log_base))
if k is not None:
discount[k:] = 0
if ignore_ties:
ranking = np.argsort(y_score)[:, ::-1]
ranked = y_true[np.arange(ranking.shape[0])[:, np.newaxis], ranking]
cumulative_gains = discount.dot(ranked.T)
else:
discount_cumsum = np.cumsum(discount)
cumulative_gains = [
_tie_averaged_dcg(y_t, y_s, discount_cumsum)
for y_t, y_s in zip(y_true, y_score)
]
cumulative_gains = np.asarray(cumulative_gains)
return cumulative_gains
def _tie_averaged_dcg(y_true, y_score, discount_cumsum):
"""
Compute DCG by averaging over possible permutations of ties.
The gain (`y_true`) of an index falling inside a tied group (in the order
induced by `y_score`) is replaced by the average gain within this group.
The discounted gain for a tied group is then the average `y_true` within
this group times the sum of discounts of the corresponding ranks.
This amounts to averaging scores for all possible orderings of the tied
groups.
(note in the case of dcg@k the discount is 0 after index k)
Parameters
----------
y_true : ndarray
The true relevance scores.
y_score : ndarray
Predicted scores.
discount_cumsum : ndarray
Precomputed cumulative sum of the discounts.
Returns
-------
discounted_cumulative_gain : float
The discounted cumulative gain.
References
----------
McSherry, F., & Najork, M. (2008, March). Computing information retrieval
performance measures efficiently in the presence of tied scores. In
European conference on information retrieval (pp. 414-421). Springer,
Berlin, Heidelberg.
"""
_, inv, counts = np.unique(-y_score, return_inverse=True, return_counts=True)
ranked = np.zeros(len(counts))
np.add.at(ranked, inv, y_true)
ranked /= counts
groups = np.cumsum(counts) - 1
discount_sums = np.empty(len(counts))
discount_sums[0] = discount_cumsum[groups[0]]
discount_sums[1:] = np.diff(discount_cumsum[groups])
return (ranked * discount_sums).sum()
def _check_dcg_target_type(y_true):
y_type = type_of_target(y_true, input_name="y_true")
supported_fmt = (
"multilabel-indicator",
"continuous-multioutput",
"multiclass-multioutput",
)
if y_type not in supported_fmt:
raise ValueError(
"Only {} formats are supported. Got {} instead".format(
supported_fmt, y_type
)
)
@validate_params(
{
"y_true": ["array-like"],
"y_score": ["array-like"],
"k": [Interval(Integral, 1, None, closed="left"), None],
"log_base": [Interval(Real, 0.0, None, closed="neither")],
"sample_weight": ["array-like", None],
"ignore_ties": ["boolean"],
},
prefer_skip_nested_validation=True,
)
def dcg_score(
y_true, y_score, *, k=None, log_base=2, sample_weight=None, ignore_ties=False
):
"""Compute Discounted Cumulative Gain.
Sum the true scores ranked in the order induced by the predicted scores,
after applying a logarithmic discount.
This ranking metric yields a high value if true labels are ranked high by
``y_score``.
Usually the Normalized Discounted Cumulative Gain (NDCG, computed by
ndcg_score) is preferred.
Parameters
----------
y_true : array-like of shape (n_samples, n_labels)
True targets of multilabel classification, or true scores of entities
to be ranked.
y_score : array-like of shape (n_samples, n_labels)
Target scores, can either be probability estimates, confidence values,
or non-thresholded measure of decisions (as returned by
"decision_function" on some classifiers).
k : int, default=None
Only consider the highest k scores in the ranking. If None, use all
outputs.
log_base : float, default=2
Base of the logarithm used for the discount. A low value means a
sharper discount (top results are more important).
sample_weight : array-like of shape (n_samples,), default=None
Sample weights. If `None`, all samples are given the same weight.
ignore_ties : bool, default=False
Assume that there are no ties in y_score (which is likely to be the
case if y_score is continuous) for efficiency gains.
Returns
-------
discounted_cumulative_gain : float
The averaged sample DCG scores.
See Also
--------
ndcg_score : The Discounted Cumulative Gain divided by the Ideal Discounted
Cumulative Gain (the DCG obtained for a perfect ranking), in order to
have a score between 0 and 1.
References
----------
`Wikipedia entry for Discounted Cumulative Gain
<https://en.wikipedia.org/wiki/Discounted_cumulative_gain>`_.
Jarvelin, K., & Kekalainen, J. (2002).
Cumulated gain-based evaluation of IR techniques. ACM Transactions on
Information Systems (TOIS), 20(4), 422-446.
Wang, Y., Wang, L., Li, Y., He, D., Chen, W., & Liu, T. Y. (2013, May).
A theoretical analysis of NDCG ranking measures. In Proceedings of the 26th
Annual Conference on Learning Theory (COLT 2013).
McSherry, F., & Najork, M. (2008, March). Computing information retrieval
performance measures efficiently in the presence of tied scores. In
European conference on information retrieval (pp. 414-421). Springer,
Berlin, Heidelberg.
Examples
--------
>>> import numpy as np
>>> from sklearn.metrics import dcg_score
>>> # we have ground-truth relevance of some answers to a query:
>>> true_relevance = np.asarray([[10, 0, 0, 1, 5]])
>>> # we predict scores for the answers
>>> scores = np.asarray([[.1, .2, .3, 4, 70]])
>>> dcg_score(true_relevance, scores)
np.float64(9.49...)
>>> # we can set k to truncate the sum; only top k answers contribute
>>> dcg_score(true_relevance, scores, k=2)
np.float64(5.63...)
>>> # now we have some ties in our prediction
>>> scores = np.asarray([[1, 0, 0, 0, 1]])
>>> # by default ties are averaged, so here we get the average true
>>> # relevance of our top predictions: (10 + 5) / 2 = 7.5
>>> dcg_score(true_relevance, scores, k=1)
np.float64(7.5)
>>> # we can choose to ignore ties for faster results, but only
>>> # if we know there aren't ties in our scores, otherwise we get
>>> # wrong results:
>>> dcg_score(true_relevance,
... scores, k=1, ignore_ties=True)
np.float64(5.0)
"""
y_true = check_array(y_true, ensure_2d=False)
y_score = check_array(y_score, ensure_2d=False)
check_consistent_length(y_true, y_score, sample_weight)
_check_dcg_target_type(y_true)
return np.average(
_dcg_sample_scores(
y_true, y_score, k=k, log_base=log_base, ignore_ties=ignore_ties
),
weights=sample_weight,
)
def _ndcg_sample_scores(y_true, y_score, k=None, ignore_ties=False):
"""Compute Normalized Discounted Cumulative Gain.
Sum the true scores ranked in the order induced by the predicted scores,
after applying a logarithmic discount. Then divide by the best possible
score (Ideal DCG, obtained for a perfect ranking) to obtain a score between
0 and 1.
This ranking metric yields a high value if true labels are ranked high by
``y_score``.
Parameters
----------
y_true : ndarray of shape (n_samples, n_labels)
True targets of multilabel classification, or true scores of entities
to be ranked.
y_score : ndarray of shape (n_samples, n_labels)
Target scores, can either be probability estimates, confidence values,
or non-thresholded measure of decisions (as returned by
"decision_function" on some classifiers).
k : int, default=None
Only consider the highest k scores in the ranking. If None, use all
outputs.
ignore_ties : bool, default=False
Assume that there are no ties in y_score (which is likely to be the
case if y_score is continuous) for efficiency gains.
Returns
-------
normalized_discounted_cumulative_gain : ndarray of shape (n_samples,)
The NDCG score for each sample (float in [0., 1.]).
See Also
--------
dcg_score : Discounted Cumulative Gain (not normalized).
"""
gain = _dcg_sample_scores(y_true, y_score, k, ignore_ties=ignore_ties)
# Here we use the order induced by y_true so we can ignore ties since
# the gain associated to tied indices is the same (permuting ties doesn't
# change the value of the re-ordered y_true)
normalizing_gain = _dcg_sample_scores(y_true, y_true, k, ignore_ties=True)
all_irrelevant = normalizing_gain == 0
gain[all_irrelevant] = 0
gain[~all_irrelevant] /= normalizing_gain[~all_irrelevant]
return gain
@validate_params(
{
"y_true": ["array-like"],
"y_score": ["array-like"],
"k": [Interval(Integral, 1, None, closed="left"), None],
"sample_weight": ["array-like", None],
"ignore_ties": ["boolean"],
},
prefer_skip_nested_validation=True,
)
def ndcg_score(y_true, y_score, *, k=None, sample_weight=None, ignore_ties=False):
"""Compute Normalized Discounted Cumulative Gain.
Sum the true scores ranked in the order induced by the predicted scores,
after applying a logarithmic discount. Then divide by the best possible
score (Ideal DCG, obtained for a perfect ranking) to obtain a score between
0 and 1.
This ranking metric returns a high value if true labels are ranked high by
``y_score``.
Parameters
----------
y_true : array-like of shape (n_samples, n_labels)
True targets of multilabel classification, or true scores of entities
to be ranked. Negative values in `y_true` may result in an output
that is not between 0 and 1.
y_score : array-like of shape (n_samples, n_labels)
Target scores, can either be probability estimates, confidence values,
or non-thresholded measure of decisions (as returned by
"decision_function" on some classifiers).
k : int, default=None
Only consider the highest k scores in the ranking. If `None`, use all
outputs.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights. If `None`, all samples are given the same weight.
ignore_ties : bool, default=False
Assume that there are no ties in y_score (which is likely to be the
case if y_score is continuous) for efficiency gains.
Returns
-------
normalized_discounted_cumulative_gain : float in [0., 1.]
The averaged NDCG scores for all samples.
See Also
--------
dcg_score : Discounted Cumulative Gain (not normalized).
References
----------
`Wikipedia entry for Discounted Cumulative Gain
<https://en.wikipedia.org/wiki/Discounted_cumulative_gain>`_
Jarvelin, K., & Kekalainen, J. (2002).
Cumulated gain-based evaluation of IR techniques. ACM Transactions on
Information Systems (TOIS), 20(4), 422-446.
Wang, Y., Wang, L., Li, Y., He, D., Chen, W., & Liu, T. Y. (2013, May).
A theoretical analysis of NDCG ranking measures. In Proceedings of the 26th
Annual Conference on Learning Theory (COLT 2013)
McSherry, F., & Najork, M. (2008, March). Computing information retrieval
performance measures efficiently in the presence of tied scores. In
European conference on information retrieval (pp. 414-421). Springer,
Berlin, Heidelberg.
Examples
--------
>>> import numpy as np
>>> from sklearn.metrics import ndcg_score
>>> # we have ground-truth relevance of some answers to a query:
>>> true_relevance = np.asarray([[10, 0, 0, 1, 5]])
>>> # we predict some scores (relevance) for the answers
>>> scores = np.asarray([[.1, .2, .3, 4, 70]])
>>> ndcg_score(true_relevance, scores)
np.float64(0.69...)
>>> scores = np.asarray([[.05, 1.1, 1., .5, .0]])
>>> ndcg_score(true_relevance, scores)
np.float64(0.49...)
>>> # we can set k to truncate the sum; only top k answers contribute.
>>> ndcg_score(true_relevance, scores, k=4)
np.float64(0.35...)
>>> # the normalization takes k into account so a perfect answer
>>> # would still get 1.0
>>> ndcg_score(true_relevance, true_relevance, k=4)
np.float64(1.0...)
>>> # now we have some ties in our prediction
>>> scores = np.asarray([[1, 0, 0, 0, 1]])
>>> # by default ties are averaged, so here we get the average (normalized)
>>> # true relevance of our top predictions: (10 / 10 + 5 / 10) / 2 = .75
>>> ndcg_score(true_relevance, scores, k=1)
np.float64(0.75...)
>>> # we can choose to ignore ties for faster results, but only
>>> # if we know there aren't ties in our scores, otherwise we get
>>> # wrong results:
>>> ndcg_score(true_relevance,
... scores, k=1, ignore_ties=True)
np.float64(0.5...)
"""
y_true = check_array(y_true, ensure_2d=False)
y_score = check_array(y_score, ensure_2d=False)
check_consistent_length(y_true, y_score, sample_weight)
if y_true.min() < 0:
raise ValueError("ndcg_score should not be used on negative y_true values.")
if y_true.ndim > 1 and y_true.shape[1] <= 1:
raise ValueError(
"Computing NDCG is only meaningful when there is more than 1 document. "
f"Got {y_true.shape[1]} instead."
)
_check_dcg_target_type(y_true)
gain = _ndcg_sample_scores(y_true, y_score, k=k, ignore_ties=ignore_ties)
return np.average(gain, weights=sample_weight)
@validate_params(
{
"y_true": ["array-like"],
"y_score": ["array-like"],
"k": [Interval(Integral, 1, None, closed="left")],
"normalize": ["boolean"],
"sample_weight": ["array-like", None],
"labels": ["array-like", None],
},
prefer_skip_nested_validation=True,
)
def top_k_accuracy_score(
y_true, y_score, *, k=2, normalize=True, sample_weight=None, labels=None
):
"""Top-k Accuracy classification score.
This metric computes the number of times where the correct label is among
the top `k` labels predicted (ranked by predicted scores). Note that the
multilabel case isn't covered here.
Read more in the :ref:`User Guide <top_k_accuracy_score>`
Parameters
----------
y_true : array-like of shape (n_samples,)
True labels.
y_score : array-like of shape (n_samples,) or (n_samples, n_classes)
Target scores. These can be either probability estimates or
non-thresholded decision values (as returned by
:term:`decision_function` on some classifiers).
The binary case expects scores with shape (n_samples,) while the
multiclass case expects scores with shape (n_samples, n_classes).
In the multiclass case, the order of the class scores must
correspond to the order of ``labels``, if provided, or else to
the numerical or lexicographical order of the labels in ``y_true``.
If ``y_true`` does not contain all the labels, ``labels`` must be
provided.
k : int, default=2
Number of most likely outcomes considered to find the correct label.
normalize : bool, default=True
If `True`, return the fraction of correctly classified samples.
Otherwise, return the number of correctly classified samples.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights. If `None`, all samples are given the same weight.
labels : array-like of shape (n_classes,), default=None
Multiclass only. List of labels that index the classes in ``y_score``.
If ``None``, the numerical or lexicographical order of the labels in
``y_true`` is used. If ``y_true`` does not contain all the labels,
``labels`` must be provided.
Returns
-------
score : float
The top-k accuracy score. The best performance is 1 with
`normalize == True` and the number of samples with
`normalize == False`.
See Also
--------
accuracy_score : Compute the accuracy score. By default, the function will
return the fraction of correct predictions divided by the total number
of predictions.
Notes
-----
In cases where two or more labels are assigned equal predicted scores,
the labels with the highest indices will be chosen first. This might
impact the result if the correct label falls after the threshold because
of that.
Examples
--------
>>> import numpy as np
>>> from sklearn.metrics import top_k_accuracy_score
>>> y_true = np.array([0, 1, 2, 2])
>>> y_score = np.array([[0.5, 0.2, 0.2], # 0 is in top 2
... [0.3, 0.4, 0.2], # 1 is in top 2
... [0.2, 0.4, 0.3], # 2 is in top 2
... [0.7, 0.2, 0.1]]) # 2 isn't in top 2
>>> top_k_accuracy_score(y_true, y_score, k=2)
np.float64(0.75)
>>> # Not normalizing gives the number of "correctly" classified samples
>>> top_k_accuracy_score(y_true, y_score, k=2, normalize=False)
np.int64(3)
"""
y_true = check_array(y_true, ensure_2d=False, dtype=None)
y_true = column_or_1d(y_true)
y_type = type_of_target(y_true, input_name="y_true")
if y_type == "binary" and labels is not None and len(labels) > 2:
y_type = "multiclass"
if y_type not in {"binary", "multiclass"}:
raise ValueError(
f"y type must be 'binary' or 'multiclass', got '{y_type}' instead."
)
y_score = check_array(y_score, ensure_2d=False)
if y_type == "binary":
if y_score.ndim == 2 and y_score.shape[1] != 1:
raise ValueError(
"`y_true` is binary while y_score is 2d with"
f" {y_score.shape[1]} classes. If `y_true` does not contain all the"
" labels, `labels` must be provided."
)
y_score = column_or_1d(y_score)
check_consistent_length(y_true, y_score, sample_weight)
y_score_n_classes = y_score.shape[1] if y_score.ndim == 2 else 2
if labels is None:
classes = _unique(y_true)
n_classes = len(classes)
if n_classes != y_score_n_classes:
raise ValueError(
f"Number of classes in 'y_true' ({n_classes}) not equal "
f"to the number of classes in 'y_score' ({y_score_n_classes})."
"You can provide a list of all known classes by assigning it "
"to the `labels` parameter."
)
else:
labels = column_or_1d(labels)
classes = _unique(labels)
n_labels = len(labels)
n_classes = len(classes)
if n_classes != n_labels:
raise ValueError("Parameter 'labels' must be unique.")
if not np.array_equal(classes, labels):
raise ValueError("Parameter 'labels' must be ordered.")
if n_classes != y_score_n_classes:
raise ValueError(
f"Number of given labels ({n_classes}) not equal to the "
f"number of classes in 'y_score' ({y_score_n_classes})."
)
if len(np.setdiff1d(y_true, classes)):
raise ValueError("'y_true' contains labels not in parameter 'labels'.")
if k >= n_classes:
warnings.warn(
(
f"'k' ({k}) greater than or equal to 'n_classes' ({n_classes}) "
"will result in a perfect score and is therefore meaningless."
),
UndefinedMetricWarning,
)
y_true_encoded = _encode(y_true, uniques=classes)
if y_type == "binary":
if k == 1:
threshold = 0.5 if y_score.min() >= 0 and y_score.max() <= 1 else 0
y_pred = (y_score > threshold).astype(np.int64)
hits = y_pred == y_true_encoded
else:
hits = np.ones_like(y_score, dtype=np.bool_)
elif y_type == "multiclass":
sorted_pred = np.argsort(y_score, axis=1, kind="mergesort")[:, ::-1]
hits = (y_true_encoded == sorted_pred[:, :k].T).any(axis=0)
if normalize:
return np.average(hits, weights=sample_weight)
elif sample_weight is None:
return np.sum(hits)
else:
return np.dot(hits, sample_weight)
|