buio's picture
Create app.py
cf81a3f
raw
history blame
6.1 kB
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
import gradio as gr
import matplotlib.pyplot as plt
from huggingface_hub import from_pretrained_keras
# download the already pushed model
trained_models = [from_pretrained_keras("buio/attention_mil_classification")]
POSITIVE_CLASS = 1
BAG_COUNT = 1000
VAL_BAG_COUNT = 300
BAG_SIZE = 3
PLOT_SIZE = 1
ENSEMBLE_AVG_COUNT = 1
def create_bags(input_data, input_labels, positive_class, bag_count, instance_count):
# Set up bags.
bags = []
bag_labels = []
# Normalize input data.
input_data = np.divide(input_data, 255.0)
# Count positive samples.
count = 0
for _ in range(bag_count):
# Pick a fixed size random subset of samples.
index = np.random.choice(input_data.shape[0], instance_count, replace=False)
instances_data = input_data[index]
instances_labels = input_labels[index]
# By default, all bags are labeled as 0.
bag_label = 0
# Check if there is at least a positive class in the bag.
if positive_class in instances_labels:
# Positive bag will be labeled as 1.
bag_label = 1
count += 1
bags.append(instances_data)
bag_labels.append(np.array([bag_label]))
print(f"Positive bags: {count}")
print(f"Negative bags: {bag_count - count}")
return (list(np.swapaxes(bags, 0, 1)), np.array(bag_labels))
# Load the MNIST dataset.
(x_train, y_train), (x_val, y_val) = keras.datasets.mnist.load_data()
# Create validation data.
val_data, val_labels = create_bags(
x_val, y_val, POSITIVE_CLASS, VAL_BAG_COUNT, BAG_SIZE
)
def predict(data, labels, trained_models):
# Collect info per model.
models_predictions = []
models_attention_weights = []
models_losses = []
models_accuracies = []
for model in trained_models:
# Predict output classes on data.
predictions = model.predict(data)
models_predictions.append(predictions)
# Create intermediate model to get MIL attention layer weights.
intermediate_model = keras.Model(model.input, model.get_layer("alpha").output)
# Predict MIL attention layer weights.
intermediate_predictions = intermediate_model.predict(data)
attention_weights = np.squeeze(np.swapaxes(intermediate_predictions, 1, 0))
models_attention_weights.append(attention_weights)
model.compile(loss="sparse_categorical_crossentropy", metrics=["accuracy"])
loss, accuracy = model.evaluate(data, labels, verbose=0)
models_losses.append(loss)
models_accuracies.append(accuracy)
print(
f"The average loss and accuracy are {np.sum(models_losses, axis=0) / ENSEMBLE_AVG_COUNT:.2f}"
f" and {100 * np.sum(models_accuracies, axis=0) / ENSEMBLE_AVG_COUNT:.2f} % resp."
)
return (
np.sum(models_predictions, axis=0) / ENSEMBLE_AVG_COUNT,
np.sum(models_attention_weights, axis=0) / ENSEMBLE_AVG_COUNT,
)
def plot(data, labels, bag_class, predictions=None, attention_weights=None):
""""Utility for plotting bags and attention weights.
Args:
data: Input data that contains the bags of instances.
labels: The associated bag labels of the input data.
bag_class: String name of the desired bag class.
The options are: "positive" or "negative".
predictions: Class labels model predictions.
If you don't specify anything, ground truth labels will be used.
attention_weights: Attention weights for each instance within the input data.
If you don't specify anything, the values won't be displayed.
"""
labels = np.array(labels).reshape(-1)
if bag_class == "positive":
if predictions is not None:
labels = np.where(predictions.argmax(1) == 1)[0]
else:
labels = np.where(labels == 1)[0]
random_labels = np.random.choice(labels, PLOT_SIZE)
bags = np.array(data)[:, random_labels]
elif bag_class == "negative":
if predictions is not None:
labels = np.where(predictions.argmax(1) == 0)[0]
else:
labels = np.where(labels == 0)[0]
random_labels = np.random.choice(labels, PLOT_SIZE)
bags = np.array(data)[:, random_labels]
else:
print(f"There is no class {bag_class}")
return
print(f"The bag class label is {bag_class}")
for i in range(PLOT_SIZE):
figure = plt.figure(figsize=(8, 8)) #each image
print(f"Bag number: {labels[i]}")
for j in range(BAG_SIZE):
image = bags[j][i]
figure.add_subplot(1, BAG_SIZE, j + 1)
plt.grid(False)
plt.axis('off')
if attention_weights is not None:
plt.title(np.around(attention_weights[random_labels[i]][j], 2))
plt.imshow(image)
plt.show()
return figure
# Evaluate and predict classes and attention scores on validation data.
def predict_and_plot(class_):
print('WTF')
class_predictions, attention_params = predict(val_data, val_labels, trained_models)
PLOT_SIZE = 1
return plot(val_data, val_labels, class_,
predictions=class_predictions,
attention_weights=attention_params)
predict_and_plot('positive')
inputs = gr.Radio(choices=['positive','negative'])
outputs = gr.Plot(label='predicted bag')
#title = "Heart Disease Classification 🩺❤️"
#description = "Binary classification of structured data including numerical and categorical features."
#article = "Author: <a href=\"https://huggingface.co/buio\">Marco Buiani</a>. Based on the <a href=\"https://keras.io/examples/structured_data/structured_data_classification_from_scratch/\">keras example</a> by <a href=\"https://twitter.com/fchollet\">François Chollet</a> Model Link: https://huggingface.co/buio/structured-data-classification"
demo = gr.Interface(fn=predict_and_plot, inputs=inputs, outputs=outputs, title=title, allow_flagging='never')
demo.launch(debug=True)