def get_network(self, input_tensor, is_training, reuse = False):
net = input_tensor
with tf.variable_scope('GaitNN', reuse = reuse):
with slim.arg_scope(self.get_arg_scope(is_training)):
with tf.variable_scope('DownSampling'):
with tf.variable_scope('17x17'):
net = layers.convolution2d(net, num_outputs = 256, kernel_size = 1)
slim.repeat(net, 3, self.residual_block, ch = 256, ch_inner = 64)
with tf.variable_scope('8x8'):
net = self.residual_block(net, ch = 512, ch_inner = 64, stride = 2)
slim.repeat(net, 2, self.residual_block, ch = 512, ch_inner = 128)
with tf.variable_scope('4x4'):
net = self.residual_block(net, ch = 512, ch_inner = 128, stride = 2)
slim.repeat(net, 1, self.residual_block, ch = 512, ch_inner = 256)
net = layers.convolution2d(net, num_outputs = 256, kernel_size = 1)
net = layers.convolution2d(net, num_outputs = 256, kernel_size = 3)
with tf.variable_scope('FullyConnected'):
# net = tf.reduce_mean(net, [1, 2], name = 'GlobalPool')
net = layers.flatten(net)
net = layers.fully_connected(net, 512, activation_fn = None, normalizer_fn = None)
with tf.variable_scope('Recurrent', initializer = tf.contrib.layers.xavier_initializer()):
cell_type = {
'GRU': tf.nn.rnn_cell.GRUCell,
'LSTM': tf.nn.rnn_cell.LSTMCell
}
cell = cell_type[self.recurrent_unit](self.FEATURES)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * self.rnn_layers, state_is_tuple = True)
net = tf.expand_dims(net, 0)
net, state = tf.nn.dynamic_rnn(cell, net, initial_state = cell.zero_state(1, dtype = tf.float32))
net = tf.reshape(net, [-1, self.FEATURES])
# Temporal Avg-Pooling
gait_signature = tf.reduce_mean(net, 0)
if is_training:
net = tf.expand_dims(gait_signature, 0)
net = layers.dropout(net, 0.7)
with tf.variable_scope('Logits'):
net = layers.fully_connected(net, self.num_of_persons, activation_fn = None,
normalizer_fn = None)
return net, gait_signature, state
评论列表
文章目录