UNet loss is NaN + UserWarning: Warning: converting a masked element to nan

Question:

I’m training a UNet, which class looks like this:

class UNet(nn.Module):
def __init__(self):
    super().__init__()

    # encoder (downsampling)
    # Each enc_conv/dec_conv block should look like this:
    # nn.Sequential(
    #     nn.Conv2d(...),
    #     ... (2 or 3 conv layers with relu and batchnorm),
    # )
    self.enc_conv0 = nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, stride=1),
        nn.BatchNorm2d(64),
        nn.ReLU()
        )
    self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=False)  # 256 -> 128
    self.enc_conv1 = nn.Sequential(
        nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU(),
        nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU()
        )
    self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=False) # 128 -> 64
    self.enc_conv2 = nn.Sequential(
        nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
        nn.BatchNorm2d(256),
        nn.ReLU(),
        nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1), 
        nn.BatchNorm2d(256),
        nn.ReLU(),
        nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
        nn.BatchNorm2d(256),
        nn.ReLU()
        )
    self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # 64 -> 32
    self.enc_conv3 = nn.Sequential(
        nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(512),
        nn.ReLU(),

        nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(512),
        nn.ReLU(),

        nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(512),
        nn.ReLU()
    )
    self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) # 32 -> 16

    # bottleneck
    self.bottleneck_conv = nn.Sequential(
        nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1, stride=1, padding=0),
        nn.BatchNorm2d(1024),
        nn.ReLU(),

        nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1, stride=1, padding=0),
        nn.BatchNorm2d(512),
        nn.ReLU()
    )

    # decoder (upsampling)
    self.upsample0 = nn.UpsamplingBilinear2d(scale_factor=2) # 16 -> 32
    self.dec_conv0 = nn.Sequential(
        nn.Conv2d(in_channels=512*2, out_channels=256, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(256),
        nn.ReLU(),

        nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(256),
        nn.ReLU(),

        nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(256),
        nn.ReLU()
    )
    self.upsample1 = nn.UpsamplingBilinear2d(scale_factor=2) # 32 -> 64
    self.dec_conv1 = nn.Sequential(
        nn.Conv2d(in_channels=256*2, out_channels=128, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU(),

        nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU(),

        nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU()
    )
    self.upsample2 = nn.UpsamplingBilinear2d(scale_factor=2) # 64 -> 128
    self.dec_conv2 = nn.Sequential(
        nn.Conv2d(in_channels=128*2, out_channels=64, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(),

        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU()
    )
    self.upsample3 = nn.UpsamplingBilinear2d(scale_factor=2) # 128 -> 256
    self.dec_conv3 = nn.Sequential(
        nn.Conv2d(in_channels=64*2, out_channels=1, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(1),
        nn.ReLU(),

        nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(1),
        nn.ReLU(),

        nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(1)
    )

def forward(self, x):
    # encoder
    e0 = self.enc_conv0(x)
    pool0 = self.pool0(e0)
    e1 = self.enc_conv1(pool0)
    pool1 = self.pool1(e1)
    e2 = self.enc_conv2(pool1)
    pool2 = self.pool2(e2)
    e3 = self.enc_conv3(pool2)
    pool3 = self.pool3(e3)

    # bottleneck
    b = self.bottleneck_conv(pool3)

    # decoder
    d0 = self.dec_conv0(torch.cat([self.upsample0(b), e3], 1))
    d1 = self.dec_conv1(torch.cat([self.upsample1(d0), e2], 1))
    d2 = self.dec_conv2(torch.cat([self.upsample2(d1), e1], 1))
    d3 = self.dec_conv3(torch.cat([self.upsample3(d2), e0], 1))  # no activation
    return d3

Train method:

def train(model, opt, loss_fn, score_fn, epochs, data_tr, data_val):

torch.cuda.empty_cache()

losses_train = []
losses_val = []
scores_train = []
scores_val = []

for epoch in range(epochs):
    tic = time()
    print('* Epoch %d/%d' % (epoch+1, epochs))

    avg_loss = 0
    model.train()  # train mode
    for X_batch, Y_batch in data_tr:
        # data to device
        X_batch = X_batch.to(device)
        Y_batch = Y_batch.to(device)

        # set parameter gradients to zero
        opt.zero_grad()

        # forward
        Y_pred = model(X_batch)
        loss = loss_fn(Y_pred, Y_batch) # forward-pass
        loss.backward()  # backward-pass
        opt.step()  # update weights

        # calculate loss to show the user
        avg_loss += loss / len(data_tr)
    toc = time()
    print('loss: %f' % avg_loss)
    losses_train.append(avg_loss)

    avg_score_train = score_fn(model, iou_pytorch, data_tr)
    scores_train.append(avg_score_train)

    # show intermediate results
    model.eval()  # testing mode
    avg_loss_val = 0
    #Y_hat = # detach and put into cpu

    for X_val, Y_val in data_val:
      with torch.no_grad():
        Y_hat = model(X_val.to(device)).detach().cpu()

        loss = loss_fn(Y_hat, Y_val)
        avg_loss_val += loss / len(data_val)

    toc = time()
    print('loss_val: %f' % avg_loss_val)
    losses_val.append(avg_loss_val)

    avg_score_val = score_fn(model, iou_pytorch, data_val)
    scores_val.append(avg_score_val)

    torch.cuda.empty_cache()

    # Visualize tools
    clear_output(wait=True)
    for k in range(5):
        plt.subplot(2, 6, k+1)
        plt.imshow(np.rollaxis(X_val[k].numpy(), 0, 3), cmap='gray')
        plt.title('Real')
        plt.axis('off')

        plt.subplot(2, 6, k+7)
        plt.imshow(Y_hat[k, 0], cmap='gray')
        plt.title('Output')
        plt.axis('off')
    plt.suptitle('%d / %d - loss: %f' % (epoch+1, epochs, avg_loss))
    plt.show()

return (losses_train, losses_val, scores_train, scores_val)

However, when executing I get train_loss and val_loss both equal nan and also a warning. In addition, when plotting the segmented picture and the target one, the output picture is not shown. I tried to execute with different loss function, but still the same. There is probably something wrong with my class.

Could you please help me? Thanks in advance.

Asked By: ALiCe P.

||

Answers:

I am not sure if this is your error, but your last Convolution layer (self.dec_conv3) has looks odd. I would only reduce to 1 channel at the very last convolution and do not perform 2 Convolutions with 1 In and 1 Out channel. Also ending with a batchnorm can only produce normalized outputs, which could be far from what you really want:

self.dec_conv3 = nn.Sequential(
    nn.Conv2d(in_channels=64*2, out_channels=32, kernel_size=3, stride=1, padding=1),
    nn.BatchNorm2d(32),
    nn.ReLU(),

    nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
    nn.BatchNorm2d(32),
    nn.ReLU(),
    nn.Conv2d(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding=1)
)

It would be interesting if your loss is Nan already at the first iteration or only after a few iterations. Maybe, you use a loss function, that devides by zero?

Answered By: MarcoM