def weighted_average(self, inputs, moving_params=None):
""""""
input_shape = tf.shape(inputs)
batch_size = input_shape[0]
bucket_size = input_shape[1]
input_size = len(self)
if moving_params is not None:
trainable_embeddings = moving_params.average(self.trainable_embeddings)
else:
trainable_embeddings = self.trainable_embeddings
embed_input = tf.matmul(tf.reshape(inputs, [-1, input_size]),
trainable_embeddings)
embed_input = tf.reshape(embed_input, tf.pack([batch_size, bucket_size, self.embed_size]))
embed_input.set_shape([tf.Dimension(None), tf.Dimension(None), tf.Dimension(self.embed_size)])
if moving_params is None:
tf.add_to_collection('Weights', embed_input)
return embed_input
#=============================================================
评论列表
文章目录