encoder.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:seqmod 作者: emanjavacas 项目源码 文件源码
def init_hidden_for(self, inp):
        batch_size = inp.size(1)
        size = (self.num_dirs * self.num_layers, batch_size, self.hid_dim)

        if self.train_init:
            h_0 = self.h_0.repeat(1, batch_size, 1)
        else:
            h_0 = inp.data.new(*size).zero_()
            h_0 = Variable(h_0, volatile=not self.training)

        if self.add_init_jitter:
            h_0 = h_0 + torch.normal(torch.zeros_like(h_0), 0.3)

        if self.cell.startswith('LSTM'):
            # compute memory cell
            c_0 = inp.data.new(*size).zero_()
            c_0 = Variable(c_0, volatile=not self.training)
            return h_0, c_0
        else:
            return h_0
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号