Passing `training=true` when using Tensorflow 2's Keras Functional API

Question:

When operating in graph mode in TF1, I believe I needed to wire up training=True and training=False via feeddicts when I was using the functional-style API. What is the proper way to do this in TF2?

I believe this is automatically handled when using tf.keras.Sequential. For example, I don’t need to specify training in the following example from the docs:

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10, activation='softmax')
])

# Model is the full model w/o custom layers
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)
print("Loss {:0.4f}, Accuracy {:0.4f}".format(loss, acc))

Can I also assume that keras automagically handles this when training with the functional api? Here is the same model, rewritten using the function api:

inputs = tf.keras.Input(shape=((28,28,1)), name="input_image")
hid = tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1))(inputs)
hid = tf.keras.layers.MaxPooling2D()(hid)
hid = tf.keras.layers.Flatten()(hid)
hid = tf.keras.layers.Dropout(0.1)(hid)
hid = tf.keras.layers.Dense(64, activation='relu')(hid)
hid = tf.keras.layers.BatchNormalization()(hid)
outputs = tf.keras.layers.Dense(10, activation='softmax')(hid)
model_fn = tf.keras.Model(inputs=inputs, outputs=outputs)

# Model is the full model w/o custom layers
model_fn.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model_fn.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model_fn.evaluate(test_data)
print("Loss {:0.4f}, Accuracy {:0.4f}".format(loss, acc))

I’m unsure if hid = tf.keras.layers.BatchNormalization()(hid) needs to be hid = tf.keras.layers.BatchNormalization()(hid, training)?

A colab for these models can be found here.

Asked By: cosentiyes

||

Answers:

I realized that there is a bug in the BatchNormalization documentation [1] where the {{TRAINABLE_ATTRIBUTE_NOTE}} isn’t actually replaced with the intended note [2]:

About setting layer.trainable = False on a BatchNormalization layer:
The meaning of setting layer.trainable = False is to freeze the layer,
i.e. its internal state will not change during training:
its trainable weights will not be updated
during fit() or train_on_batch(), and its state updates will not be run.
Usually, this does not necessarily mean that the layer is run in inference
mode (which is normally controlled by the training argument that can
be passed when calling a layer). “Frozen state” and “inference mode”
are two separate concepts.

However, in the case of the BatchNormalization layer, setting
trainable = False on the layer means that the layer will be
subsequently run in inference mode
(meaning that it will use
the moving mean and the moving variance to normalize the current batch,
rather than using the mean and variance of the current batch).
This behavior has been introduced in TensorFlow 2.0, in order
to enable layer.trainable = False to produce the most commonly
expected behavior in the convnet fine-tuning use case.
Note that:

  • This behavior only occurs as of TensorFlow 2.0. In 1.*,
    setting layer.trainable = False would freeze the layer but would
    not switch it to inference mode.
  • Setting trainable on an model containing other layers will
    recursively set the trainable value of all inner layers.
  • If the value of the trainable
    attribute is changed after calling compile() on a model,
    the new value doesn’t take effect for this model
    until compile() is called again.

[1] https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization?version=stable

[2] https://github.com/tensorflow/tensorflow/blob/r2.0/tensorflow/python/keras/layers/normalization_v2.py#L26-L65

Answered By: cosentiyes

As for the original broader question of whether you have to manually pass the training flag when using Keras Functional API, this example from the official docs suggests that you should not:

# ...

x = Dropout(0.5)(x)
outputs = Linear(10)(x)
model = tf.keras.Model(inputs, outputs)

# ...

# You can pass a `training` argument in `__call__`
# (it will get passed down to the Dropout layer).
y = model(tf.ones((2, 16)), training=True)
Answered By: Ben Usman

Batch normalization applies a transformation that maintains the mean output close to 0 and the output standard deviation close to 1. Importantly, batch normalization works differently during training and during inference. According to the keras documentation,

During training (i.e. when using fit() or when calling the layer/model with the argument training=True), the layer normalizes its output using the mean and standard deviation of the current batch of inputs. That is to say, for each channel being normalized, the layer returns (batch - mean(batch)) / (var(batch) + epsilon) * gamma + beta, where:

  • epsilon is a small constant (configurable as part of the constructor arguments)
  • gamma is a learned scaling factor (initialized as 1), which can be disabled by passing scale=False to the constructor.
  • beta is a learned offset factor (initialized as 0), which can be disabled by passing center=False to the constructor.

During inference (i.e. when using evaluate() or predict() or when calling the layer/model with the argument training=False (which is the default), the layer normalizes its output using a moving average of the mean and standard deviation of the batches it has seen during training. That is to say, it returns (batch - self.moving_mean) / (self.moving_var + epsilon) * gamma + beta.

self.moving_mean and self.moving_var are non-trainable variables that are updated each time the layer in called in training mode, as such:

  • moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)
  • moving_var = moving_var * momentum + var(batch) * (1 - momentum)
    As such, the layer will only normalize its inputs during inference after having been trained on data that has similar statistics as the inference data.
Answered By: Innat