base_controller.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:EAS 作者: han-cai 项目源码 文件源码
def build_forward(self, _input):
        output = _input  # [batch_size, num_steps, rnn_units]
        feature_dim = int(output.get_shape()[2])  # rnn_units
        output = tf.reshape(output, [-1, feature_dim])  # [batch_size * num_steps, rnn_units]
        final_activation = 'sigmoid' if self.out_dim == 1 else 'softmax'
        if self.net_type == 'simple':
            net_config = [] if self.net_config is None else self.net_config
            with tf.variable_scope('wider_actor'):
                for layer in net_config:
                    units, activation = layer.get('units'), layer.get('activation', 'relu')
                    output = BasicModel.fc_layer(output, units, use_bias=True)
                    output = BasicModel.activation(output, activation)
                logits = BasicModel.fc_layer(output, self.out_dim, use_bias=True)  # [batch_size * num_steps, out_dim]
            probs = BasicModel.activation(logits, final_activation)  # [batch_size * num_steps, out_dim]
            probs_dim = self.out_dim
            if self.out_dim == 1:
                probs = tf.concat([1 - probs, probs], axis=1)
                probs_dim = 2

            self.decision = tf.multinomial(tf.log(probs), 1)  # [batch_size * num_steps, 1]
            self.decision = tf.reshape(self.decision, [-1, self.num_steps])  # [batch_size, num_steps]
            self.probs = tf.reshape(probs, [-1, self.num_steps, probs_dim])  # [batch_size, num_steps, out_dim]
        else:
            raise ValueError('Do not support %s' % self.net_type)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号