net.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:3D-R2N2 作者: chrischoy 项目源码 文件源码
def __init__(self, random_seed=dt.datetime.now().microsecond, compute_grad=True):
        self.rng = np.random.RandomState(random_seed)

        self.batch_size = cfg.CONST.BATCH_SIZE
        self.img_w = cfg.CONST.IMG_W
        self.img_h = cfg.CONST.IMG_H
        self.n_vox = cfg.CONST.N_VOX
        self.compute_grad = compute_grad

        # (self.batch_size, 3, self.img_h, self.img_w),
        # override x and is_x_tensor4 when using multi-view network
        self.x = tensor.tensor4()
        self.is_x_tensor4 = True

        # (self.batch_size, self.n_vox, 2, self.n_vox, self.n_vox),
        self.y = tensor5()

        self.activations = []  # list of all intermediate activations
        self.loss = []  # final loss
        self.output = []  # final output
        self.error = []  # final output error
        self.params = []  # all learnable params
        self.grads = []  # will be filled out automatically
        self.setup()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号