Citelab / context_cite /solver.py
SHEN1017's picture
Upload 97 files
96b6673 verified
raw
history blame
1.82 kB
import numpy as np
from numpy.typing import NDArray
from typing import Tuple
from abc import ABC, abstractmethod
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import Lasso
from sklearn.pipeline import make_pipeline
class BaseSolver(ABC):
"""
A base solver class.
Methods:
fit(self, masks: NDArray, outputs: NDArray, num_output_tokens: int) -> Tuple[NDArray, NDArray]:
Fit the solver to the given data.
"""
@abstractmethod
def fit(
self, masks: NDArray, outputs: NDArray, num_output_tokens: int
) -> Tuple[NDArray, NDArray]: ...
class LassoRegression(BaseSolver):
"""
A LASSO solver using the scikit-learn library.
Attributes:
lasso_alpha (float):
The alpha parameter for the LASSO regression. Defaults to 0.01.
Methods:
fit(self, masks: NDArray, outputs: NDArray, num_output_tokens: int) -> Tuple[NDArray, NDArray]:
Fit the solver to the given data.
"""
def __init__(self, lasso_alpha: float = 0.01) -> None:
self.lasso_alpha = lasso_alpha
def fit(
self, masks: NDArray, outputs: NDArray, num_output_tokens: int
) -> Tuple[NDArray, NDArray]:
X = masks.astype(np.float32)
Y = outputs / num_output_tokens
scaler = StandardScaler()
lasso = Lasso(alpha=self.lasso_alpha, random_state=0, fit_intercept=True)
# Pipeline is ((X - scaler.mean_) / scaler.scale_) @ lasso.coef_.T + lasso.intercept_
pipeline = make_pipeline(scaler, lasso)
pipeline.fit(X, Y)
# Rescale back to original scale
weight = lasso.coef_ / scaler.scale_
bias = lasso.intercept_ - (scaler.mean_ / scaler.scale_) @ lasso.coef_.T
return weight * num_output_tokens, bias * num_output_tokens