model.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:thalnet 作者: JuliusKunze 项目源码 文件源码
def __call__(self, inputs, center_state, module_state):
        """
        :return: output, new_center_features, new_module_state
        """
        with tf.variable_scope(self.name):
            reading_weights = tf.get_variable('reading_weights',shape=[self.center_size,self.context_input_size],initializer=tf.truncated_normal_initializer(stddev=0.1))

            context_input = tf.matmul(center_state, tf.clip_by_norm(reading_weights,1.0))

            inputs = tf.concat([inputs, context_input], axis=1) if self.input_size else context_input

            inputs = tf.contrib.layers.fully_connected(inputs, num_outputs=self.center_output_size)

            gru = tf.nn.rnn_cell.GRUCell(self.num_gru_units)

            gru_output, new_module_state = gru(inputs=inputs, state=module_state)

            output, center_feature_output = tf.split(gru_output,
                                                     [self.output_size, self.center_output_size],
                                                     axis=1) if self.output_size else (None, gru_output)

        return output, center_feature_output, new_module_state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号