keras variational autoencoder loss function

Question:

I’ve read this blog by Keras on VAE implementation, where VAE loss is defined this way:

def vae_loss(x, x_decoded_mean):
    xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
    kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
    return xent_loss + kl_loss

I looked at the Keras documentation and the VAE loss function is defined this way:
In this implementation, the reconstruction_loss is multiplied by original_dim, which I don’t see in the first implementation!

if args.mse:
        reconstruction_loss = mse(inputs, outputs)
    else:
        reconstruction_loss = binary_crossentropy(inputs,
                                                  outputs)

    reconstruction_loss *= original_dim
    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5
    vae_loss = K.mean(reconstruction_loss + kl_loss)
    vae.add_loss(vae_loss)

Can somebody please explain why? Thank you!

Asked By: pnaseri

||

Answers:

first_one: CE + mean(kl, axis=-1) = CE + sum(kl, axis=-1) / d

second_one: d * CE + sum(kl, axis=-1)

So:
first_one = second_one / d

And note that the second one returns the mean loss over all the samples, but the first one returns a vector of losses for all samples.

Answered By: mamaj

In VAE, the reconstruction loss function can be expressed as:

reconstruction_loss = - log(p ( x | z))

If the decoder output distribution is assumed to be Gaussian, then the loss function boils down to MSE since:

reconstruction_loss = - log(p( x | z)) = - log ∏ ( N(x(i), x_out(i), sigma**2) = − ∑ log ( N(x(i), x_out(i), sigma**2) . alpha . ∑ (x(i), x_out(i))**2

In contrast, the equation for the MSE loss is:

L(x,x_out) = MSE = 1/m ∑ (x(i) - x_out(i)) **2

Where m is the output dimensions. for example, in MNIST m = width × height × channels = 28 × 28 × 1 = 784

Thus,

reconstruction_loss = mse(inputs, outputs)

should be multiplied by m (i.e. original dimension) to be equal to the original reconstruction loss in the VAE formulation.

Answered By: Lina Achaji

In a Variational Autoencoder (VAE), the loss function is the negative Evidence Lower Bound ELBO, which is a sum of two terms:

# simplified formula
VAE_loss = reconstruction_loss + B*KL_loss

The KL_loss is also knwon as regularization_loss. Originally, B is set to 1.0, but it can be used as a hyperparameter, as in the beta-VAEs (source 1, source 2).

When training on images, consider that the input tensors have a shape of (batch_size, height, width, channels). However, the VAE_loss is a scalar value that is averaged along the batch size, and you should take sum of the loss function across all other dimensions. That is, you should compute the loss for each training sample in the batch to get a vector of shape (batch_size, ) and then take the average value as the VAE_loss.

When using Mean Squared Error (MSE) or Binary Cross-Entropy (BCE) to compute the reconstruction loss, you get a result that is an average and not a sum. Therefore, you should the multiply the result by the total number of dimensions, such as np.prod(INPUT_DIM), where INPUT_DIM is the input tensor’s shape. Notice, however, that if you forget to do so and take the reconstruction loss as the BCE or MSE, you are effectively applying a smaller value for the hyperparameter B in the VAE loss, so it may work.

For instance, when you call tensorflow binary cross entropy loss function, it will compute this sum and divide by the number of items (check here for a detailed example):
BCE Formula

The term n in this formula will be the number of items that were summed along the specified axis, and not the number of the batch size. However, your loss should be the sum along all dimensions, averaged for the different samples in the batch. You should pay attention to the shape of your tensor when computing the VAE loss.

Let us see this in more detail:

The KL_loss or regularization loss measured the difference between the distribution of the latent or encoded variables and the assumed prior distribution (usually a standard normal distribution). You can compute this ir with the following code:

    from keras import backend as K

    # z_mean and z_log_var have shape: (batch_size, latent_dim)
    # Regularization loss or KL loss:
    regularization_loss = 1 + z_log_var - K.square(z_mean) -
                          K.exp(z_log_var)
    # After the sum, regularization loss has shape: (batch_size, )
    regularization_loss = -0.5 * K.sum(regularization_loss, axis=-1) 

Notice that the regularization_loss comes from the encoder: for each input, the encoder is computing a value of the vectors z_mean and z_log_var. The regularization loss measures how much the values differ from (mean=0, variance=1). In this loss, you take the sum along the dimension of the latent variable: the larger the dimension of the latent variables in your autoencoder, the larger this loss will be.

The reconstruction loss takes different forms, as it is based on the expected distribution of the output or predicted variable. From appendix C in the original variational autoencoder paper:

In variational auto-encoders, neural networks are used as probabilistic encoders and decoders. There are many possible choices of encoders and decoders, depending on the type of data and model.

You can use a variational autoencoder (VAE) with continuous variables or with binary variables. You need to make some assumption about the distribution of the data in order to select the reconstruction loss function. Let X be your input variable, and let m be its dimension (for MNIST images, m = 28*28*1 = 784). Two common assumptions are:

  • X is continuous: you can assume the output is normally distributed (each pixel is independent), and the reconstruction loss is the L-2 norm, i.e., the sum of squares, which you can compute as: m*MSE

  • X is a binary variable (e.g. 0/1, according to your activation function in the output layer): you can assume that the output follows a Bernoulli distribution, and the reconstruction loss is m*BCE.

For the discrete case, this code will work:

    INPUT_DIM= (28,28,1)
    # Reconstruction loss for binary variables, shape=(batch_size, )
    reconstruction_loss = keras.losses.binary_crossentropy(inputs,
                                                           outputs,
                                                           axis=[1,2,3])
    reconstruction_loss *= K.constant(np.prod(INPUT_DIM))

Notice how the BCE was applied along the axis 1,2 and 3, but not along axis 0, which is the number of samples in the batch.

For the continuous case, a possible code is:

    # Reconstruction loss for continuous variables, shape=(batch_size, )
    reconstruction_loss = K.mean(K.square(outputs - inputs), axis=[1,2,3])                                 
    reconstruction_loss *= K.constant(np.prod(INPUT_DIM))

Notice that Keras loss MeanSquaredError does not accept the axis parameter, so we cannot use it to retrieve the MSE. Also, you could simply use:

reconstruction_loss = K.sum(K.square(outputs - inputs), axis=[1,2,3])      

The fact that the reconstruction loss is a sum along all dimensions means that the larger the dimension of your data, the larger this sum will be. That is, an 28×28 image will produce smaller a reconstruction loss than a 100×100 image. In practice, you may need to adjust the value of the hyperparameter B.

Finally, you can sum and take the average along the samples to get the VAE loss:

    # Total VAE loss (-ELBO)
    VAE_loss= K.mean(reconstruction_loss +
                        regularization_loss*K.constant(B))

A more detailed exaplanation of different loss functions can be found here.

Answered By: mikeliux