layers.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:hLSTMat 作者: zhaoluffy 项目源码 文件源码
def param_init_lstm_cond(self, options, params,
                             prefix='lstm_cond', nin=None, dim=None, dimctx=None):
        if nin == None:
            nin = options['dim']
        if dim == None:
            dim = options['dim']
        if dimctx == None:
            dimctx = options['dim']
        # input to LSTM
        W = numpy.concatenate([norm_weight(nin, dim),
                               norm_weight(nin, dim),
                               norm_weight(nin, dim),
                               norm_weight(nin, dim)], axis=1)
        params[_p(prefix, 'W')] = W

        # LSTM to LSTM
        U = numpy.concatenate([ortho_weight(dim),
                               ortho_weight(dim),
                               ortho_weight(dim),
                               ortho_weight(dim)], axis=1)
        params[_p(prefix, 'U')] = U

        # bias to LSTM
        params[_p(prefix, 'b')] = numpy.zeros((4 * dim,)).astype('float32')

        # context to LSTM
        # Wc = norm_weight(dimctx, dim * 4)
        # params[_p(prefix, 'Wc')] = Wc

        # attention: context -> hidden
        Wc_att = norm_weight(dimctx, ortho=False)
        params[_p(prefix, 'Wc_att')] = Wc_att

        # attention: LSTM -> hidden
        Wd_att = norm_weight(dim, dimctx)
        params[_p(prefix, 'Wd_att')] = Wd_att

        # attention: hidden bias
        b_att = numpy.zeros((dimctx,)).astype('float32')
        params[_p(prefix, 'b_att')] = b_att

        # attention:
        U_att = norm_weight(dimctx, 1)
        params[_p(prefix, 'U_att')] = U_att
        c_att = numpy.zeros((1,)).astype('float32')
        params[_p(prefix, 'c_tt')] = c_att

        if options['selector']:
            # attention: selector
            W_sel = norm_weight(dim, 1)
            params[_p(prefix, 'W_sel')] = W_sel
            b_sel = numpy.float32(0.)
            params[_p(prefix, 'b_sel')] = b_sel

        return params
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号