nn.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:demeter 作者: evancasey 项目源码 文件源码
def __init__(self,
                 name,
                 hidden_layers,
                 input_dims,
                 output_dims):

        # tf
        self.sess = tf.get_default_session()

        with tf.variable_scope(name):
            self.obs = tf.placeholder(shape=[None, input_dims], dtype=tf.float32)
            flat_input_state = slim.flatten(self.obs, scope='flat')

            if hidden_layers == "":
                self.logits = slim.fully_connected(
                    inputs=flat_input_state,
                    num_outputs=output_dims,
                    activation_fn=None,
                    weights_initializer=tf.zeros_initializer)
            else:
                final_hidden = self.hidden_layers_starting_at(flat_input_state, hidden_layers)
                self.logits = slim.fully_connected(
                        inputs=final_hidden,
                        num_outputs=output_dims,
                        activation_fn=None)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号