def __call__(self, inputs, center_state, module_state):
"""
:return: output, new_center_features, new_module_state
"""
with tf.variable_scope(self.name):
reading_weights = tf.get_variable('reading_weights',shape=[self.center_size,self.context_input_size],initializer=tf.truncated_normal_initializer(stddev=0.1))
context_input = tf.matmul(center_state, tf.clip_by_norm(reading_weights,1.0))
inputs = tf.concat([inputs, context_input], axis=1) if self.input_size else context_input
inputs = tf.contrib.layers.fully_connected(inputs, num_outputs=self.center_output_size)
gru = tf.nn.rnn_cell.GRUCell(self.num_gru_units)
gru_output, new_module_state = gru(inputs=inputs, state=module_state)
output, center_feature_output = tf.split(gru_output,
[self.output_size, self.center_output_size],
axis=1) if self.output_size else (None, gru_output)
return output, center_feature_output, new_module_state
评论列表
文章目录