Some questions about the plotted results of a Unet

Question:

I’m studying this example about Unet.
It is about binary segmentation, and I have some questions about the code:

  1. what is the meaning of doing:

    #preprocess the mask 
    mask[mask >= 2] = 0 
    mask[mask != 0 ] = 1
    

    The dataset contains “mask” pictures composed by three colours (indeed they’re called “trimaps”). As a test, I tried to plot mask before and after this piece of code, and it seems that the role of these code lines is to convert mask pictures from three-colours to two-colours (background: purple and foregorund: yellow), but I don’t know how.

  2. At the bottom of the “Generators” section, there’s a picture composed by three subpictures. The sub-image in the middle is a “black and white” mask.
    Which are the code lines which perform the conversion of the colours of “mask” pictures from purple/yellow to black/white?

  3. Finally, I tried to plot msk through the code line plt.imshow(msk), instead of plotting it through plt.imshow( np.concatenate([img, msk], axis = 1)) (as done in the code).
    But the result of plotting msk through plt.imshow(msk) is a black picture, why?

Asked By: rainbow

||

Answers:

  1. UNet, at least in its original form, works with binary masks. You have masks with three regions, background, object and some kind of edge. That piece of code is making the background (label 2) equal to zero and the object and its edge (labels 0 and 1) equal to one. This way you have a binary mask to use as ground truth. You see them in purple and yellow because matplotlib default color map is viridis which happens to be purple at zero and yellow at 1. Not that this is actually throwing away useful informations from those masks that could be someway used to train a better model. But that’s okay to simplify things a bit and better understand what’s going on.

  2. The last step in the mask preprocessing code converts single color masks to rgb. So when you plot them with colored images your mask can be either (0, 0, 0) which is black, or (1, 1, 1) which is white.

  3. Not sure, it should work, probably something about default normalization in plt.imshow

Answered By: filippo