Is there a function to create a UNET of custom depth in python/KERAS?

Question:

I know one can be created manually, but wanted to know if someone has created a function (similar to MATLAB’s unet) where you can choose the number of steps along the encoder/decoder paths.

Asked By: R0bots

||

Answers:

I’ve done this code that creates parametric UNET :

class UNET:
    def __init__(self, n_class, n_level, n_filter, n_block, input_shape, loss, dropout_bool='False'):
        super(UNET, self).__init__()
        self.n_class = n_class
        self.input_shape = input_shape
        self.n_level = n_level
        self.n_filter = n_filter
        self.loss = loss
        self.n_block = n_block
        self.dropout_bool = dropout_bool

    @staticmethod
    def conv2d_bn_relu(x, filters, kernel_size=3, strides=1):
        x_conv = tf.keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding='same',
                                        kernel_initializer='he_normal')(x)
        x_bn = tf.keras.layers.BatchNormalization()(x_conv)
        x_relu = tf.keras.layers.Activation('relu')(x_bn)
        return x_relu

    @staticmethod
    def conv2d_transpose_bn_relu(x, filters, kernel_size=3, strides=1):
        x_conv = tf.keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding='same',
                                        kernel_initializer='he_normal')(tf.keras.layers.UpSampling2D(size=(2, 2))(x))
        x_bn = tf.keras.layers.BatchNormalization()(x_conv)
        x_relu = tf.keras.layers.Activation('relu')(x_bn)
        return x_relu

    def call(self, ckpt_name, predicting=False, retrain=False):
        inputs = self.input_shape
        net = {}
        # Downsampling

        for l in range(0, self.n_level):
            strides = 1
            name = 'conv{}'.format(l)
            if l == 0:
                x = self.conv2d_bn_relu(inputs, filters=self.n_filter[l], kernel_size=3, strides=strides)
            else:
                x = self.conv2d_bn_relu(x, filters=self.n_filter[l], kernel_size=3, strides=strides)
            for _ in range(1, self.n_block[l]):
                x = self.conv2d_bn_relu(x, filters=self.n_filter[l], kernel_size=3)
            net[name] = x
            if l != self.n_level - 1:
                x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
            if l in (self.n_level - 1, self.n_level - 2) and self.dropout_bool:
                x = tf.keras.layers.SpatialDropout2D(0.5)(x)

        # Upsampling

        l = self.n_level - 1

        net['conv{}_up'.format(l)] = net['conv{}'.format(l)]
        for l in range(self.n_level - 2, -1, -1):
            name = 'conv{}_up'.format(l)
            x = self.conv2d_transpose_bn_relu(net['conv{}_up'.format(l + 1)], filters=self.n_filter[l], kernel_size=2,
                                              strides=1)
            x = tf.keras.layers.concatenate([net['conv{}'.format(l)], x], axis=-1)
            for i in range(0, self.n_block[l]):
                if l == 0:
                    x = self.conv2d_bn_relu(x, filters=self.n_filter[l], kernel_size=3)
                else:
                    x = self.conv2d_bn_relu(x, filters=self.n_filter[l], kernel_size=3)
            net[name] = x
        logits = tf.keras.layers.Conv2D(filters=self.n_class, kernel_size=1)(net['conv0_up'])
        if self.loss != "mse":
            logits = tf.keras.layers.Activation('sigmoid')(logits)
        model = tf.keras.Model(inputs=inputs, outputs=logits)

        if predicting or retrain:
            model.load_weights(ckpt_name)

        return model

You have to give to the UNET the number of output class (n_class), the number of level (n_level), the initial filter (n_filter ex: 16 then 64 -> 128 -> 256 …), number of blocs per step, which loss to use (if mse => linear activation if bce => sigmoid).

In the main you have to compute n_filter and n_blocks and then call the method :

n_filter = []
for level in range(args.num_level):
    n_filter += [args.num_filter * pow(2, level)]
    print('Number of filters at each level = ', n_filter)
n_block = [2] * args.num_level
dl_model = UNET(1, args.num_level, n_filter, n_block, input_shape, args.loss, dropout_bool='False').call(ckpt_name, predicting=args.predicting)
Answered By: Orphee Faucoz