rainbow_network.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:SRLF 作者: Fritz449 项目源码 文件源码
def create_networks(self):
        self.mean = tf.get_variable("means", shape=(1, int(self.state_input.get_shape()[1])),
                                    initializer=tf.constant_initializer(0),
                                    trainable=False)
        self.std = tf.get_variable("stds", shape=(1, int(self.state_input.get_shape()[1])),
                                   initializer=tf.constant_initializer(1),
                                   trainable=False)

        mean_ph = tf.placeholder(tf.float32, shape=self.mean.get_shape())
        std_ph = tf.placeholder(tf.float32, shape=self.std.get_shape())
        self.norm_set_op = [self.mean.assign(mean_ph), self.std.assign(std_ph)]
        self.norm_phs = [mean_ph, std_ph]
        self.good_input = tf.clip_by_value((self.state_input - self.mean) / (self.std + 1e-5), -50, 50)
        self.good_next_input = tf.clip_by_value((self.next_state_input - self.mean) / (self.std + 1e-5), -50, 50)

        self.atom_probs, self.weights, self.weights_phs = self.create_network("network", self.good_input)
        self.target_atom_probs, self.target_weights, self.target_weights_phs = self.create_network("target",
                                                                                                   self.good_next_input)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号