"""Conv block and U-net model."""
import tensorflow as tf
[docs]def conv_block(x, layout, filters=None, transpose=False, rate=0.1, activation=None, is_training=True):
    """Convolutional block.
    Parameters
    ----------
    x : tensor
        Input tensor.
    layout : str
        Layout of layers. Can contain "c" for convolution, "a" for activation, "n" for batchnorm,
        "p" for maxpooling, "d" for dropout. E.g. of layout: "ccnapd".
    filters : int, list or None
        Number of filters for convolutions. Can be a single number (all convolutions will
        have the same number of filters), a list of the same length as a count of letters "c"
        in the layout, or None if the layout contains no "c".
    transpose : bool
        If true, transposed convolutions are used.
    rate : float
        Dropout rate parameter. Default to 0.1.
    activation : function
        Activation function. If not specified activation is tf.nn.elu.
    is_training: bool
        Phase of training for batchnorm.
        Default to True.
    Returns
    -------
    x : tensor
        Output tensor.
    """
    i = 0
    try:
        iter(filters)
    except TypeError:
        filters = list([filters] * layout.count('c'))
    for s in layout:
        if s == 'c':
            if transpose:
                x = tf.layers.conv2d_transpose(x, filters[i], (3, 3), strides=(2, 2), padding='same')
            else:
                x = tf.layers.conv2d(x, filters[i], (3, 3), padding='same')
            i += 1
        elif s == 'a':
            if activation is None:
                activation = tf.nn.elu
            x = activation(x)
        elif s == 'p':
            x = tf.layers.max_pooling2d(x, (2, 2), (2, 2))
        elif s == 'n':
            x = tf.layers.batch_normalization(x, training=is_training, momentum=0.9)
        elif s == 'd':
            if rate is None:
                rate = 0.1
            x = tf.layers.dropout(x, rate=rate, training=is_training)
        else:
            raise KeyError('unknown letter {0}'.format(s))
    return x 
[docs]def u_net(images, depth, init_filters, output=None, is_training=True, verbose=0):
    """U-net implementation.
    Parameters
    ----------
    images : 4d tensor
        Input tensor.
    depth : int
        Depth of the U-net.
    init_filters : int
        Number of filters in the first conv block.
    output : dict
        Output conv block config.
    is_training: bool
        Phase of training for batchnorm. Default to True.
    verbose : int
        Level for information messages. Default to 0.
    Returns
    -------
    up : 4d tensor
        Output tensor.
    """
    verboseprint = print if verbose == 0 else lambda *args, **kwargs: None
    conv_d = []
    conv = images
    verboseprint('input', conv.get_shape())
    for d in range(depth):
        conv = conv_block(conv, 'caca', init_filters * (2 ** d), is_training=is_training)
        verboseprint('conv_block_{0}'.format(d), conv.get_shape())
        conv_d.append(conv)
        conv = conv_block(conv, 'pd')
        verboseprint('pool_{0}'.format(d), conv.get_shape())
    conv = conv_block(conv, 'cacad', init_filters * (2 ** depth), is_training=is_training)
    verboseprint('bottom_conv_block_{0}'.format(depth), conv.get_shape())
    for d in range(depth, 0, -1):
        conv = conv_block(conv, 'cad', init_filters * (2 ** d), transpose=True, is_training=is_training)
        verboseprint('up_{0}'.format(d - 1), conv.get_shape())
        conv = tf.concat([conv, conv_d[d - 1]], axis=-1)
        verboseprint('concat_{0}'.format(d), conv.get_shape())
        conv = conv_block(conv, 'cacad', init_filters * (2 ** (d - 1)), is_training=is_training)
        verboseprint('up_conv_block_{0}'.format(d), conv.get_shape())
    if output is not None:
        out = conv_block(conv, is_training=is_training, **output)
    else:
        out = conv
    verboseprint('output shape', out.get_shape())
    return out