Tensorflow: External calculation of dice coef on validation set different than my Unet's validation dice coef with same data set

Question:

So I am training a variation of a Unet style network in Tensorflow for a problem I am trying to solve. I have noticed an interesting pattern / error that I am unable to comprehend or fix.

As I have been training this network, on tensorboard the training loss is greater than validation loss, but the metric for validation is very low.(below)

enter image description here

But I have been looking at the output data from the network, and honestly, the output doesn’t appear "half bad", at least not something that’s a Dice of .25-.30

enter image description here

So when I externally validate the Dice by reloading the model and predicting on the validation set, I get a high dice score of > .90.

enter image description here

I have a feeling this is due to my loss and metrics utilized, but am unsure how to proceed. My loss metrics, and external validation metric code blocks are posted below.

Loss Class

class sce_dsc(losses.Loss):
    def __init__(self, scale_sce=1.0, scale_dsc=1.0, sample_weight = None, epsilon=0.01, name=None):
        super(sce_dsc, self).__init__()
        self.sce = losses.SparseCategoricalCrossentropy(from_logits=False) #while the last layer activation is sigmoid, logits needs to be false
        self.epsilon = epsilon
        self.scale_a = scale_sce
        self.scale_b = scale_dsc
        self.cls = 1
        self.weights = sample_weight

    def dsc(self, y_true, y_pred, sample_weight = None):
        
        true = tf.cast(y_true[..., 0] == self.cls, tf.int64)
        pred = tf.nn.softmax(y_pred, axis=-1)[..., self.cls]
        if self.weights is not None:
            #true = true * (sample_weight[...])
            true = true & (sample_weight[...] !=0)
            #pred = pred * (sample_weight[...])
            pred = pred & (sample_weight[...] !=0)
        A = tf.math.reduce_sum(tf.cast(true, tf.float32) * tf.cast(pred,tf.float32)) * 2
        B = tf.cast(tf.math.reduce_sum(true), tf.float32) + tf.cast(tf.math.reduce_sum(pred),tf.float32) + self.epsilon
        
        return (1.0 - A/B) 

    def call(self, y_true, y_pred):
        sce_loss = self.sce(y_true=y_true, y_pred=y_pred, sample_weight=self.weights) * self.scale_a
        dsc_loss = self.dsc(y_true=y_true, y_pred=y_pred, sample_weight=self.weights) * self.scale_b
        loss = tf.cast(sce_loss, tf.float32) + tf.cast(dsc_loss,tf.float32)     
        #self.add_loss(loss)
        return loss```

Metric Class

    class custom_dice(keras.metrics.Metric):
    
       def __init__(self, name = "dsc", **kwargs):
           super(custom_dice,self).__init__(**kwargs)
           self.dice = self.add_weight(name = 'dice_coef', initializer = 'zeros')
        
       def update_state(self, y_true,y_pred, sample_weight = None):
           true = tf.cast(y_true[...,0] == 1, tf.int64)
           pred = tf.math.argmax(y_pred == 1 , axis=-1) 
           if sample_weight is not None:
            true = true * (sample_weight[...])
            pred = pred * (sample_weight[...])
   
           A = tf.math.count_nonzero(true & pred) * 2
           B = tf.math.count_nonzero(true) + tf.math.count_nonzero(pred)
           value = tf.math.divide_no_nan(tf.cast(A, tf.float32),tf.cast(B, tf.float32))
           self.dice.assign(value)
        
       def result(self):
           return self.dice
    
       def reset_state(self):
           self.dice.assign(0.0)

External Validation Dice

    def dsc(y_true, y_pred, sample_weight=None, c = 1):
       print(y_true.shape, y_pred.shape)
       true = tf.cast(y_true[...,0] == 1, tf.int64)
       pred = tf.math.argmax(y_pred== c , axis=-1) 
       print(true.shape,pred.shape)
       if sample_weight is not None:
           true = true * (sample_weight[...])
           pred = pred * (sample_weight[...])
   
       A = tf.math.count_nonzero(true & pred) * 2
       B = tf.math.count_nonzero(true) + tf.math.count_nonzero(pred)
       return A / B 
Asked By: zhilothebest

||

Answers:

The metric above runs into an issue of calculating NaN’s, or essentially 0 if the network does not predict anything on slices where there are none of the positive class. The rewritten code below fixes the issue:

 def dice(self, y_true,y_pred, epsilon = p['epsilon']):
    y_pred_arg = tf.math.argmax(y_pred, axis = -1)
    y_true_f = tf.cast(K.flatten(y_true), tf.int64)
    y_pred_f = tf.cast(K.flatten(y_pred_arg), tf.int64)
    intersection = tf.cast(K.sum(y_true_f * y_pred_f), tf.float32)
    dice = (2 * intersection + epsilon) / (tf.cast(K.sum(y_true_f), tf.float32) + tf.cast(K.sum(y_pred_f), tf.float32) + epsilon)
    return tf.cast(dice, tf.float32)

the epsilon is a smoothing factor. This helps prevent a situation to divide by 0. I personally found epsilon = 1e-2 to have the best results on my current network, but this is definitely a hyper-parameter that should be optimized for training.

Answered By: zhilothebest