fradinho commited on
Commit
44cb6dd
·
1 Parent(s): bd79c0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -2
app.py CHANGED
@@ -14,6 +14,24 @@ from tensorflow import keras
14
  import segmentation_models as sm
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def jacard(y_true, y_pred):
18
  y_true_c = K.flatten(y_true)
19
  y_pred_c = K.flatten(y_pred)
@@ -203,11 +221,11 @@ def weighted_categorical_crossentropy(weights):
203
  from tensorflow.python.keras.utils import generic_utils
204
 
205
  # Load the model
206
- model = tf.keras.models.load_model("model.h5", custom_objects={"jacard":jacard, "wcce":weighted_categorical_crossentropy})
207
  #model = tf.keras.models.load_model("model_2.h5", custom_objects={"jacard":jacard, "bce_dice":bce_dice})
208
  ###model = tf.keras.models.load_model("model_2_A.h5", custom_objects={"jacard":jacard, "bce_dice":bce_dice})
209
  #model = tf.keras.models.load_model("model_2_A_0.h5", custom_objects={"jacard":jacard, "bce_dice":bce_dice})
210
-
211
 
212
  # Create a user interface for the model
213
  my_app = gr.Blocks()
 
14
  import segmentation_models as sm
15
 
16
 
17
+ def dice_metric(y_pred, y_true):
18
+ intersection = K.sum(K.sum(K.abs(y_true * y_pred), axis=-1))
19
+ union = K.sum(K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1))
20
+ # if y_pred.sum() == 0 and y_pred.sum() == 0:
21
+ # return 1.0
22
+
23
+ return (2*intersection) / union
24
+
25
+
26
+ def focal_loss(predict, true):
27
+ error = tf.keras.losses.binary_crossentropy(predict, true)
28
+ pt = tf.exp(error)
29
+ focal_loss = (1 - pt) ** 2 * error
30
+ return dice_metric(predict, true) + (1*tf.reduce_mean(focal_loss))
31
+
32
+ def focal_iou(y_true, y_pred):
33
+ return focal_loss(y_true, y_pred) - K.log(jacard(y_true, y_pred))
34
+
35
  def jacard(y_true, y_pred):
36
  y_true_c = K.flatten(y_true)
37
  y_pred_c = K.flatten(y_pred)
 
221
  from tensorflow.python.keras.utils import generic_utils
222
 
223
  # Load the model
224
+ #model = tf.keras.models.load_model("model.h5", custom_objects={"jacard":jacard, "wcce":weighted_categorical_crossentropy})
225
  #model = tf.keras.models.load_model("model_2.h5", custom_objects={"jacard":jacard, "bce_dice":bce_dice})
226
  ###model = tf.keras.models.load_model("model_2_A.h5", custom_objects={"jacard":jacard, "bce_dice":bce_dice})
227
  #model = tf.keras.models.load_model("model_2_A_0.h5", custom_objects={"jacard":jacard, "bce_dice":bce_dice})
228
+ model = tf.keras.models.load_model("model_3_A.h5", custom_objects={"jacard":jacard, "bce_dice":bce_dice, "focal_iou":focal_iou})
229
 
230
  # Create a user interface for the model
231
  my_app = gr.Blocks()