layers.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:ip-avsr 作者: lzuwei 项目源码 文件源码
def create_pretrained_lstm(lstm_weights, prefix, l_incoming, l_mask, hidden_units, cell_parameters, gate_parameters,
                           name, use_peepholes=False, backwards=False):
    l_lstm = LSTMLayer(
        l_incoming, hidden_units, peepholes=use_peepholes,
        # We need to specify a separate input for masks
        mask_input=l_mask,
        # Here, we supply the gate parameters for each gate
        ingate=gate_parameters, forgetgate=gate_parameters,
        cell=cell_parameters, outgate=gate_parameters,
        # We'll learn the initialization and use gradient clipping
        learn_init=True, grad_clipping=5., name=name, backwards=backwards)

    l_lstm.W_hid_to_cell.container.data = lstm_weights['{}_w_hid_to_cell'.format(prefix)].astype('float32')
    l_lstm.W_hid_to_forgetgate.container.data = lstm_weights['{}_w_hid_to_forgetgate'.format(prefix)].astype('float32')
    l_lstm.W_hid_to_ingate.container.data = lstm_weights['{}_w_hid_to_ingate'.format(prefix)].astype('float32')
    l_lstm.W_hid_to_outgate.container.data = lstm_weights['{}_w_hid_to_outgate'.format(prefix)].astype('float32')
    l_lstm.W_in_to_cell.container.data = lstm_weights['{}_w_in_to_cell'.format(prefix)].astype('float32')
    l_lstm.W_in_to_forgetgate.container.data = lstm_weights['{}_w_in_to_forgetgate'.format(prefix)].astype('float32')
    l_lstm.W_in_to_ingate.container.data = lstm_weights['{}_w_in_to_ingate'.format(prefix)].astype('float32')
    l_lstm.W_in_to_outgate.container.data = lstm_weights['{}_w_in_to_outgate'.format(prefix)].astype('float32')
    l_lstm.b_cell.container.data = lstm_weights['{}_b_cell'.format(prefix)].astype('float32').reshape((-1,))
    l_lstm.b_forgetgate.container.data = lstm_weights['{}_b_forgetgate'.format(prefix)].astype('float32').reshape((-1,))
    l_lstm.b_ingate.container.data = lstm_weights['{}_b_ingate'.format(prefix)].astype('float32').reshape((-1,))
    l_lstm.b_outgate.container.data = lstm_weights['{}_b_outgate'.format(prefix)].astype('float32').reshape((-1,))
    return l_lstm
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号