File size: 8,126 Bytes
7885a28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 |
"""
Module contains classes for invertible (and differentiable) link functions.
"""
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
from abc import ABC, abstractmethod
from dataclasses import dataclass
import numpy as np
from scipy.special import expit, logit
from scipy.stats import gmean
from ..utils.extmath import softmax
@dataclass
class Interval:
low: float
high: float
low_inclusive: bool
high_inclusive: bool
def __post_init__(self):
"""Check that low <= high"""
if self.low > self.high:
raise ValueError(
f"One must have low <= high; got low={self.low}, high={self.high}."
)
def includes(self, x):
"""Test whether all values of x are in interval range.
Parameters
----------
x : ndarray
Array whose elements are tested to be in interval range.
Returns
-------
result : bool
"""
if self.low_inclusive:
low = np.greater_equal(x, self.low)
else:
low = np.greater(x, self.low)
if not np.all(low):
return False
if self.high_inclusive:
high = np.less_equal(x, self.high)
else:
high = np.less(x, self.high)
# Note: np.all returns numpy.bool_
return bool(np.all(high))
def _inclusive_low_high(interval, dtype=np.float64):
"""Generate values low and high to be within the interval range.
This is used in tests only.
Returns
-------
low, high : tuple
The returned values low and high lie within the interval.
"""
eps = 10 * np.finfo(dtype).eps
if interval.low == -np.inf:
low = -1e10
elif interval.low < 0:
low = interval.low * (1 - eps) + eps
else:
low = interval.low * (1 + eps) + eps
if interval.high == np.inf:
high = 1e10
elif interval.high < 0:
high = interval.high * (1 + eps) - eps
else:
high = interval.high * (1 - eps) - eps
return low, high
class BaseLink(ABC):
"""Abstract base class for differentiable, invertible link functions.
Convention:
- link function g: raw_prediction = g(y_pred)
- inverse link h: y_pred = h(raw_prediction)
For (generalized) linear models, `raw_prediction = X @ coef` is the so
called linear predictor, and `y_pred = h(raw_prediction)` is the predicted
conditional (on X) expected value of the target `y_true`.
The methods are not implemented as staticmethods in case a link function needs
parameters.
"""
is_multiclass = False # used for testing only
# Usually, raw_prediction may be any real number and y_pred is an open
# interval.
# interval_raw_prediction = Interval(-np.inf, np.inf, False, False)
interval_y_pred = Interval(-np.inf, np.inf, False, False)
@abstractmethod
def link(self, y_pred, out=None):
"""Compute the link function g(y_pred).
The link function maps (predicted) target values to raw predictions,
i.e. `g(y_pred) = raw_prediction`.
Parameters
----------
y_pred : array
Predicted target values.
out : array
A location into which the result is stored. If provided, it must
have a shape that the inputs broadcast to. If not provided or None,
a freshly-allocated array is returned.
Returns
-------
out : array
Output array, element-wise link function.
"""
@abstractmethod
def inverse(self, raw_prediction, out=None):
"""Compute the inverse link function h(raw_prediction).
The inverse link function maps raw predictions to predicted target
values, i.e. `h(raw_prediction) = y_pred`.
Parameters
----------
raw_prediction : array
Raw prediction values (in link space).
out : array
A location into which the result is stored. If provided, it must
have a shape that the inputs broadcast to. If not provided or None,
a freshly-allocated array is returned.
Returns
-------
out : array
Output array, element-wise inverse link function.
"""
class IdentityLink(BaseLink):
"""The identity link function g(x)=x."""
def link(self, y_pred, out=None):
if out is not None:
np.copyto(out, y_pred)
return out
else:
return y_pred
inverse = link
class LogLink(BaseLink):
"""The log link function g(x)=log(x)."""
interval_y_pred = Interval(0, np.inf, False, False)
def link(self, y_pred, out=None):
return np.log(y_pred, out=out)
def inverse(self, raw_prediction, out=None):
return np.exp(raw_prediction, out=out)
class LogitLink(BaseLink):
"""The logit link function g(x)=logit(x)."""
interval_y_pred = Interval(0, 1, False, False)
def link(self, y_pred, out=None):
return logit(y_pred, out=out)
def inverse(self, raw_prediction, out=None):
return expit(raw_prediction, out=out)
class HalfLogitLink(BaseLink):
"""Half the logit link function g(x)=1/2 * logit(x).
Used for the exponential loss.
"""
interval_y_pred = Interval(0, 1, False, False)
def link(self, y_pred, out=None):
out = logit(y_pred, out=out)
out *= 0.5
return out
def inverse(self, raw_prediction, out=None):
return expit(2 * raw_prediction, out)
class MultinomialLogit(BaseLink):
"""The symmetric multinomial logit function.
Convention:
- y_pred.shape = raw_prediction.shape = (n_samples, n_classes)
Notes:
- The inverse link h is the softmax function.
- The sum is over the second axis, i.e. axis=1 (n_classes).
We have to choose additional constraints in order to make
y_pred[k] = exp(raw_pred[k]) / sum(exp(raw_pred[k]), k=0..n_classes-1)
for n_classes classes identifiable and invertible.
We choose the symmetric side constraint where the geometric mean response
is set as reference category, see [2]:
The symmetric multinomial logit link function for a single data point is
then defined as
raw_prediction[k] = g(y_pred[k]) = log(y_pred[k]/gmean(y_pred))
= log(y_pred[k]) - mean(log(y_pred)).
Note that this is equivalent to the definition in [1] and implies mean
centered raw predictions:
sum(raw_prediction[k], k=0..n_classes-1) = 0.
For linear models with raw_prediction = X @ coef, this corresponds to
sum(coef[k], k=0..n_classes-1) = 0, i.e. the sum over classes for every
feature is zero.
Reference
---------
.. [1] Friedman, Jerome; Hastie, Trevor; Tibshirani, Robert. "Additive
logistic regression: a statistical view of boosting" Ann. Statist.
28 (2000), no. 2, 337--407. doi:10.1214/aos/1016218223.
https://projecteuclid.org/euclid.aos/1016218223
.. [2] Zahid, Faisal Maqbool and Gerhard Tutz. "Ridge estimation for
multinomial logit models with symmetric side constraints."
Computational Statistics 28 (2013): 1017-1034.
http://epub.ub.uni-muenchen.de/11001/1/tr067.pdf
"""
is_multiclass = True
interval_y_pred = Interval(0, 1, False, False)
def symmetrize_raw_prediction(self, raw_prediction):
return raw_prediction - np.mean(raw_prediction, axis=1)[:, np.newaxis]
def link(self, y_pred, out=None):
# geometric mean as reference category
gm = gmean(y_pred, axis=1)
return np.log(y_pred / gm[:, np.newaxis], out=out)
def inverse(self, raw_prediction, out=None):
if out is None:
return softmax(raw_prediction, copy=True)
else:
np.copyto(out, raw_prediction)
softmax(out, copy=False)
return out
_LINKS = {
"identity": IdentityLink,
"log": LogLink,
"logit": LogitLink,
"half_logit": HalfLogitLink,
"multinomial_logit": MultinomialLogit,
}
|