def linear(input_,
output_size,
weights_initializer=initializers.xavier_initializer(),
biases_initializer=tf.zeros_initializer,
activation_fn=None,
trainable=True,
name='linear'):
shape = input_.get_shape().as_list()
if len(shape) > 2:
input_ = tf.reshape(input_, [-1, reduce(lambda x, y: x * y, shape[1:])])
shape = input_.get_shape().as_list()
with tf.variable_scope(name):
w = tf.get_variable('w', [shape[1], output_size], tf.float32,
initializer=weights_initializer, trainable=trainable)
b = tf.get_variable('b', [output_size],
initializer=biases_initializer, trainable=trainable)
out = tf.nn.bias_add(tf.matmul(input_, w), b)
if activation_fn != None:
return activation_fn(out), w, b
else:
return out, w, b
评论列表
文章目录