How to remove training=True from the inbound nodes of a layer in an existing model?

Question:

Assuming there is a model given as an h5 file, i.e., I can not change the code building the model’s architecture:

from tensorflow.keras.layers import Input, BatchNormalization
from tensorflow.keras.models import Model

inputs = Input(shape=(4,))
outputs = BatchNormalization()(inputs, training=True)
model = Model(inputs=inputs, outputs=outputs)
model.save('model.h5', include_optimizer=False)

Now I’d like to remove the training=True part, i.e., I want to the BatchNormalization as if it was attached to the model without this flag.

My current attempt looks as follows:

import numpy as np
from tensorflow.keras.models import load_model

model = load_model('model.h5')

for layer in model.layers:
    for node in layer.inbound_nodes:
        if "training" in node.call_kwargs:
            del node.call_kwargs["training"]

model.predict(np.asarray([[1, 2, 3, 4]]))

But the model.predict calls fails with the following error (I’m using TensorFlow 2.5.0):

ValueError: Could not pack sequence. Structure had 1 elements, but flat_sequence had 2 elements.  Structure: ((<KerasTensor: shape=(None, 4) dtype=float32 (created by layer 'input_1')>,), {}), flat_sequence: [<tf.Tensor 'model/Cast:0' shape=(None, 4) dtype=float32>, True].

How can this be fixed/worked around?

(When using node.call_kwargs["training"] = False instead of del node.call_kwargs["training"] then model.predict does not crash, but it simply behaves as if nothing was changed, i.e., the modified flag is ignored.)

Asked By: Tobias Hermann

||

Answers:

have you tried

for layer in model.layers:
    layer.trainable=False
Answered By: Gerry P

I found simply saving and re-loading the model again after modifying the call_kwargs helps.

import numpy as np
from tensorflow.keras.models import load_model

model = load_model('model.h5')

# Removing training=True
for layer in model.layers:
    for node in layer.inbound_nodes:
        if "training" in node.call_kwargs:
            del node.call_kwargs["training"]

# The two following lines are the solution.
model.save('model_modified.h5', include_optimizer=False)
model = load_model('model_modified.h5')

model.predict(np.asarray([[1, 2, 3, 4]]))

And all is fine. 🙂

Answered By: Tobias Hermann