def linear(inputs, output_size, bias, concat=False, dtype=None, scope=None):
"""
Linear layer
Args:
inputs: A Tensor or a list of Tensors with shape [batch, input_size]
output_size: An integer specify the output size
bias: a boolean value indicate whether to use bias term
concat: a boolean value indicate whether to concatenate all inputs
dtype: an instance of tf.DType, the default value is ``tf.float32''
scope: the scope of this layer, the default value is ``linear''
Returns:
a Tensor with shape [batch, output_size]
Raises:
RuntimeError: raises ``RuntimeError'' when input sizes do not
compatible with each other
"""
with tf.variable_scope(scope, default_name="linear", values=[inputs]):
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
input_size = [item.get_shape()[-1].value for item in inputs]
if len(inputs) != len(input_size):
raise RuntimeError("inputs and input_size unmatched!")
output_shape = tf.concat([tf.shape(inputs[0])[:-1], [output_size]],
axis=0)
# Flatten to 2D
inputs = [tf.reshape(inp, [-1, inp.shape[-1].value]) for inp in inputs]
results = []
if concat:
input_size = sum(input_size)
inputs = tf.concat(inputs, 1)
shape = [input_size, output_size]
matrix = tf.get_variable("matrix", shape, dtype=dtype)
results.append(tf.matmul(inputs, matrix))
else:
for i in range(len(input_size)):
shape = [input_size[i], output_size]
name = "matrix_%d" % i
matrix = tf.get_variable(name, shape, dtype=dtype)
results.append(tf.matmul(inputs[i], matrix))
output = tf.add_n(results)
if bias:
shape = [output_size]
bias = tf.get_variable("bias", shape, dtype=dtype)
output = tf.nn.bias_add(output, bias)
output = tf.reshape(output, output_shape)
return output
评论列表
文章目录