tf.keras.layers.BatchNormalization with trainable=False appears to not update its internal moving mean and variance

Question:

I am trying to find out, how exactly does BatchNormalization layer behave in TensorFlow. I came up with the following piece of code which to the best of my knowledge should be a perfectly valid keras model, however the mean and variance of BatchNormalization doesn’t appear to be updated.

From docs https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization

in the case of the BatchNormalization layer, setting trainable = False on the layer means that the layer will be subsequently run in inference mode (meaning that it will use the moving mean and the moving variance to normalize the current batch, rather than using the mean and variance of the current batch).

I expect the model to return a different value with each subsequent predict call.
What I see, however, are the exact same values returned 10 times.
Can anyone explain to me why does the BatchNormalization layer not update its internal values?

import tensorflow as tf
import numpy as np

if __name__ == '__main__':

    np.random.seed(1)
    x = np.random.randn(3, 5) * 5 + 0.3

    bn = tf.keras.layers.BatchNormalization(trainable=False, epsilon=1e-9)
    z = input = tf.keras.layers.Input([5])
    z = bn(z)

    model = tf.keras.Model(inputs=input, outputs=z)

    for i in range(10):
        print(x)
        print(model.predict(x))
        print()

I use TensorFlow 2.1.0

Asked By: Addy

||

Answers:

Okay, I found the mistake in my assumptions. The moving average is being updated during training not during inference as I thought. This makes perfect sense, as updating the moving averages during inference would likely result in an unstable production model (for example a long sequence of highly pathological input samples [e.g. such that their generating distribution differs drastically from the one on which the network was trained] could potentially bias the network and result in worse performance on valid input samples).

The trainable parameter is useful when you’re fine-tuning a pretrained model and want to freeze some of the layers of the network even during training. Because when you call model.predict(x) (or even model(x) or model(x, training=False)), the layer automatically uses the moving averages instead of batch averages.

The code below demonstrates this clearly

import tensorflow as tf
import numpy as np

if __name__ == '__main__':

    np.random.seed(1)
    x = np.random.randn(10, 5) * 5 + 0.3

    z = input = tf.keras.layers.Input([5])
    z = tf.keras.layers.BatchNormalization(trainable=True, epsilon=1e-9, momentum=0.99)(z)

    model = tf.keras.Model(inputs=input, outputs=z)
    
    # a dummy loss function
    model.compile(loss=lambda x, y: (x - y) ** 2)

    # a dummy fit just to update the batchnorm moving averages
    model.fit(x, x, batch_size=3, epochs=10)
    
    # first predict uses the moving averages from training
    pred = model(x).numpy()
    print(pred.mean(axis=0))
    print(pred.var(axis=0))
    print()
    
    # outputs the same thing as previous predict
    pred = model(x).numpy()
    print(pred.mean(axis=0))
    print(pred.var(axis=0))
    print()
    
    # here calling the model with training=True results in update of moving averages
    # furthermore, it uses the batch mean and variance as in training, 
    # so the result is very different
    pred = model(x, training=True).numpy()
    print(pred.mean(axis=0))
    print(pred.var(axis=0))
    print()
    
    # here we see again that the moving averages are used but they differ slightly after
    # the previous call, as expected
    pred = model(x).numpy()
    print(pred.mean(axis=0))
    print(pred.var(axis=0))
    print()

In the end, I found that the documentation (https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization) mentions this:

  1. When performing inference using a model containing batch normalization, it is generally (though not always) desirable to use accumulated statistics rather than mini-batch statistics. This is accomplished by passing training=False when calling the model, or using model.predict.

Hopefully this will help someone with similar misunderstanding in the future.

Answered By: Addy