def MLP(self, inputs, output_size, func=None, keep_prob=None, n_splits=1):
""""""
n_dims = len(inputs.get_shape().as_list())
batch_size = tf.shape(inputs)[0]
bucket_size = tf.shape(inputs)[1]
input_size = inputs.get_shape().as_list()[-1]
output_shape = tf.pack([batch_size] + [bucket_size]*(n_dims-2) + [output_size])
shape_to_set = [tf.Dimension(None)]*(n_dims-1) + [tf.Dimension(output_size)]
if func is None:
func = self.mlp_func
if self.moving_params is None:
if keep_prob is None:
keep_prob = self.mlp_keep_prob
else:
keep_prob = 1
if keep_prob < 1:
noise_shape = tf.pack([batch_size] + [1]*(n_dims-2) + [input_size])
inputs = tf.nn.dropout(inputs, keep_prob, noise_shape=noise_shape)
linear = linalg.linear(inputs,
output_size,
n_splits=n_splits * (1+(func.__name__ in ('gated_tanh', 'gated_identity'))),
add_bias=True,
moving_params=self.moving_params)
if func.__name__ in ('gated_tanh', 'gated_identity'):
linear = [tf.concat(n_dims-1, [lin1, lin2]) for lin1, lin2 in zip(linear[:len(linear)//2], linear[len(linear)//2:])]
if n_splits == 1:
linear = [linear]
for i, split in enumerate(linear):
split = func(split)
split.set_shape(shape_to_set)
linear[i] = split
if n_splits == 1:
return linear[0]
else:
return linear
#=============================================================
评论列表
文章目录