def resnn(self, image_batch):
"""Build the resnn model.
Args:
image_batch: Sequences returned from inputs_train() or inputs_eval.
Returns:
Logits.
"""
# First convolution
with tf.variable_scope('conv_layer1'):
net = self.conv2d(image_batch, self.groups[0].num_ker, 5, 1)
net = self.BN_ReLU(net)
# Max pool
if FLAGS.max_pool:
net = tf.nn.max_pool(net,
[1, 3, 3, 1],
strides=[1, 1, 1, 1],
padding='SAME')
# stacking Residual Units
for group_i, group in enumerate(self.groups):
for unit_i in range(group.num_units):
net = self.residual_unit(net, group_i, unit_i)
# an extra activation before average pooling
if FLAGS.special_first:
with tf.variable_scope('special_BN_ReLU'):
net = self.BN_ReLU(net)
# padding should be VALID for global average pooling
# output: batch*1*1*channels
net_shape = net.get_shape().as_list()
net = tf.nn.avg_pool(net,
ksize=[1, net_shape[1], net_shape[2], 1],
strides=[1, 1, 1, 1],
padding='VALID')
net_shape = net.get_shape().as_list()
softmax_len = net_shape[1] * net_shape[2] * net_shape[3]
net = tf.reshape(net, [-1, softmax_len])
# add dropout
if FLAGS.dropout:
with tf.name_scope("dropout"):
net = tf.nn.dropout(net, FLAGS.dropout_keep_prob)
# 2D-fully connected nueral network
with tf.variable_scope('FC-layer'):
net = fully_connected(
net,
num_outputs=FLAGS.num_cats,
activation_fn=None,
normalizer_fn=None,
weights_initializer=variance_scaling_initializer(),
weights_regularizer=l2_regularizer(FLAGS.weight_decay),
biases_initializer=tf.zeros_initializer, )
return net
评论列表
文章目录