def forward_one_step(self, X):
batchsize = X.shape[0]
seq_length = X.shape[1]
ksize = self.kernel_size
if seq_length < ksize:
self.reset_state()
return self.__call__(X, return_last=True)
xt = X[:, -ksize:]
enmbedding = self.embed(xt)
enmbedding = F.swapaxes(enmbedding, 1, 2)
residual_input = enmbedding if self.ndim_h == self.ndim_embedding else 0
out_data = self._forward_layer_one_step(0, enmbedding)[:, :, -ksize:]
for layer_index in xrange(1, self.num_blocks * self.num_layers_per_block):
out_data = self._forward_layer_one_step(layer_index, out_data)[:, :, -ksize:]
if (layer_index + 1) % self.num_layers_per_block == 0:
if self.using_dropout:
out_data = F.dropout(out_data, ratio=self.dropout)
out_data += residual_input
residual_input = out_data
out_data = out_data[..., -1, None]
out_data = self.dense(out_data)
out_data = F.reshape(F.swapaxes(out_data, 1, 2), (-1, self.vocab_size))
return out_data
评论列表
文章目录