"""Conv block and U-net model."""
import tensorflow as tf
[docs]def conv_block(x, layout, filters=None, transpose=False, rate=0.1, activation=None, is_training=True):
"""Convolutional block.
Parameters
----------
x : tensor
Input tensor.
layout : str
Layout of layers. Can contain "c" for convolution, "a" for activation, "n" for batchnorm,
"p" for maxpooling, "d" for dropout. E.g. of layout: "ccnapd".
filters : int, list or None
Number of filters for convolutions. Can be a single number (all convolutions will
have the same number of filters), a list of the same length as a count of letters "c"
in the layout, or None if the layout contains no "c".
transpose : bool
If true, transposed convolutions are used.
rate : float
Dropout rate parameter. Default to 0.1.
activation : function
Activation function. If not specified activation is tf.nn.elu.
is_training: bool
Phase of training for batchnorm.
Default to True.
Returns
-------
x : tensor
Output tensor.
"""
i = 0
try:
iter(filters)
except TypeError:
filters = list([filters] * layout.count('c'))
for s in layout:
if s == 'c':
if transpose:
x = tf.layers.conv2d_transpose(x, filters[i], (3, 3), strides=(2, 2), padding='same')
else:
x = tf.layers.conv2d(x, filters[i], (3, 3), padding='same')
i += 1
elif s == 'a':
if activation is None:
activation = tf.nn.elu
x = activation(x)
elif s == 'p':
x = tf.layers.max_pooling2d(x, (2, 2), (2, 2))
elif s == 'n':
x = tf.layers.batch_normalization(x, training=is_training, momentum=0.9)
elif s == 'd':
if rate is None:
rate = 0.1
x = tf.layers.dropout(x, rate=rate, training=is_training)
else:
raise KeyError('unknown letter {0}'.format(s))
return x
[docs]def u_net(images, depth, init_filters, output=None, is_training=True, verbose=0):
"""U-net implementation.
Parameters
----------
images : 4d tensor
Input tensor.
depth : int
Depth of the U-net.
init_filters : int
Number of filters in the first conv block.
output : dict
Output conv block config.
is_training: bool
Phase of training for batchnorm. Default to True.
verbose : int
Level for information messages. Default to 0.
Returns
-------
up : 4d tensor
Output tensor.
"""
verboseprint = print if verbose == 0 else lambda *args, **kwargs: None
conv_d = []
conv = images
verboseprint('input', conv.get_shape())
for d in range(depth):
conv = conv_block(conv, 'caca', init_filters * (2 ** d), is_training=is_training)
verboseprint('conv_block_{0}'.format(d), conv.get_shape())
conv_d.append(conv)
conv = conv_block(conv, 'pd')
verboseprint('pool_{0}'.format(d), conv.get_shape())
conv = conv_block(conv, 'cacad', init_filters * (2 ** depth), is_training=is_training)
verboseprint('bottom_conv_block_{0}'.format(depth), conv.get_shape())
for d in range(depth, 0, -1):
conv = conv_block(conv, 'cad', init_filters * (2 ** d), transpose=True, is_training=is_training)
verboseprint('up_{0}'.format(d - 1), conv.get_shape())
conv = tf.concat([conv, conv_d[d - 1]], axis=-1)
verboseprint('concat_{0}'.format(d), conv.get_shape())
conv = conv_block(conv, 'cacad', init_filters * (2 ** (d - 1)), is_training=is_training)
verboseprint('up_conv_block_{0}'.format(d), conv.get_shape())
if output is not None:
out = conv_block(conv, is_training=is_training, **output)
else:
out = conv
verboseprint('output shape', out.get_shape())
return out