Multi Scale Segmentation mask outputs in keras in U Net

Question:

So this is the model, with input as a single image and outputs at different scales of the image, i.e., I, 1/2 I, 1/4 I and 1/8 I, Model(inputs=[inputs], outputs=[out6, out7, out8, out9])

I am not sure how to create the train dataset. Suppose the input to the y_train will be data of say shape (50, 192, 256, 3) where 3 = channel of the image, 192 is the width and 256 is the height, and there are 50 of them, but how to create a y_train which will have 4 components? I have tried with zip and then converting it to numpy but that doesn’t works…

Asked By: Jimut123

||

Answers:

If you necessarily want the model to learn to generate multi-scale masks then you can try downsampling to generate the scaled masks for supervised learning using UNET. You can use interpolation-based methods to automatically resize an image with minimum loss. Here is a post where I compare benchmarks against multiple such methods.

If you want to create [masks, masks_half, masks_quarter, masks_eighth] for your model.fit, which is the list of original + rescaled versions of the mask images, you may wanna try a fast downsampling method (depending on the size of your dataset).

Here I have used skimage.transform.pyramid_reduce to downsample a mask to half, quarter, and eighth of its scale. The method uses interpolation (spline), but can be controlled via parameters. Check this for more details.

from skimage.transform import pyramid_reduce

masks = np.random.random((50, 192, 256, 3))

masks_half = np.stack([pyramid_reduce(i, 2, multichannel=True) for i in masks])
masks_quater = np.stack([pyramid_reduce(i, 4, multichannel=True) for i in masks])
masks_eighth = np.stack([pyramid_reduce(i, 8, multichannel=True) for i in masks])

print('Shape of original',masks.shape)
print('Shape of half scaled',masks_half.shape)
print('Shape of quater scaled',masks_quater.shape)
print('Shape of eighth scaled',masks_eighth.shape)
Shape of original (50, 192, 256, 3)
Shape of half scaled (50, 96, 128, 3)
Shape of quater scaled (50, 48, 64, 3)
Shape of eighth scaled (50, 24, 32, 3)

Testing on a single image/mask –

from skimage.data import camera
from skimage.transform import pyramid_reduce

def plotit(img, h, q, e):
    fig, axes = plt.subplots(1,4, figsize=(10,15))
    axes[0].imshow(img)
    axes[1].imshow(h)
    axes[2].imshow(q)
    axes[3].imshow(e)
    axes[0].title.set_text('Original')
    axes[1].title.set_text('Half')
    axes[2].title.set_text('Quarter')
    axes[3].title.set_text('Eighth')

img = camera() #(512,512)
h = pyramid_reduce(img, 2)   #Half
q = pyramid_reduce(img, 4)   #Quarter
e = pyramid_reduce(img, 8)   #Eighth

plotit(img, h, q, e)

enter image description here

Notice the change in scale over x and y-axis ———————>

Answered By: Akshay Sehgal
Categories: questions Tags: , , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.