nem_model.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:Neural-EM 作者: sjoerdvansteenkiste 项目源码 文件源码
def init_state(self, batch_size, K, dtype):
        # inner RNN hidden state init
        with tf.name_scope('inner_RNN_init'):
            h = self.cell.zero_state(batch_size * K, dtype)

        # initial prediction (B, K, W, H, C)
        with tf.name_scope('pred_init'):
            pred_shape = tf.stack([batch_size, K] + self.input_shape.as_list())
            pred = tf.ones(shape=pred_shape, dtype=dtype) * self.pred_init

        # initial gamma (B, K, W, H, 1)
        with tf.name_scope('gamma_init'):
            gamma_shape = self.gamma_shape.as_list()
            shape = tf.stack([batch_size, K] + gamma_shape)

            # init with Gaussian distribution
            gamma = tf.abs(tf.random_normal(shape, dtype=dtype))
            gamma /= tf.reduce_sum(gamma, 1, keep_dims=True)

            # init with all 1 if K = 1
            if K == 1:
                gamma = tf.ones_like(gamma)

            return h, pred, gamma
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号