def load_weights(model, sess, weight_file):
"""
Load weights from given weight file (used to load pretrain weight of vgg model)
Args:
model : model to restore variable to
sess : tensorflow session
weight_file : weight file name
"""
weights = np.load(weight_file)
keys = sorted(weights.keys())
for i, k in enumerate(keys):
if i <= 29:
print('-- %s %s --' % (i,k))
print(np.shape(weights[k]))
sess.run(model.parameters_conv[i].assign(weights[k]))
operations.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录