Single loss function with multi-input multi-output model in Keras

Question:

I am trying to train a multi-input (3) multi-output (4) model using Keras and I need to use a SINGLE loss function that takes in all the output predictions. 2 of these outputs are my true model outputs that I care about and have corresponding labels, while the other 2 outputs are learnable parameters from within my model that I want to use to dynamically update the loss weights for my true model outputs.
I need something like this:

model.compile(optimizer=optimizer, loss = unified_loss

where the unified loss should have access to all my model outputs and corresponding labels. I am using tf.data.from_tensor_slices(...) to train.

The only workaround I have found is to use a custom training loop, which allows this. But, I lose a lot of functionality and callbacks become trickier to implement.

Is there a way to solve this using the regular model.compilt(...) and model.fit(...)?

Apart from a custom training loop, which is not preferred, I did try the standard approach of:

model.compile(optimizer=optimizer, loss = [loss1, loss2], loss_weights = [alpha, beta]

where I tried to make alpha and beta learnable parameters but this is not desired because I have a custom equation that is more involved than a simple weighted sum.

Asked By: samuraikmc

||

Answers:

Add a layer to your model that concats the losses into a single tensor/output. Have your custom loss parse out each of the four values and run the necessary math on them. During inference, run the model without the extra layer.

The pattern of having a slightly different model for training and inference is a common one.

Here is an example of the basic idea:

import tensorflow as tf

inp1 = tf.keras.Input((1,))
inp2 = tf.keras.Input((1,))
inp3 = tf.keras.Input((1,))

inputs = tf.keras.layers.Concatenate()([inp1, inp2, inp3])
out1 = tf.keras.layers.Dense(1)(inputs)
out2 = tf.keras.layers.Dense(1)(inputs)
out3 = tf.keras.layers.Dense(1)(inputs)
out4 = tf.keras.layers.Dense(1)(inputs)

model = tf.keras.Model([inp1, inp2, inp3], [out1, out2, out3, out4])

x1 = tf.convert_to_tensor([1])
x2 = tf.convert_to_tensor([1])
x3 = tf.convert_to_tensor([1])

model((x1, x2, x3))

outs = tf.stack([out1, out2, out3, out4])

training_model = tf.keras.Model([inp1, inp2, inp3], outs)

training_model((x1, x2, x3))

def exotic_loss(y_true, y_pred):
  true1, true2, true3 = tf.unstack(y_true)
  pred1, pred2, pred3 = tf.unstack(y_pred)
  return true1 + true2 + true3 + pred1 + pred2 + pred3

training_model.compile(loss=exotic_loss)
Answered By: Yaoshiang