nettrainer.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:deep-prior 作者: moberweger 项目源码 文件源码
def loadMacroBatch(self, macro_idx):
        """
        Make sure that macro batch is loaded in the shared variable
        :param macro_idx: macro batch index
        :return: None
        """
        if macro_idx != self.currentMacroBatch:
                # last macro batch is handled separately, as it is padded
                if self.isLastMacroBatch(macro_idx):
                    start_idx = 0
                    end_idx = self.getNumSamplesPerMacroBatch()
                    print("Loading last macro batch {}, start idx {}, end idx {}".format(macro_idx, start_idx, end_idx))
                    self.replaceTrainingData(start_idx, end_idx, last=True)
                    # remember current macro batch index
                    self.currentMacroBatch = macro_idx
                else:
                    start_idx = macro_idx * self.getNumSamplesPerMacroBatch()
                    end_idx = min((macro_idx + 1) * self.getNumSamplesPerMacroBatch(), self.train_data_xDB.shape[0])
                    print("Loading macro batch {}, start idx {}, end idx {}".format(macro_idx, start_idx, end_idx))
                    self.replaceTrainingData(start_idx, end_idx)
                    # remember current macro batch index
                    self.currentMacroBatch = macro_idx
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号