def init_hx(self, xs):
hx_shape = self.n_layers * self.direction
with cuda.get_device_from_id(self._device_id):
if args.chainer_version_check[0] == 2:
hx = chainer.Variable(
self.xp.zeros((hx_shape, xs.data.shape[1], self.out_size),
dtype=xs.dtype))
else:
hx = chainer.Variable(
self.xp.zeros((hx_shape, xs.data.shape[1], self.out_size),
dtype=xs.dtype),
volatile='auto')
return hx
评论列表
文章目录