Some questions about the plotted results of a Unet
Question:
I’m studying this example about Unet.
It is about binary segmentation, and I have some questions about the code:
-
what is the meaning of doing:
#preprocess the mask
mask[mask >= 2] = 0
mask[mask != 0 ] = 1
The dataset contains “mask” pictures composed by three colours (indeed they’re called “trimaps”). As a test, I tried to plot mask
before and after this piece of code, and it seems that the role of these code lines is to convert mask
pictures from three-colours to two-colours (background: purple
and foregorund: yellow
), but I don’t know how.
-
At the bottom of the “Generators” section, there’s a picture composed by three subpictures. The sub-image in the middle is a “black and white” mask.
Which are the code lines which perform the conversion of the colours of “mask” pictures from purple/yellow
to black/white
?
-
Finally, I tried to plot msk
through the code line plt.imshow(msk)
, instead of plotting it through plt.imshow( np.concatenate([img, msk], axis = 1))
(as done in the code).
But the result of plotting msk
through plt.imshow(msk)
is a black picture, why?
Answers:
-
UNet, at least in its original form, works with binary masks. You have masks with three regions, background, object and some kind of edge. That piece of code is making the background (label 2) equal to zero and the object and its edge (labels 0 and 1) equal to one. This way you have a binary mask to use as ground truth. You see them in purple and yellow because matplotlib default color map is viridis which happens to be purple at zero and yellow at 1. Not that this is actually throwing away useful informations from those masks that could be someway used to train a better model. But that’s okay to simplify things a bit and better understand what’s going on.
-
The last step in the mask preprocessing code converts single color masks to rgb. So when you plot them with colored images your mask can be either (0, 0, 0)
which is black, or (1, 1, 1)
which is white.
-
Not sure, it should work, probably something about default normalization in plt.imshow
I’m studying this example about Unet.
It is about binary segmentation, and I have some questions about the code:
-
what is the meaning of doing:
#preprocess the mask mask[mask >= 2] = 0 mask[mask != 0 ] = 1
The dataset contains “mask” pictures composed by three colours (indeed they’re called “trimaps”). As a test, I tried to plot
mask
before and after this piece of code, and it seems that the role of these code lines is to convertmask
pictures from three-colours to two-colours (background: purple
andforegorund: yellow
), but I don’t know how. -
At the bottom of the “Generators” section, there’s a picture composed by three subpictures. The sub-image in the middle is a “black and white” mask.
Which are the code lines which perform the conversion of the colours of “mask” pictures frompurple/yellow
toblack/white
? -
Finally, I tried to plot
msk
through the code lineplt.imshow(msk)
, instead of plotting it throughplt.imshow( np.concatenate([img, msk], axis = 1))
(as done in the code).
But the result of plottingmsk
throughplt.imshow(msk)
is a black picture, why?
-
UNet, at least in its original form, works with binary masks. You have masks with three regions, background, object and some kind of edge. That piece of code is making the background (label 2) equal to zero and the object and its edge (labels 0 and 1) equal to one. This way you have a binary mask to use as ground truth. You see them in purple and yellow because matplotlib default color map is viridis which happens to be purple at zero and yellow at 1. Not that this is actually throwing away useful informations from those masks that could be someway used to train a better model. But that’s okay to simplify things a bit and better understand what’s going on.
-
The last step in the mask preprocessing code converts single color masks to rgb. So when you plot them with colored images your mask can be either
(0, 0, 0)
which is black, or(1, 1, 1)
which is white. -
Not sure, it should work, probably something about default normalization in
plt.imshow