def channel_wise_fc_layer(bottom, name, bias=True):
"""
channel-wise fully connected layer
"""
_, width, height, n_feat_map = bottom.get_shape().as_list()
input_reshape = tf.reshape( bottom, [-1, width*height, n_feat_map] ) # order='C'
input_transpose = tf.transpose( input_reshape, [2,0,1] ) # n_feat_map * batch * d
with tf.variable_scope(name):
W = tf.get_variable(
"W",
shape=[n_feat_map,width*height, width*height], # n_feat_map * d * d_filter
initializer=tf.truncated_normal_initializer(0., 0.005))
output = tf.batch_matmul(input_transpose, W) # n_feat_map * batch * d_filter
if bias == True:
b = tf.get_variable(
"b",
shape=width*height,
initializer=tf.constant_initializer(0.))
output = tf.nn.bias_add(output, b)
output_transpose = tf.transpose(output, [1,2,0]) # batch * d_filter * n_feat_map
output_reshape = tf.reshape( output_transpose, [-1, width, height, n_feat_map] )
return output_reshape
评论列表
文章目录