def init_hidden_for(self, enc_hidden):
"""
Creates a variable to be fed as init hidden step.
Returns:
--------
torch.Tensor(num_layers x batch x hid_dim)
"""
# unpack
if self.cell.startswith('LSTM'):
h_0, _ = enc_hidden
else:
h_0 = enc_hidden
# compute h_0
if self.train_init:
h_0 = self.h_0.repeat(1, h_0.size(1), 1)
else:
if not self.reuse_hidden:
h_0 = h_0.zeros_like(h_0)
if self.add_init_jitter:
h_0 = h_0 + torch.normal(torch.zeros_like(h_0), 0.3)
# pack
if self.cell.startswith('LSTM'):
return h_0, h_0.zeros_like(h_0)
else:
return h_0
评论列表
文章目录