def linear(self, inputs, output_size, n_splits=1, add_bias=False):
""""""
n_dims = len(inputs.get_shape().as_list())
batch_size = tf.shape(inputs)[0]
bucket_size = tf.shape(inputs)[1]
input_size = inputs.get_shape().as_list()[-1]
output_shape = tf.pack([batch_size] + [bucket_size]*(n_dims-2) + [output_size])
shape_to_set = [tf.Dimension(None)]*(n_dims-1) + [tf.Dimension(output_size)]
if self.moving_params is None:
keep_prob = self.info_keep_prob
else:
keep_prob = 1
if keep_prob < 1:
noise_shape = tf.pack([batch_size] + [1]*(n_dims-2) + [input_size])
inputs = tf.nn.dropout(inputs, keep_prob, noise_shape=noise_shape)
lin = linalg.linear(inputs,
output_size,
n_splits=n_splits,
add_bias=add_bias,
moving_params=self.moving_params)
if n_splits == 1:
lin = [lin]
for i, split in enumerate(lin):
split.set_shape(shape_to_set)
if n_splits == 1:
return lin[0]
else:
return lin
#=============================================================
评论列表
文章目录