How to prevent Keras from computing metrics during training
Question:
I’m using Tensorflow/Keras 2.4.1 and I have a (unsupervised) custom metric that takes several of my model inputs as parameters such as:
model = build_model() # returns a tf.keras.Model object
my_metric = custom_metric(model.output, model.input[0], model.input[1])
model.add_metric(my_metric)
[...]
model.fit([...]) # training with fit
However, it happens that custom_metric
is very expensive so I would like it to be computed during validation only. I found this answer but I hardly understand how I can adapt the solution to my metric that uses several model inputs as parameter since the update_state
method doesn’t seem flexible.
In my context, is there a way to avoid computing my metric during training, aside from writing my own training loop ?
Also, I am very surprised we cannot natively specify to Tensorflow that some metrics should only be computed at validation time, is there a reason for that ?
In addition, since the model is trained to optimize the loss, and that the training dataset should not be used to evaluate a model, I don’t even understand why, by default, Tensorflow computes metrics during training.
Answers:
I was able to use learning_phase
but only in symbolic tensor mode (graph) mode:
So, at first we need to disable eager mode (this must be done right after importing tensorflow):
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
Then you can create your metric using a symbolic if (backend.switch
):
def metric_graph(in1, in2, out):
actual_metric = out * (in1 + in2)
return K.switch(K.learning_phase(), tf.zeros((1,)), actual_metric)
The method add_metric
will ask for a name and an aggregation method, which you can set to "mean"
.
So, here is one example:
x1 = numpy.ones((5,3))
x2 = numpy.ones((5,3))
y = 3*numpy.ones((5,1))
vx1 = numpy.ones((5,3))
vx2 = numpy.ones((5,3))
vy = 3*numpy.ones((5,1))
def metric_eager(in1, in2, out):
if (K.learning_phase()):
return 0
else:
return out * (in1 + in2)
def metric_graph(in1, in2, out):
actual_metric = out * (in1 + in2)
return K.switch(K.learning_phase(), tf.zeros((1,)), actual_metric)
ins1 = Input((3,))
ins2 = Input((3,))
outs = Concatenate()([ins1, ins2])
outs = Dense(1)(outs)
model = Model([ins1, ins2],outs)
model.add_metric(metric_graph(ins1, ins2, outs), name='my_metric', aggregation='mean')
model.compile(loss='mse', optimizer='adam')
model.fit([x1, x2],y, validation_data=([vx1, vx2], vy), epochs=3)
I think that the simplest solution to compute a metric only on the validation is using a custom callback.
here we define our dummy callback:
class MyCustomMetricCallback(tf.keras.callbacks.Callback):
def __init__(self, train=None, validation=None):
super(MyCustomMetricCallback, self).__init__()
self.train = train
self.validation = validation
def on_epoch_end(self, epoch, logs={}):
mse = tf.keras.losses.mean_squared_error
if self.train:
logs['my_metric_train'] = float('inf')
X_train, y_train = self.train[0], self.train[1]
y_pred = self.model.predict(X_train)
score = mse(y_train, y_pred)
logs['my_metric_train'] = np.round(score, 5)
if self.validation:
logs['my_metric_val'] = float('inf')
X_valid, y_valid = self.validation[0], self.validation[1]
y_pred = self.model.predict(X_valid)
val_score = mse(y_pred, y_valid)
logs['my_metric_val'] = np.round(val_score, 5)
Given this dummy model:
def build_model():
inp1 = Input((5,))
inp2 = Input((5,))
out = Concatenate()([inp1, inp2])
out = Dense(1)(out)
model = Model([inp1, inp2], out)
model.compile(loss='mse', optimizer='adam')
return model
and this data:
X_train1 = np.random.uniform(0,1, (100,5))
X_train2 = np.random.uniform(0,1, (100,5))
y_train = np.random.uniform(0,1, (100,1))
X_val1 = np.random.uniform(0,1, (100,5))
X_val2 = np.random.uniform(0,1, (100,5))
y_val = np.random.uniform(0,1, (100,1))
you can use the custom callback to compute the metric both on train and validation:
model = build_model()
model.fit([X_train1, X_train2], y_train, epochs=10,
callbacks=[MyCustomMetricCallback(train=([X_train1, X_train2],y_train), validation=([X_val1, X_val2],y_val))])
only on validation:
model = build_model()
model.fit([X_train1, X_train2], y_train, epochs=10,
callbacks=[MyCustomMetricCallback(validation=([X_val1, X_val2],y_val))])
only on train:
model = build_model()
model.fit([X_train1, X_train2], y_train, epochs=10,
callbacks=[MyCustomMetricCallback(train=([X_train1, X_train2],y_train))])
remember only that the callback evaluates the metrics one-shot on the data, like any metric/loss computed by default by keras on the validation_data
.
here is the running code.
Since the metrics are being run within the train_step
function of keras.Model
, filtering out train disabled metrics without altering the API requires to subclass keras.Model
.
We define a simple metric wrapper:
class TrainDisabledMetric(Metric):
def __init__(self, metric: Metric):
super().__init__(name=metric.name)
self._metric = metric
def update_state(self, *args, **kwargs):
return self._metric.update_state(*args, **kwargs)
def reset_state(self):
return self._metric.reset_state()
def result(self):
return self._metric.result()
and subclass keras.Model
to filter out those metrics during train:
class CustomModel(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def compile(self, optimizer='rmsprop', loss=None, metrics=None,
loss_weights=None, weighted_metrics=None, run_eagerly=None,
steps_per_execution=None, jit_compile=None, **kwargs):
from_serialized = kwargs.get('from_serialized', False)
super().compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights,
weighted_metrics=weighted_metrics, run_eagerly=run_eagerly,
steps_per_execution=steps_per_execution,
jit_compile=jit_compile, **kwargs)
self.on_train_compiled_metrics = self.compiled_metrics
if metrics is not None:
def get_on_train_traverse_tree(structure):
flat = tf.nest.flatten(structure)
on_train = [not isinstance(e, TrainDisabledMetric) for e in flat]
full_tree = tf.nest.pack_sequence_as(structure, on_train)
return get_traverse_shallow_structure(lambda s: any(tf.nest.flatten(s)),
full_tree)
on_train_sub_tree = get_on_train_traverse_tree(metrics)
flat_on_train = flatten_up_to(on_train_sub_tree, metrics)
def clean_tree(tree):
if isinstance(tree, list):
_list = []
for t in tree:
r = clean_tree(t)
if r:
_list.append(r)
return _list
elif isinstance(tree, dict):
_tree = {}
for k, v in tree.items():
r = clean_tree(v)
if r:
_tree[k] = r
return _tree
else:
return tree
pruned_on_train_sub_tree = clean_tree(on_train_sub_tree)
pruned_flat_on_train = [m for keep, m in
zip(tf.nest.flatten(on_train_sub_tree),
flat_on_train) if keep]
on_train_metrics = tf.nest.pack_sequence_as(pruned_on_train_sub_tree,
pruned_flat_on_train)
self.on_train_compiled_metrics = compile_utils.MetricsContainer(
on_train_metrics, weighted_metrics=None, output_names=self.output_names,
from_serialized=from_serialized)
def train_step(self, data):
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
# Run forward pass.
with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compute_loss(x, y, y_pred, sample_weight)
self._validate_target_and_loss(y, loss)
# Run backwards pass.
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
return self.compute_metrics(x, y, y_pred, sample_weight, training=True)
def compute_metrics(self, x, y, y_pred, sample_weight, training=False):
del x # The default implementation does not use `x`.
if training:
self.on_train_compiled_metrics.update_state(y, y_pred, sample_weight)
metrics = self.on_train_metrics
else:
self.compiled_metrics.update_state(y, y_pred, sample_weight)
metrics = self.metrics
# Collect metrics to return
return_metrics = {}
for metric in metrics:
result = metric.result()
if isinstance(result, dict):
return_metrics.update(result)
else:
return_metrics[metric.name] = result
return return_metrics
@property
def on_train_metrics(self):
metrics = []
if self._is_compiled:
# TODO(omalleyt): Track `LossesContainer` and `MetricsContainer` objects
# so that attr names are not load-bearing.
if self.compiled_loss is not None:
metrics += self.compiled_loss.metrics
if self.on_train_compiled_metrics is not None:
metrics += self.on_train_compiled_metrics.metrics
for l in self._flatten_layers():
metrics.extend(l._metrics) # pylint: disable=protected-access
return metrics
Now given a keras model, we can wrap it and compile it with train disabled metrics:
model: keras.Model = ...
custom_model = CustomModel(inputs=model.input, outputs=model.output)
train_enabled_metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
# wrap train disabled metrics with `TrainDisabledMetric`:
train_disabled_metrics = [
TrainDisabledMetric(tf.keras.metrics.SparseCategoricalCrossentropy())]
metrics = train_enabled_metrics + train_disabled_metrics
custom_model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True), metrics=metrics, )
custom_model.fit(ds_train, epochs=6, validation_data=ds_test, )
The metric SparseCategoricalCrossentropy
is computed only during validation:
Epoch 1/6
469/469 [==============================] - 2s 2ms/step - loss: 0.3522 - sparse_categorical_accuracy: 0.8366 - val_loss: 0.1978 - val_sparse_categorical_accuracy: 0.9086 - val_sparse_categorical_crossentropy: 1.3197
Epoch 2/6
469/469 [==============================] - 1s 1ms/step - loss: 0.1631 - sparse_categorical_accuracy: 0.9526 - val_loss: 0.1429 - val_sparse_categorical_accuracy: 0.9587 - val_sparse_categorical_crossentropy: 1.1910
Epoch 3/6
469/469 [==============================] - 1s 1ms/step - loss: 0.1178 - sparse_categorical_accuracy: 0.9654 - val_loss: 0.1139 - val_sparse_categorical_accuracy: 0.9661 - val_sparse_categorical_crossentropy: 1.1369
Epoch 4/6
469/469 [==============================] - 1s 1ms/step - loss: 0.0909 - sparse_categorical_accuracy: 0.9735 - val_loss: 0.0981 - val_sparse_categorical_accuracy: 0.9715 - val_sparse_categorical_crossentropy: 1.0434
Epoch 5/6
469/469 [==============================] - 1s 1ms/step - loss: 0.0735 - sparse_categorical_accuracy: 0.9784 - val_loss: 0.0913 - val_sparse_categorical_accuracy: 0.9721 - val_sparse_categorical_crossentropy: 0.9862
Epoch 6/6
469/469 [==============================] - 1s 1ms/step - loss: 0.0606 - sparse_categorical_accuracy: 0.9823 - val_loss: 0.0824 - val_sparse_categorical_accuracy: 0.9761 - val_sparse_categorical_crossentropy: 1.0024
I’m using Tensorflow/Keras 2.4.1 and I have a (unsupervised) custom metric that takes several of my model inputs as parameters such as:
model = build_model() # returns a tf.keras.Model object
my_metric = custom_metric(model.output, model.input[0], model.input[1])
model.add_metric(my_metric)
[...]
model.fit([...]) # training with fit
However, it happens that custom_metric
is very expensive so I would like it to be computed during validation only. I found this answer but I hardly understand how I can adapt the solution to my metric that uses several model inputs as parameter since the update_state
method doesn’t seem flexible.
In my context, is there a way to avoid computing my metric during training, aside from writing my own training loop ?
Also, I am very surprised we cannot natively specify to Tensorflow that some metrics should only be computed at validation time, is there a reason for that ?
In addition, since the model is trained to optimize the loss, and that the training dataset should not be used to evaluate a model, I don’t even understand why, by default, Tensorflow computes metrics during training.
I was able to use learning_phase
but only in symbolic tensor mode (graph) mode:
So, at first we need to disable eager mode (this must be done right after importing tensorflow):
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
Then you can create your metric using a symbolic if (backend.switch
):
def metric_graph(in1, in2, out):
actual_metric = out * (in1 + in2)
return K.switch(K.learning_phase(), tf.zeros((1,)), actual_metric)
The method add_metric
will ask for a name and an aggregation method, which you can set to "mean"
.
So, here is one example:
x1 = numpy.ones((5,3))
x2 = numpy.ones((5,3))
y = 3*numpy.ones((5,1))
vx1 = numpy.ones((5,3))
vx2 = numpy.ones((5,3))
vy = 3*numpy.ones((5,1))
def metric_eager(in1, in2, out):
if (K.learning_phase()):
return 0
else:
return out * (in1 + in2)
def metric_graph(in1, in2, out):
actual_metric = out * (in1 + in2)
return K.switch(K.learning_phase(), tf.zeros((1,)), actual_metric)
ins1 = Input((3,))
ins2 = Input((3,))
outs = Concatenate()([ins1, ins2])
outs = Dense(1)(outs)
model = Model([ins1, ins2],outs)
model.add_metric(metric_graph(ins1, ins2, outs), name='my_metric', aggregation='mean')
model.compile(loss='mse', optimizer='adam')
model.fit([x1, x2],y, validation_data=([vx1, vx2], vy), epochs=3)
I think that the simplest solution to compute a metric only on the validation is using a custom callback.
here we define our dummy callback:
class MyCustomMetricCallback(tf.keras.callbacks.Callback):
def __init__(self, train=None, validation=None):
super(MyCustomMetricCallback, self).__init__()
self.train = train
self.validation = validation
def on_epoch_end(self, epoch, logs={}):
mse = tf.keras.losses.mean_squared_error
if self.train:
logs['my_metric_train'] = float('inf')
X_train, y_train = self.train[0], self.train[1]
y_pred = self.model.predict(X_train)
score = mse(y_train, y_pred)
logs['my_metric_train'] = np.round(score, 5)
if self.validation:
logs['my_metric_val'] = float('inf')
X_valid, y_valid = self.validation[0], self.validation[1]
y_pred = self.model.predict(X_valid)
val_score = mse(y_pred, y_valid)
logs['my_metric_val'] = np.round(val_score, 5)
Given this dummy model:
def build_model():
inp1 = Input((5,))
inp2 = Input((5,))
out = Concatenate()([inp1, inp2])
out = Dense(1)(out)
model = Model([inp1, inp2], out)
model.compile(loss='mse', optimizer='adam')
return model
and this data:
X_train1 = np.random.uniform(0,1, (100,5))
X_train2 = np.random.uniform(0,1, (100,5))
y_train = np.random.uniform(0,1, (100,1))
X_val1 = np.random.uniform(0,1, (100,5))
X_val2 = np.random.uniform(0,1, (100,5))
y_val = np.random.uniform(0,1, (100,1))
you can use the custom callback to compute the metric both on train and validation:
model = build_model()
model.fit([X_train1, X_train2], y_train, epochs=10,
callbacks=[MyCustomMetricCallback(train=([X_train1, X_train2],y_train), validation=([X_val1, X_val2],y_val))])
only on validation:
model = build_model()
model.fit([X_train1, X_train2], y_train, epochs=10,
callbacks=[MyCustomMetricCallback(validation=([X_val1, X_val2],y_val))])
only on train:
model = build_model()
model.fit([X_train1, X_train2], y_train, epochs=10,
callbacks=[MyCustomMetricCallback(train=([X_train1, X_train2],y_train))])
remember only that the callback evaluates the metrics one-shot on the data, like any metric/loss computed by default by keras on the validation_data
.
here is the running code.
Since the metrics are being run within the train_step
function of keras.Model
, filtering out train disabled metrics without altering the API requires to subclass keras.Model
.
We define a simple metric wrapper:
class TrainDisabledMetric(Metric):
def __init__(self, metric: Metric):
super().__init__(name=metric.name)
self._metric = metric
def update_state(self, *args, **kwargs):
return self._metric.update_state(*args, **kwargs)
def reset_state(self):
return self._metric.reset_state()
def result(self):
return self._metric.result()
and subclass keras.Model
to filter out those metrics during train:
class CustomModel(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def compile(self, optimizer='rmsprop', loss=None, metrics=None,
loss_weights=None, weighted_metrics=None, run_eagerly=None,
steps_per_execution=None, jit_compile=None, **kwargs):
from_serialized = kwargs.get('from_serialized', False)
super().compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights,
weighted_metrics=weighted_metrics, run_eagerly=run_eagerly,
steps_per_execution=steps_per_execution,
jit_compile=jit_compile, **kwargs)
self.on_train_compiled_metrics = self.compiled_metrics
if metrics is not None:
def get_on_train_traverse_tree(structure):
flat = tf.nest.flatten(structure)
on_train = [not isinstance(e, TrainDisabledMetric) for e in flat]
full_tree = tf.nest.pack_sequence_as(structure, on_train)
return get_traverse_shallow_structure(lambda s: any(tf.nest.flatten(s)),
full_tree)
on_train_sub_tree = get_on_train_traverse_tree(metrics)
flat_on_train = flatten_up_to(on_train_sub_tree, metrics)
def clean_tree(tree):
if isinstance(tree, list):
_list = []
for t in tree:
r = clean_tree(t)
if r:
_list.append(r)
return _list
elif isinstance(tree, dict):
_tree = {}
for k, v in tree.items():
r = clean_tree(v)
if r:
_tree[k] = r
return _tree
else:
return tree
pruned_on_train_sub_tree = clean_tree(on_train_sub_tree)
pruned_flat_on_train = [m for keep, m in
zip(tf.nest.flatten(on_train_sub_tree),
flat_on_train) if keep]
on_train_metrics = tf.nest.pack_sequence_as(pruned_on_train_sub_tree,
pruned_flat_on_train)
self.on_train_compiled_metrics = compile_utils.MetricsContainer(
on_train_metrics, weighted_metrics=None, output_names=self.output_names,
from_serialized=from_serialized)
def train_step(self, data):
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
# Run forward pass.
with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compute_loss(x, y, y_pred, sample_weight)
self._validate_target_and_loss(y, loss)
# Run backwards pass.
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
return self.compute_metrics(x, y, y_pred, sample_weight, training=True)
def compute_metrics(self, x, y, y_pred, sample_weight, training=False):
del x # The default implementation does not use `x`.
if training:
self.on_train_compiled_metrics.update_state(y, y_pred, sample_weight)
metrics = self.on_train_metrics
else:
self.compiled_metrics.update_state(y, y_pred, sample_weight)
metrics = self.metrics
# Collect metrics to return
return_metrics = {}
for metric in metrics:
result = metric.result()
if isinstance(result, dict):
return_metrics.update(result)
else:
return_metrics[metric.name] = result
return return_metrics
@property
def on_train_metrics(self):
metrics = []
if self._is_compiled:
# TODO(omalleyt): Track `LossesContainer` and `MetricsContainer` objects
# so that attr names are not load-bearing.
if self.compiled_loss is not None:
metrics += self.compiled_loss.metrics
if self.on_train_compiled_metrics is not None:
metrics += self.on_train_compiled_metrics.metrics
for l in self._flatten_layers():
metrics.extend(l._metrics) # pylint: disable=protected-access
return metrics
Now given a keras model, we can wrap it and compile it with train disabled metrics:
model: keras.Model = ...
custom_model = CustomModel(inputs=model.input, outputs=model.output)
train_enabled_metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
# wrap train disabled metrics with `TrainDisabledMetric`:
train_disabled_metrics = [
TrainDisabledMetric(tf.keras.metrics.SparseCategoricalCrossentropy())]
metrics = train_enabled_metrics + train_disabled_metrics
custom_model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True), metrics=metrics, )
custom_model.fit(ds_train, epochs=6, validation_data=ds_test, )
The metric SparseCategoricalCrossentropy
is computed only during validation:
Epoch 1/6
469/469 [==============================] - 2s 2ms/step - loss: 0.3522 - sparse_categorical_accuracy: 0.8366 - val_loss: 0.1978 - val_sparse_categorical_accuracy: 0.9086 - val_sparse_categorical_crossentropy: 1.3197
Epoch 2/6
469/469 [==============================] - 1s 1ms/step - loss: 0.1631 - sparse_categorical_accuracy: 0.9526 - val_loss: 0.1429 - val_sparse_categorical_accuracy: 0.9587 - val_sparse_categorical_crossentropy: 1.1910
Epoch 3/6
469/469 [==============================] - 1s 1ms/step - loss: 0.1178 - sparse_categorical_accuracy: 0.9654 - val_loss: 0.1139 - val_sparse_categorical_accuracy: 0.9661 - val_sparse_categorical_crossentropy: 1.1369
Epoch 4/6
469/469 [==============================] - 1s 1ms/step - loss: 0.0909 - sparse_categorical_accuracy: 0.9735 - val_loss: 0.0981 - val_sparse_categorical_accuracy: 0.9715 - val_sparse_categorical_crossentropy: 1.0434
Epoch 5/6
469/469 [==============================] - 1s 1ms/step - loss: 0.0735 - sparse_categorical_accuracy: 0.9784 - val_loss: 0.0913 - val_sparse_categorical_accuracy: 0.9721 - val_sparse_categorical_crossentropy: 0.9862
Epoch 6/6
469/469 [==============================] - 1s 1ms/step - loss: 0.0606 - sparse_categorical_accuracy: 0.9823 - val_loss: 0.0824 - val_sparse_categorical_accuracy: 0.9761 - val_sparse_categorical_crossentropy: 1.0024