tsp_seqarrayiter.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def __init__(self, data_arrays, time_steps, batch_size, nfeatures,
                 total_iterations=None):
        self.batch_size = batch_size
        self.time_steps = time_steps
        self.nfeatures = nfeatures
        self.index = 0

        # make sure input is in dict format
        if isinstance(data_arrays, dict):
            self.data_arrays = dict(data_arrays)
        else:
            raise ValueError("Must provide dict as input")

        # number of examples
        self.ndata = len(self.data_arrays['inp_txt'])

        # number of examples (as integer multiple of batch size)
        self.ndata = self.ndata // (self.batch_size) * self.batch_size

        self.nbatches = self.ndata // self.batch_size

        if self.ndata < self.batch_size:
            raise ValueError('Number of examples is smaller than the batch size')

        self.total_iterations = self.nbatches if total_iterations is None else total_iterations

        # reshape array for batch and batch size dimensions
        self.data_arrays['inp_txt'] = \
            self.data_arrays['inp_txt'][:self.ndata][:][:].reshape(
            self.batch_size,
            self.nbatches,
            self.time_steps,
            self.nfeatures)

        self.data_arrays['tgt_txt'] = \
            self.data_arrays['tgt_txt'][:self.ndata][:].reshape(
            self.batch_size,
            self.nbatches,
            self.time_steps)

        self.data_arrays['teacher_tgt'] = \
            self.data_arrays['teacher_tgt'][:self.ndata][:][:].reshape(
            self.batch_size,
            self.nbatches,
            self.time_steps,
            self.nfeatures)

        # Teacher Forcing
        self.data_arrays['teacher_tgt'] = np.roll(self.data_arrays['teacher_tgt'], shift=1, axis=2)
        # put a start token (0, 0) as the first decoder input
        for i in range(self.batch_size):
            for j in range(self.nbatches):
                for k in range(self.nfeatures):
                    np.put(self.data_arrays['teacher_tgt'][i][j][0], [k], [0])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号