explain-ViT / explain.py
WwYc's picture
Update explain.py
c32bf50 verified
raw
history blame
463 Bytes
import matplotlib.pyplot as plt
from visualization import generate_visualization
def do_explain(transform, image, class_index=None, use_threshold=False):
fig, axs = plt.subplots(1, 2)
axs[0].imshow(image)
axs[0].axis("off")
transformed_image = transform(image)
viz = generate_visualization(
transformed_image, class_index=class_index, use_threshold=use_threshold
)
axs[1].imshow(viz)
axs[1].axis("off")
return fig