|
"""Contains classes and methods related to interpretation for components in Gradio.""" |
|
|
|
from __future__ import annotations |
|
|
|
import copy |
|
import math |
|
from abc import ABC, abstractmethod |
|
from typing import TYPE_CHECKING, Any |
|
|
|
import numpy as np |
|
from gradio_client import utils as client_utils |
|
|
|
from gradio import components |
|
|
|
if TYPE_CHECKING: |
|
from gradio import Interface |
|
|
|
|
|
class Interpretable(ABC): |
|
def __init__(self) -> None: |
|
self.set_interpret_parameters() |
|
|
|
def set_interpret_parameters(self): |
|
""" |
|
Set any parameters for interpretation. Properties can be set here to be |
|
used in get_interpretation_neighbors and get_interpretation_scores. |
|
""" |
|
pass |
|
|
|
def get_interpretation_scores( |
|
self, x: Any, neighbors: list[Any] | None, scores: list[float], **kwargs |
|
) -> list: |
|
""" |
|
Arrange the output values from the neighbors into interpretation scores for the interface to render. |
|
Parameters: |
|
x: Input to interface |
|
neighbors: Neighboring values to input x used for interpretation. |
|
scores: Output value corresponding to each neighbor in neighbors |
|
Returns: |
|
Arrangement of interpretation scores for interfaces to render. |
|
""" |
|
return scores |
|
|
|
|
|
class TokenInterpretable(Interpretable, ABC): |
|
@abstractmethod |
|
def tokenize(self, x: Any) -> tuple[list, list, None]: |
|
""" |
|
Interprets an input data point x by splitting it into a list of tokens (e.g |
|
a string into words or an image into super-pixels). |
|
""" |
|
return [], [], None |
|
|
|
@abstractmethod |
|
def get_masked_inputs(self, tokens: list, binary_mask_matrix: list[list]) -> list: |
|
return [] |
|
|
|
|
|
class NeighborInterpretable(Interpretable, ABC): |
|
@abstractmethod |
|
def get_interpretation_neighbors(self, x: Any) -> tuple[list, dict]: |
|
""" |
|
Generates values similar to input to be used to interpret the significance of the input in the final output. |
|
Parameters: |
|
x: Input to interface |
|
Returns: (neighbor_values, interpret_kwargs, interpret_by_removal) |
|
neighbor_values: Neighboring values to input x to compute for interpretation |
|
interpret_kwargs: Keyword arguments to be passed to get_interpretation_scores |
|
""" |
|
return [], {} |
|
|
|
|
|
async def run_interpret(interface: Interface, raw_input: list): |
|
""" |
|
Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box |
|
interpretation for a certain set of UI component types, as well as the custom interpretation case. |
|
Parameters: |
|
raw_input: a list of raw inputs to apply the interpretation(s) on. |
|
""" |
|
if isinstance(interface.interpretation, list): |
|
processed_input = [ |
|
input_component.preprocess(raw_input[i]) |
|
for i, input_component in enumerate(interface.input_components) |
|
] |
|
original_output = await interface.call_function(0, processed_input) |
|
original_output = original_output["prediction"] |
|
|
|
if len(interface.output_components) == 1: |
|
original_output = [original_output] |
|
|
|
scores, alternative_outputs = [], [] |
|
|
|
for i, (x, interp) in enumerate(zip(raw_input, interface.interpretation)): |
|
if interp == "default": |
|
input_component = interface.input_components[i] |
|
neighbor_raw_input = list(raw_input) |
|
if isinstance(input_component, TokenInterpretable): |
|
tokens, neighbor_values, masks = input_component.tokenize(x) |
|
interface_scores = [] |
|
alternative_output = [] |
|
for neighbor_input in neighbor_values: |
|
neighbor_raw_input[i] = neighbor_input |
|
processed_neighbor_input = [ |
|
input_component.preprocess(neighbor_raw_input[i]) |
|
for i, input_component in enumerate( |
|
interface.input_components |
|
) |
|
] |
|
|
|
neighbor_output = await interface.call_function( |
|
0, processed_neighbor_input |
|
) |
|
neighbor_output = neighbor_output["prediction"] |
|
if len(interface.output_components) == 1: |
|
neighbor_output = [neighbor_output] |
|
processed_neighbor_output = [ |
|
output_component.postprocess(neighbor_output[i]) |
|
for i, output_component in enumerate( |
|
interface.output_components |
|
) |
|
] |
|
|
|
alternative_output.append(processed_neighbor_output) |
|
interface_scores.append( |
|
quantify_difference_in_label( |
|
interface, original_output, neighbor_output |
|
) |
|
) |
|
alternative_outputs.append(alternative_output) |
|
scores.append( |
|
input_component.get_interpretation_scores( |
|
raw_input[i], |
|
neighbor_values, |
|
interface_scores, |
|
masks=masks, |
|
tokens=tokens, |
|
) |
|
) |
|
elif isinstance(input_component, NeighborInterpretable): |
|
( |
|
neighbor_values, |
|
interpret_kwargs, |
|
) = input_component.get_interpretation_neighbors( |
|
x |
|
) |
|
interface_scores = [] |
|
alternative_output = [] |
|
for neighbor_input in neighbor_values: |
|
neighbor_raw_input[i] = neighbor_input |
|
processed_neighbor_input = [ |
|
input_component.preprocess(neighbor_raw_input[i]) |
|
for i, input_component in enumerate( |
|
interface.input_components |
|
) |
|
] |
|
neighbor_output = await interface.call_function( |
|
0, processed_neighbor_input |
|
) |
|
neighbor_output = neighbor_output["prediction"] |
|
if len(interface.output_components) == 1: |
|
neighbor_output = [neighbor_output] |
|
processed_neighbor_output = [ |
|
output_component.postprocess(neighbor_output[i]) |
|
for i, output_component in enumerate( |
|
interface.output_components |
|
) |
|
] |
|
|
|
alternative_output.append(processed_neighbor_output) |
|
interface_scores.append( |
|
quantify_difference_in_label( |
|
interface, original_output, neighbor_output |
|
) |
|
) |
|
alternative_outputs.append(alternative_output) |
|
interface_scores = [-score for score in interface_scores] |
|
scores.append( |
|
input_component.get_interpretation_scores( |
|
raw_input[i], |
|
neighbor_values, |
|
interface_scores, |
|
**interpret_kwargs, |
|
) |
|
) |
|
else: |
|
raise ValueError( |
|
f"Component {input_component} does not support interpretation" |
|
) |
|
elif interp == "shap" or interp == "shapley": |
|
try: |
|
import shap |
|
except (ImportError, ModuleNotFoundError) as err: |
|
raise ValueError( |
|
"The package `shap` is required for this interpretation method. Try: `pip install shap`" |
|
) from err |
|
input_component = interface.input_components[i] |
|
if not isinstance(input_component, TokenInterpretable): |
|
raise ValueError( |
|
f"Input component {input_component} does not support `shap` interpretation" |
|
) |
|
|
|
tokens, _, masks = input_component.tokenize(x) |
|
|
|
|
|
def get_masked_prediction(binary_mask): |
|
assert isinstance(input_component, TokenInterpretable) |
|
masked_xs = input_component.get_masked_inputs(tokens, binary_mask) |
|
preds = [] |
|
for masked_x in masked_xs: |
|
processed_masked_input = copy.deepcopy(processed_input) |
|
processed_masked_input[i] = input_component.preprocess(masked_x) |
|
new_output = client_utils.synchronize_async( |
|
interface.call_function, 0, processed_masked_input |
|
) |
|
new_output = new_output["prediction"] |
|
if len(interface.output_components) == 1: |
|
new_output = [new_output] |
|
pred = get_regression_or_classification_value( |
|
interface, original_output, new_output |
|
) |
|
preds.append(pred) |
|
return np.array(preds) |
|
|
|
num_total_segments = len(tokens) |
|
explainer = shap.KernelExplainer( |
|
get_masked_prediction, np.zeros((1, num_total_segments)) |
|
) |
|
shap_values = explainer.shap_values( |
|
np.ones((1, num_total_segments)), |
|
nsamples=int(interface.num_shap * num_total_segments), |
|
silent=True, |
|
) |
|
assert shap_values is not None, "SHAP values could not be calculated" |
|
scores.append( |
|
input_component.get_interpretation_scores( |
|
raw_input[i], |
|
None, |
|
shap_values[0].tolist(), |
|
masks=masks, |
|
tokens=tokens, |
|
) |
|
) |
|
alternative_outputs.append([]) |
|
elif interp is None: |
|
scores.append(None) |
|
alternative_outputs.append([]) |
|
else: |
|
raise ValueError(f"Unknown interpretation method: {interp}") |
|
return scores, alternative_outputs |
|
elif interface.interpretation: |
|
processed_input = [ |
|
input_component.preprocess(raw_input[i]) |
|
for i, input_component in enumerate(interface.input_components) |
|
] |
|
interpreter = interface.interpretation |
|
interpretation = interpreter(*processed_input) |
|
if len(raw_input) == 1: |
|
interpretation = [interpretation] |
|
return interpretation, [] |
|
else: |
|
raise ValueError("No interpretation method specified.") |
|
|
|
|
|
def diff(original: Any, perturbed: Any) -> int | float: |
|
try: |
|
score = float(original) - float(perturbed) |
|
except ValueError: |
|
score = int(original != perturbed) |
|
return score |
|
|
|
|
|
def quantify_difference_in_label( |
|
interface: Interface, original_output: list, perturbed_output: list |
|
) -> int | float: |
|
output_component = interface.output_components[0] |
|
post_original_output = output_component.postprocess(original_output[0]) |
|
post_perturbed_output = output_component.postprocess(perturbed_output[0]) |
|
|
|
if isinstance(output_component, components.Label): |
|
original_label = post_original_output["label"] |
|
perturbed_label = post_perturbed_output["label"] |
|
|
|
|
|
if "confidences" in post_original_output: |
|
original_confidence = original_output[0][original_label] |
|
perturbed_confidence = perturbed_output[0][original_label] |
|
score = original_confidence - perturbed_confidence |
|
else: |
|
score = diff(original_label, perturbed_label) |
|
return score |
|
|
|
elif isinstance(output_component, components.Number): |
|
score = diff(post_original_output, post_perturbed_output) |
|
return score |
|
|
|
else: |
|
raise ValueError( |
|
f"This interpretation method doesn't support the Output component: {output_component}" |
|
) |
|
|
|
|
|
def get_regression_or_classification_value( |
|
interface: Interface, original_output: list, perturbed_output: list |
|
) -> int | float: |
|
"""Used to combine regression/classification for Shap interpretation method.""" |
|
output_component = interface.output_components[0] |
|
post_original_output = output_component.postprocess(original_output[0]) |
|
post_perturbed_output = output_component.postprocess(perturbed_output[0]) |
|
|
|
if isinstance(output_component, components.Label): |
|
original_label = post_original_output["label"] |
|
perturbed_label = post_perturbed_output["label"] |
|
|
|
|
|
if "confidences" in post_original_output: |
|
if math.isnan(perturbed_output[0][original_label]): |
|
return 0 |
|
return perturbed_output[0][original_label] |
|
else: |
|
score = diff( |
|
perturbed_label, original_label |
|
) |
|
return score |
|
|
|
else: |
|
raise ValueError( |
|
f"This interpretation method doesn't support the Output component: {output_component}" |
|
) |
|
|