tools.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:monogreedy 作者: jinjunqi 项目源码 文件源码
def average_models(best, L=6, model_dir='', model_name='ra.h5'):
    print '... merging'
    print '{} {:d}-{:d}'.format(model_dir, best-L/2, best+L/2)
    params = {}
    side_info = {}
    attrs = {}
    for i in xrange(max(best-L/2, 0), best+L/2):
        with h5py.File(osp.join(model_dir, model_name+'.'+str(i)), 'r') as f:
            for k, v in f.attrs.items():
                attrs[k] = v
            for p in f.keys():
                if '#' not in p:
                    side_info[p] = f[p][...]
                elif p in params:
                    params[p] += np.array(f[p]).astype('float32') / L
                else:
                    params[p] = np.array(f[p]).astype('float32') / L
    with h5py.File(osp.join(model_dir, model_name+'.merge'), 'w') as f:
        for p in params.keys():
            f[p] = params[p]
        for s in side_info.keys():
            f[s] = side_info[s]
        for k, v in attrs.items():
            f.attrs[k] = v
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号