How to call loaded Tensorflow model with custom params in call method?


I’ve been following the Tensorflow text generation tutorial which includes two models, "MyModel" and "OneStep". "MyModel" is an RNN operating on vectorized strings; "OneStep" essentially wraps "MyModel" and operates on strings directly.

The tutorial saves and loads "OneStep", and I followed this successfully, but I now want to save and reload "MyModel". This is not done in the tutorial, and when I try to call the reloaded model with return_state=True, I get an error:

ValueError                                Traceback (most recent call last)
/tmp/ipykernel_23/ in <module>
      1 # TODO: Loaded model gives an error
      2 for input_example_batch, target_example_batch in train_ds.take(1):
----> 3     example_batch_predictions, example_states = loaded_model(input_example_batch, False, None, return_state=True)
      4     print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")
      5     print(example_states.shape, "   # (batch_size, rnn_units)")

/opt/conda/lib/python3.7/site-packages/tensorflow/python/saved_model/ in _call_attribute(instance, *args, **kwargs)
    663 def _call_attribute(instance, *args, **kwargs):
--> 664   return instance.__call__(*args, **kwargs)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/ in __call__(self, *args, **kwds)
    884       with OptionalXlaContext(self._jit_compile):
--> 885         result = self._call(*args, **kwds)
    887       new_tracing_count = self.experimental_get_tracing_count()

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/ in _call(self, *args, **kwds)
    931       # This is the first call of __call__, so we have to initialize.
    932       initializers = []
--> 933       self._initialize(args, kwds, add_initializers_to=initializers)
    934     finally:
    935       # At this point we know that the initialization is complete (or less

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/ in _initialize(self, args, kwds, add_initializers_to)
    758     self._concrete_stateful_fn = (
    759         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 760             *args, **kwds))
    762     def invalid_creator_scope(*unused_args, **unused_kwds):

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/ in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   3064       args, kwargs = None, None
   3065     with self._lock:
-> 3066       graph_function, _ = self._maybe_define_function(args, kwargs)
   3067     return graph_function

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/ in _maybe_define_function(self, args, kwargs)
   3462           self._function_cache.missed.add(call_context_key)
-> 3463           graph_function = self._create_graph_function(args, kwargs)
   3464           self._function_cache.primary[cache_key] = graph_function

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/ in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3306             arg_names=arg_names,
   3307             override_flat_arg_shapes=override_flat_arg_shapes,
-> 3308             capture_by_value=self._capture_by_value),
   3309         self._function_attributes,
   3310         function_spec=self.function_spec,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)
   1005         _, original_func = tf_decorator.unwrap(python_func)
-> 1007       func_outputs = python_func(*func_args, **func_kwargs)
   1009       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/ in wrapped_fn(*args, **kwds)
    666         # the function a weak reference to itself to avoid a reference cycle.
    667         with OptionalXlaContext(compile_with_xla):
--> 668           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    669         return out

/opt/conda/lib/python3.7/site-packages/tensorflow/python/saved_model/ in restored_function_body(*args, **kwargs)
    292         .format(_pretty_format_positional(args), kwargs,
    293                 len(saved_function.concrete_functions),
--> 294                 "nn".join(signature_descriptions)))
    296   concrete_function_objects = []

ValueError: Could not find matching function to call loaded from the SavedModel. Got:
  Positional arguments (4 total):
    * Tensor("inputs:0", shape=(64, 113), dtype=int64)
    * False
    * None
    * True
  Keyword arguments: {}

Expected these arguments to match one of the following 4 option(s):

Option 1:
  Positional arguments (4 total):
    * TensorSpec(shape=(None, 113), dtype=tf.int64, name='input_1')
    * False
    * None
    * False
  Keyword arguments: {}

Option 2:
  Positional arguments (4 total):
    * TensorSpec(shape=(None, 113), dtype=tf.int64, name='inputs')
    * False
    * None
    * False
  Keyword arguments: {}

Option 3:
  Positional arguments (4 total):
    * TensorSpec(shape=(None, 113), dtype=tf.int64, name='inputs')
    * True
    * None
    * False
  Keyword arguments: {}

Option 4:
  Positional arguments (4 total):
    * TensorSpec(shape=(None, 113), dtype=tf.int64, name='input_1')
    * True
    * None
    * False
  Keyword arguments: {}

I think this is due to the custom parameter in the call method. Here is a minimal example that reproduces the problem:

import tensorflow as tf

class CustomModel(tf.keras.models.Model):
    def __init__(self):
        self.dense = tf.keras.layers.Dense(10)
    def call(self, inputs, custom_param=False):
        return self.dense(inputs)

model = CustomModel()

sample_inputs = tf.zeros((16, 30))
print('Sample inputs:', sample_inputs)

sample_outputs = model(sample_inputs)
print('Sample outputs:', sample_outputs)'saved_model')
loaded_model = tf.keras.models.load_model('saved_model')

sample_outputs_2 = loaded_model(sample_inputs, custom_param=True)
print('Sample outputs 2:', sample_outputs_2)

Calling the reloaded model with custom_param taking any value other than the default always seems to fail.

Is this a bug or by design? How can I modify the model so that it returns just the output sequence while training, but returns the output sequence and state at inference time? This is so I can feed the state back into the model and generate additional characters during inference.

Asked By: A Kubiesa



I figured it out. The solution was in the TensorFlow docs, albeit not very clearly.

With the above code, the loaded model is of type keras.saving.saved_model.load.CustomModel, which is not the same as the original type. To get back the original type, you need to do the following.

The CustomModel class needs the get_config and from_config methods.

class CustomModel(tf.keras.models.Model):
    def __init__(self):
        self.dense = tf.keras.layers.Dense(10)
    def call(self, inputs, custom_param=False):
        return self.dense(inputs)

    def get_config(self):
        return {} # Any parameters originally passed to __init__ should go here.

    def from_config(cls, config):
        return cls(**config)

When loading the model, you need to pass the custom class in the custom_objects dictionary.

loaded_model = tf.keras.models.load_model('saved_model', custom_objects={'CustomModel': CustomModel})

Then, loaded_model is of type CustomModel and calling it with custom_param=True works.

Answered By: A Kubiesa