utils.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:chinese-char-rnn 作者: indiejoseph 项目源码 文件源码
def create_batches(self):
    self.num_batches = int(self.train.size / (self.batch_size * self.seq_length))
    self.num_valid_batches = int(self.valid.size / (self.batch_size * self.seq_length))

    # When the data (tensor) is too small, let's give them a better error message
    if self.num_batches == 0:
      assert False, "Not enough data. Make seq_length and batch_size small."

    self.train = self.train[:self.num_batches * self.batch_size * self.seq_length]
    self.valid = self.valid[:self.num_valid_batches * self.batch_size * self.seq_length]
    xdata = self.train
    ydata = np.copy(self.train)
    ydata[:-1] = xdata[1:]
    ydata[-1] = xdata[0]
    x_valid = self.valid
    y_valid = np.copy(self.valid)
    y_valid[:-1] = x_valid[1:]
    y_valid[-1] = x_valid[0]
    self.x_valid = np.split(x_valid.reshape(self.batch_size, -1), self.num_valid_batches, 1)
    self.y_valid = np.split(y_valid.reshape(self.batch_size, -1), self.num_valid_batches, 1)
    self.x_batches = np.split(xdata.reshape(self.batch_size, -1), self.num_batches, 1)
    self.y_batches = np.split(ydata.reshape(self.batch_size, -1), self.num_batches, 1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号