buio commited on
Commit
cf81a3f
·
1 Parent(s): 583a161

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -0
app.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import tensorflow.keras as keras
4
+ import gradio as gr
5
+ import matplotlib.pyplot as plt
6
+ from huggingface_hub import from_pretrained_keras
7
+
8
+
9
+ # download the already pushed model
10
+ trained_models = [from_pretrained_keras("buio/attention_mil_classification")]
11
+
12
+
13
+ POSITIVE_CLASS = 1
14
+ BAG_COUNT = 1000
15
+ VAL_BAG_COUNT = 300
16
+ BAG_SIZE = 3
17
+ PLOT_SIZE = 1
18
+ ENSEMBLE_AVG_COUNT = 1
19
+
20
+ def create_bags(input_data, input_labels, positive_class, bag_count, instance_count):
21
+
22
+ # Set up bags.
23
+ bags = []
24
+ bag_labels = []
25
+
26
+ # Normalize input data.
27
+ input_data = np.divide(input_data, 255.0)
28
+
29
+ # Count positive samples.
30
+ count = 0
31
+
32
+ for _ in range(bag_count):
33
+
34
+ # Pick a fixed size random subset of samples.
35
+ index = np.random.choice(input_data.shape[0], instance_count, replace=False)
36
+ instances_data = input_data[index]
37
+ instances_labels = input_labels[index]
38
+
39
+ # By default, all bags are labeled as 0.
40
+ bag_label = 0
41
+
42
+ # Check if there is at least a positive class in the bag.
43
+ if positive_class in instances_labels:
44
+
45
+ # Positive bag will be labeled as 1.
46
+ bag_label = 1
47
+ count += 1
48
+
49
+ bags.append(instances_data)
50
+ bag_labels.append(np.array([bag_label]))
51
+
52
+ print(f"Positive bags: {count}")
53
+ print(f"Negative bags: {bag_count - count}")
54
+
55
+ return (list(np.swapaxes(bags, 0, 1)), np.array(bag_labels))
56
+
57
+ # Load the MNIST dataset.
58
+ (x_train, y_train), (x_val, y_val) = keras.datasets.mnist.load_data()
59
+
60
+ # Create validation data.
61
+ val_data, val_labels = create_bags(
62
+ x_val, y_val, POSITIVE_CLASS, VAL_BAG_COUNT, BAG_SIZE
63
+ )
64
+
65
+
66
+ def predict(data, labels, trained_models):
67
+
68
+ # Collect info per model.
69
+ models_predictions = []
70
+ models_attention_weights = []
71
+ models_losses = []
72
+ models_accuracies = []
73
+
74
+ for model in trained_models:
75
+
76
+ # Predict output classes on data.
77
+ predictions = model.predict(data)
78
+ models_predictions.append(predictions)
79
+
80
+ # Create intermediate model to get MIL attention layer weights.
81
+ intermediate_model = keras.Model(model.input, model.get_layer("alpha").output)
82
+
83
+ # Predict MIL attention layer weights.
84
+ intermediate_predictions = intermediate_model.predict(data)
85
+
86
+ attention_weights = np.squeeze(np.swapaxes(intermediate_predictions, 1, 0))
87
+ models_attention_weights.append(attention_weights)
88
+
89
+ model.compile(loss="sparse_categorical_crossentropy", metrics=["accuracy"])
90
+ loss, accuracy = model.evaluate(data, labels, verbose=0)
91
+ models_losses.append(loss)
92
+ models_accuracies.append(accuracy)
93
+
94
+ print(
95
+ f"The average loss and accuracy are {np.sum(models_losses, axis=0) / ENSEMBLE_AVG_COUNT:.2f}"
96
+ f" and {100 * np.sum(models_accuracies, axis=0) / ENSEMBLE_AVG_COUNT:.2f} % resp."
97
+ )
98
+
99
+ return (
100
+ np.sum(models_predictions, axis=0) / ENSEMBLE_AVG_COUNT,
101
+ np.sum(models_attention_weights, axis=0) / ENSEMBLE_AVG_COUNT,
102
+ )
103
+
104
+ def plot(data, labels, bag_class, predictions=None, attention_weights=None):
105
+
106
+ """"Utility for plotting bags and attention weights.
107
+
108
+ Args:
109
+ data: Input data that contains the bags of instances.
110
+ labels: The associated bag labels of the input data.
111
+ bag_class: String name of the desired bag class.
112
+ The options are: "positive" or "negative".
113
+ predictions: Class labels model predictions.
114
+ If you don't specify anything, ground truth labels will be used.
115
+ attention_weights: Attention weights for each instance within the input data.
116
+ If you don't specify anything, the values won't be displayed.
117
+ """
118
+
119
+ labels = np.array(labels).reshape(-1)
120
+
121
+ if bag_class == "positive":
122
+ if predictions is not None:
123
+ labels = np.where(predictions.argmax(1) == 1)[0]
124
+ else:
125
+ labels = np.where(labels == 1)[0]
126
+
127
+ random_labels = np.random.choice(labels, PLOT_SIZE)
128
+ bags = np.array(data)[:, random_labels]
129
+
130
+ elif bag_class == "negative":
131
+ if predictions is not None:
132
+ labels = np.where(predictions.argmax(1) == 0)[0]
133
+ else:
134
+ labels = np.where(labels == 0)[0]
135
+
136
+ random_labels = np.random.choice(labels, PLOT_SIZE)
137
+ bags = np.array(data)[:, random_labels]
138
+
139
+ else:
140
+ print(f"There is no class {bag_class}")
141
+ return
142
+
143
+ print(f"The bag class label is {bag_class}")
144
+ for i in range(PLOT_SIZE):
145
+ figure = plt.figure(figsize=(8, 8)) #each image
146
+ print(f"Bag number: {labels[i]}")
147
+ for j in range(BAG_SIZE):
148
+ image = bags[j][i]
149
+ figure.add_subplot(1, BAG_SIZE, j + 1)
150
+ plt.grid(False)
151
+ plt.axis('off')
152
+ if attention_weights is not None:
153
+ plt.title(np.around(attention_weights[random_labels[i]][j], 2))
154
+ plt.imshow(image)
155
+ plt.show()
156
+ return figure
157
+
158
+
159
+ # Evaluate and predict classes and attention scores on validation data.
160
+ def predict_and_plot(class_):
161
+ print('WTF')
162
+ class_predictions, attention_params = predict(val_data, val_labels, trained_models)
163
+ PLOT_SIZE = 1
164
+ return plot(val_data, val_labels, class_,
165
+ predictions=class_predictions,
166
+ attention_weights=attention_params)
167
+
168
+ predict_and_plot('positive')
169
+
170
+ inputs = gr.Radio(choices=['positive','negative'])
171
+
172
+ outputs = gr.Plot(label='predicted bag')
173
+
174
+ #title = "Heart Disease Classification 🩺❤️"
175
+ #description = "Binary classification of structured data including numerical and categorical features."
176
+ #article = "Author: <a href=\"https://huggingface.co/buio\">Marco Buiani</a>. Based on the <a href=\"https://keras.io/examples/structured_data/structured_data_classification_from_scratch/\">keras example</a> by <a href=\"https://twitter.com/fchollet\">François Chollet</a> Model Link: https://huggingface.co/buio/structured-data-classification"
177
+
178
+ demo = gr.Interface(fn=predict_and_plot, inputs=inputs, outputs=outputs, title=title, allow_flagging='never')
179
+
180
+ demo.launch(debug=True)