def linear(inputs, output_size, add_bias=True, n_splits=1, initializer=None, scope=None, moving_params=None):
""""""
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
output_size *= n_splits
with tf.variable_scope(scope or 'Linear'):
# Reformat the input
total_input_size = 0
shapes = [a.get_shape().as_list() for a in inputs]
for shape in shapes:
total_input_size += shape[-1]
input_shape = tf.shape(inputs[0])
output_shape = []
for i in xrange(len(shapes[0])):
output_shape.append(input_shape[i])
output_shape[-1] = output_size
output_shape = tf.pack(output_shape)
for i, (input_, shape) in enumerate(zip(inputs, shapes)):
inputs[i] = tf.reshape(input_, [-1, shape[-1]])
concatenation = tf.concat(1, inputs)
# Get the matrix
if initializer is None and moving_params is None:
mat = orthonormal_initializer(total_input_size, output_size//n_splits)
mat = np.concatenate([mat]*n_splits, axis=1)
initializer = tf.constant_initializer(mat)
matrix = tf.get_variable('Weights', [total_input_size, output_size], initializer=initializer)
if moving_params is not None:
matrix = moving_params.average(matrix)
else:
tf.add_to_collection('Weights', matrix)
# Get the bias
if add_bias:
bias = tf.get_variable('Biases', [output_size], initializer=tf.zeros_initializer)
if moving_params is not None:
bias = moving_params.average(bias)
else:
bias = 0
# Do the multiplication
new = tf.matmul(concatenation, matrix) + bias
new = tf.reshape(new, output_shape)
new.set_shape([tf.Dimension(None) for _ in xrange(len(shapes[0])-1)] + [tf.Dimension(output_size)])
if n_splits > 1:
return tf.split(len(new.get_shape().as_list())-1, n_splits, new)
else:
return new
#===============================================================
评论列表
文章目录