TypeError: __init__() got an unexpected keyword argument 'name' when loading a model with Custom Layer

Question:

I made a custom layer in keras for reshaping the outputs of a CNN before feeding to ConvLSTM2D layer

class TemporalReshape(Layer):
    def __init__(self,batch_size,num_patches):
        super(TemporalReshape,self).__init__()
        self.batch_size = batch_size
        self.num_patches = num_patches

    def call(self,inputs):
        nshape = (self.batch_size,self.num_patches)+inputs.shape[1:]
        return tf.reshape(inputs, nshape)

    def get_config(self):
        config = super().get_config().copy()
        config.update({'batch_size':self.batch_size,'num_patches':self.num_patches})
        return config

When I try to load the best model using

model = tf.keras.models.load_model('/content/saved_models/model_best.h5',custom_objects={'TemporalReshape':TemporalReshape})

I get the error

TypeError                                 Traceback (most recent call last)
<ipython-input-83-40b46da33e91> in <module>()
----> 1 model = tf.keras.models.load_model('/content/saved_models/model_best.h5',custom_objects={'TemporalReshape':TemporalReshape})


/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/save.py in load_model(filepath, custom_objects, compile, options)
    180     if (h5py is not None and (
    181         isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
--> 182       return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)
    183 
    184     filepath = path_to_string(filepath)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/hdf5_format.py in load_model_from_hdf5(filepath, custom_objects, compile)
    176     model_config = json.loads(model_config.decode('utf-8'))
    177     model = model_config_lib.model_from_config(model_config,
--> 178                                                custom_objects=custom_objects)
    179 
    180     # set weights

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/model_config.py in model_from_config(config, custom_objects)
     53                     '`Sequential.from_config(config)`?')
     54   from tensorflow.python.keras.layers import deserialize  # pylint: disable=g-import-not-at-top
---> 55   return deserialize(config, custom_objects=custom_objects)
     56 
     57 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
    173       module_objects=LOCAL.ALL_OBJECTS,
    174       custom_objects=custom_objects,
--> 175       printable_module_name='layer')

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    356             custom_objects=dict(
    357                 list(_GLOBAL_CUSTOM_OBJECTS.items()) +
--> 358                 list(custom_objects.items())))
    359       with CustomObjectScope(custom_objects):
    360         return cls.from_config(cls_config)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in from_config(cls, config, custom_objects)
    615     """
    616     input_tensors, output_tensors, created_layers = reconstruct_from_config(
--> 617         config, custom_objects)
    618     model = cls(inputs=input_tensors, outputs=output_tensors,
    619                 name=config.get('name'))

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in reconstruct_from_config(config, custom_objects, created_layers)
   1202   # First, we create all layers and enqueue nodes to be processed
   1203   for layer_data in config['layers']:
-> 1204     process_layer(layer_data)
   1205   # Then we process nodes in order of layer depth.
   1206   # Nodes that cannot yet be processed (if the inbound node

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in process_layer(layer_data)
   1184       from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
   1185 
-> 1186       layer = deserialize_layer(layer_data, custom_objects=custom_objects)
   1187       created_layers[layer_name] = layer
   1188 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
    173       module_objects=LOCAL.ALL_OBJECTS,
    174       custom_objects=custom_objects,
--> 175       printable_module_name='layer')

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    358                 list(custom_objects.items())))
    359       with CustomObjectScope(custom_objects):
--> 360         return cls.from_config(cls_config)
    361     else:
    362       # Then `cls` may be a function returning a class.

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in from_config(cls, config)
    695         A layer instance.
    696     """
--> 697     return cls(**config)
    698 
    699   def compute_output_shape(self, input_shape):

TypeError: __init__() got an unexpected keyword argument 'name'

When building the model, I used the custom layer like the following :

x = TemporalReshape(batch_size = 8, num_patches = 16)(x)

What is causing the error and how to load the model without this error?

Asked By: Siladittya

||

Answers:

Based on the error message only, I would suggest putting **kwargs in __init__. This object will then accept any other keyword argument that you haven’t included.

def __init__(self, batch_size, num_patches, **kwargs):
        super(TemporalReshape, self).__init__(**kwargs) # <--- must, thanks https://stackoverflow.com/users/349130/dr-snoopy
        self.batch_size = batch_size
        self.num_patches = num_patches
Answered By: Nicolas Gervais

Insert **kwargs to __init__() function.

Error message: "TypeError: __init__() missing 3 required positional arguments: 'batch_size', 'num_patches'"

Answered By: HongKiem