def double_MLP(self, inputs, n_splits=1):
""""""
batch_size = tf.shape(inputs)[0]
bucket_size = tf.shape(inputs)[1]
input_size = inputs.get_shape().as_list()[-1]
output_size = self.mlp_size
output_shape = tf.pack([batch_size, bucket_size, bucket_size, output_size])
shape_to_set = [tf.Dimension(None), tf.Dimension(None), tf.Dimension(None), tf.Dimension(output_size)]
if self.moving_params is None:
if self.drop_gradually:
s = self.global_sigmoid
keep_prob = s + (1-s)*self.mlp_keep_prob
else:
keep_prob = self.mlp_keep_prob
else:
keep_prob = 1
if isinstance(keep_prob, tf.Tensor) or keep_prob < 1:
noise_shape = tf.pack([batch_size, 1, input_size])
inputs = tf.nn.dropout(inputs, keep_prob, noise_shape=noise_shape)
lin1, lin2 = linalg.linear(inputs,
output_size*n_splits,
n_splits=2,
add_bias=True,
moving_params=self.moving_params)
lin1 = tf.reshape(tf.transpose(lin1, [0, 2, 1]), tf.pack([-1, bucket_size, 1]))
lin2 = tf.reshape(tf.transpose(lin2, [0, 2, 1]), tf.pack([-1, 1, bucket_size]))
lin = lin1 + lin2
lin = tf.reshape(lin, tf.pack([batch_size, n_splits*output_size, bucket_size, bucket_size]))
lin = tf.transpose(lin, [0,2,3,1])
top_mlps = tf.split(3, n_splits, self.mlp_func(lin))
for top_mlp in top_mlps:
top_mlp.set_shape(shape_to_set)
if self.moving_params is None:
with tf.variable_scope('Linear', reuse=True):
matrix = tf.get_variable('Weights')
I = tf.diag(tf.ones([self.mlp_size]))
for W in tf.split(1, 2*n_splits, matrix):
WTWmI = tf.matmul(W, W, transpose_a=True) - I
tf.add_to_collection('ortho_losses', tf.nn.l2_loss(WTWmI))
for split in top_mlps:
tf.add_to_collection('covar_losses', self.covar_loss(split))
if n_splits == 1:
return top_mlps[0]
else:
return top_mlps
#=============================================================
评论列表
文章目录