def init_hidden_for(self, inp):
batch_size = inp.size(1)
size = (self.num_dirs * self.num_layers, batch_size, self.hid_dim)
if self.train_init:
h_0 = self.h_0.repeat(1, batch_size, 1)
else:
h_0 = inp.data.new(*size).zero_()
h_0 = Variable(h_0, volatile=not self.training)
if self.add_init_jitter:
h_0 = h_0 + torch.normal(torch.zeros_like(h_0), 0.3)
if self.cell.startswith('LSTM'):
# compute memory cell
c_0 = inp.data.new(*size).zero_()
c_0 = Variable(c_0, volatile=not self.training)
return h_0, c_0
else:
return h_0
评论列表
文章目录