model.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:human-rl 作者: gsastry 项目源码 文件源码
def __init__(self, ob_space, ac_space, size=256, **kwargs):
        self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space))

        for i in range(4):
            x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
        # introduce a "fake" batch dimension of 1 after flatten so that we can do GRU over time dim
        x = tf.expand_dims(flatten(x), 1)

        gru = rnn.GRUCell(size)

        h_init = np.zeros((1, size), np.float32)
        self.state_init = [h_init]
        h_in = tf.placeholder(tf.float32, [1, size])
        self.state_in = [h_in]

        gru_outputs, gru_state = tf.nn.dynamic_rnn(
            gru, x, initial_state=h_in, sequence_length=[size], time_major=True)
        x = tf.reshape(gru_outputs, [-1, size])
        self.logits = linear(x, ac_space, "action", normalized_columns_initializer(0.01))
        self.vf = tf.reshape(linear(x, 1, "value", normalized_columns_initializer(1.0)), [-1])
        self.state_out = [gru_state[:1]]
        self.sample = categorical_sample(self.logits, ac_space)[0, :]
        self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号