Spaces:
Sleeping
Sleeping
File size: 6,090 Bytes
cf81a3f cf63db7 cf81a3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
import gradio as gr
import matplotlib.pyplot as plt
from huggingface_hub import from_pretrained_keras
# download the already pushed model
trained_models = [from_pretrained_keras("buio/attention_mil_classification")]
POSITIVE_CLASS = 1
BAG_COUNT = 1000
VAL_BAG_COUNT = 300
BAG_SIZE = 3
PLOT_SIZE = 1
ENSEMBLE_AVG_COUNT = 1
def create_bags(input_data, input_labels, positive_class, bag_count, instance_count):
# Set up bags.
bags = []
bag_labels = []
# Normalize input data.
input_data = np.divide(input_data, 255.0)
# Count positive samples.
count = 0
for _ in range(bag_count):
# Pick a fixed size random subset of samples.
index = np.random.choice(input_data.shape[0], instance_count, replace=False)
instances_data = input_data[index]
instances_labels = input_labels[index]
# By default, all bags are labeled as 0.
bag_label = 0
# Check if there is at least a positive class in the bag.
if positive_class in instances_labels:
# Positive bag will be labeled as 1.
bag_label = 1
count += 1
bags.append(instances_data)
bag_labels.append(np.array([bag_label]))
print(f"Positive bags: {count}")
print(f"Negative bags: {bag_count - count}")
return (list(np.swapaxes(bags, 0, 1)), np.array(bag_labels))
# Load the MNIST dataset.
(x_train, y_train), (x_val, y_val) = keras.datasets.mnist.load_data()
# Create validation data.
val_data, val_labels = create_bags(
x_val, y_val, POSITIVE_CLASS, VAL_BAG_COUNT, BAG_SIZE
)
def predict(data, labels, trained_models):
# Collect info per model.
models_predictions = []
models_attention_weights = []
models_losses = []
models_accuracies = []
for model in trained_models:
# Predict output classes on data.
predictions = model.predict(data)
models_predictions.append(predictions)
# Create intermediate model to get MIL attention layer weights.
intermediate_model = keras.Model(model.input, model.get_layer("alpha").output)
# Predict MIL attention layer weights.
intermediate_predictions = intermediate_model.predict(data)
attention_weights = np.squeeze(np.swapaxes(intermediate_predictions, 1, 0))
models_attention_weights.append(attention_weights)
model.compile(loss="sparse_categorical_crossentropy", metrics=["accuracy"])
loss, accuracy = model.evaluate(data, labels, verbose=0)
models_losses.append(loss)
models_accuracies.append(accuracy)
print(
f"The average loss and accuracy are {np.sum(models_losses, axis=0) / ENSEMBLE_AVG_COUNT:.2f}"
f" and {100 * np.sum(models_accuracies, axis=0) / ENSEMBLE_AVG_COUNT:.2f} % resp."
)
return (
np.sum(models_predictions, axis=0) / ENSEMBLE_AVG_COUNT,
np.sum(models_attention_weights, axis=0) / ENSEMBLE_AVG_COUNT,
)
def plot(data, labels, bag_class, predictions=None, attention_weights=None):
""""Utility for plotting bags and attention weights.
Args:
data: Input data that contains the bags of instances.
labels: The associated bag labels of the input data.
bag_class: String name of the desired bag class.
The options are: "positive" or "negative".
predictions: Class labels model predictions.
If you don't specify anything, ground truth labels will be used.
attention_weights: Attention weights for each instance within the input data.
If you don't specify anything, the values won't be displayed.
"""
labels = np.array(labels).reshape(-1)
if bag_class == "positive":
if predictions is not None:
labels = np.where(predictions.argmax(1) == 1)[0]
else:
labels = np.where(labels == 1)[0]
random_labels = np.random.choice(labels, PLOT_SIZE)
bags = np.array(data)[:, random_labels]
elif bag_class == "negative":
if predictions is not None:
labels = np.where(predictions.argmax(1) == 0)[0]
else:
labels = np.where(labels == 0)[0]
random_labels = np.random.choice(labels, PLOT_SIZE)
bags = np.array(data)[:, random_labels]
else:
print(f"There is no class {bag_class}")
return
print(f"The bag class label is {bag_class}")
for i in range(PLOT_SIZE):
figure = plt.figure(figsize=(8, 8)) #each image
print(f"Bag number: {labels[i]}")
for j in range(BAG_SIZE):
image = bags[j][i]
figure.add_subplot(1, BAG_SIZE, j + 1)
plt.grid(False)
plt.axis('off')
if attention_weights is not None:
plt.title(np.around(attention_weights[random_labels[i]][j], 2))
plt.imshow(image)
plt.show()
return figure
# Evaluate and predict classes and attention scores on validation data.
def predict_and_plot(class_):
print('WTF')
class_predictions, attention_params = predict(val_data, val_labels, trained_models)
PLOT_SIZE = 1
return plot(val_data, val_labels, class_,
predictions=class_predictions,
attention_weights=attention_params)
predict_and_plot('positive')
inputs = gr.Radio(choices=['positive','negative'])
outputs = gr.Plot(label='predicted bag')
#title = "Heart Disease Classification 🩺❤️"
#description = "Binary classification of structured data including numerical and categorical features."
#article = "Author: <a href=\"https://huggingface.co/buio\">Marco Buiani</a>. Based on the <a href=\"https://keras.io/examples/structured_data/structured_data_classification_from_scratch/\">keras example</a> by <a href=\"https://twitter.com/fchollet\">François Chollet</a> Model Link: https://huggingface.co/buio/structured-data-classification"
demo = gr.Interface(fn=predict_and_plot, inputs=inputs, outputs=outputs, allow_flagging='never')
demo.launch(debug=True) |