def init_hidden_for(self, inp):
size = (self.num_layers, inp.size(1), self.hid_dim)
# create h_0
if self.train_init:
h_0 = self.h_0.repeat(1, inp.size(1), 1)
else:
h_0 = Variable(inp.data.new(*size).zero_(),
volatile=not self.training)
# eventualy add jitter
if self.add_init_jitter:
h_0 = h_0 + torch.normal(torch.zeros_like(h_0), 0.3)
# return
if self.cell.startswith('LSTM'):
return h_0, h_0.zeros_like(h_0)
else:
return h_0
评论列表
文章目录