How to save/load a model checkpoint with several losses in Pytorch?

Question:

Using Ubuntu 20.04, Pytorch 1.10.1.

I am trying to solve a music generation task with a transformer architecture and multi-embeddings, for processing tokens with several characteristics.

In each training iteration, I have to calculate the loss of each token characteristic and store it in a vector, then I suppose that I should store in a checkpoint a vector containing all of them (or something similar), instead of what I’m doing now which is saving the total loss. I would like to know how to store all losses in the checkpoint (be able to keep training when loading it), or if it isn’t needed at all.

The epochs loop:

for epoch in range(0, epochs):
    
    print('Epoch: ', epoch)
    
    loss = trfrmr.train(epoch+1, model, train_loader, train_loss_func, opt, lr_scheduler, num_iters=-1)
    loss_train.append(loss)
    
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': opt.state_dict(),
        'loss': loss,
        }, "model_pop909_checkpoint.pth")

The training loop:

for batch_num, batch in enumerate(dataloader):
    time_before = time.time()

    opt.zero_grad()

    x = batch[0].to(get_device())
    tgt = batch[1].to(get_device())

    # x is the input sequence (N,T,Z), that should be input into the transformer forward function as (T,N,Z)
    y = model(x.permute(1, 0, 2))

    # tgt is the real output sequence, of shape (N,T,Z), T is sequence length, N batch size, Z the different token types
    # y are the output logits, is a list of Z tensors of shape (T,N,C*) where C is the vocabulary size, and will vary depending on the token type (pitch, velocity etc...)
    losses = []
    for j in range(LEN_VOCAB):
        aux_loss = loss.forward(y[j].permute(1, 2, 0),
                                        tgt[..., j])  # shapes (N,C,T) and (N,T), see Pytorch cross-entropy for details
        losses.append(aux_loss)

    losses_sum = sum(losses)  # here we sum, but we could also have mean for instance

    losses_sum.backward()
    opt.step()

    if lr_scheduler is not None:
         lr_scheduler.step()

    lr = opt.param_groups[0]['lr']
                
    loss_hist.append(losses_sum)
    if batch_num == num_iters:
       break

Answers:

As far as I can tell from your code, your loss function has no custom learnable parameters; it’s just recalculated every time your model iterates. Thus there is no need to save its value other than keeping a history of it; it is not required to continue training from a checkpoint.

Answered By: dx2-66

The problem was that when loading again the model I wasn’t doing it properly (not loading optimizer parameters, but only model ones). Now in my code, at the beginning of the loop I do:

if loaded:
    print('Loading model and optimizer...')
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    opt.load_state_dict(checkpoint['optimizer_state_dict'])
    print('Loaded succesfully!')

And I also load the epoch:

epoch = 0
if loaded:
    print('Loading epoch value...')
    epoch = checkpoint['epoch'] 
    print('Loaded succesfully!')

This answer was posted as an edit to the question How to save/load a model checkpoint with several losses in Pytorch? by the OP Enrique Vilchez Campillejo under CC BY-SA 4.0.

Answered By: vvvvv