lm.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:seqmod 作者: emanjavacas 项目源码 文件源码
def init_hidden_for(self, inp):
        size = (self.num_layers, inp.size(1), self.hid_dim)
        # create h_0
        if self.train_init:
            h_0 = self.h_0.repeat(1, inp.size(1), 1)
        else:
            h_0 = Variable(inp.data.new(*size).zero_(),
                           volatile=not self.training)
        # eventualy add jitter
        if self.add_init_jitter:
            h_0 = h_0 + torch.normal(torch.zeros_like(h_0), 0.3)
        # return
        if self.cell.startswith('LSTM'):
            return h_0, h_0.zeros_like(h_0)
        else:
            return h_0
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号