extract_encoder_from_model.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:ip-avsr 作者: lzuwei 项目源码 文件源码
def main():
    options = parse_options()
    print(options)
    window = T.iscalar('theta')
    inputs1 = T.tensor3('inputs1', dtype='float32')
    mask = T.matrix('mask', dtype='uint8')
    shape = [int(i) for i in options['shape'].split(',')]
    nonlinearities = [select_nonlinearity(s) for s in options['nonlinearities'].split(',')]
    network = deltanet_majority_vote.load_saved_model(options['input'],
                                                      (shape, nonlinearities),
                                                      (None, None, options['input_dim']), inputs1, (None, None), mask,
                                                      options['lstm_size'], window, options['output_classes'],
                                                      use_blstm=options['use_blstm'])
    d = deltanet_majority_vote.extract_encoder_weights(network, ['fc1', 'fc2', 'fc3', 'bottleneck'],
                                                       [('w1', 'b1'), ('w2', 'b2'), ('w3', 'b3'), ('w4', 'b4')])
    expected_keys = ['w1', 'w2', 'w3', 'w4', 'b1', 'b2', 'b3', 'b4']
    keys = d.keys()
    for k in keys:
        assert k in expected_keys
        assert type(d[k]) == np.ndarray
    if 'output' in options:
        print('save extracted weights to {}'.format(options['output']))
        save_mat(d, options['output'])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号