network.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:Automatic_Group_Photography_Enhancement 作者: Yuliang-Zou 项目源码 文件源码
def extract(self, data_path, session, saver):
        saver.restore(session, data_path)
        scopes = ['conv1_1','conv1_2','conv2_1','conv2_2','conv3_1','conv3_2','conv3_3','conv4_1','conv4_2','conv4_3','conv5_1','conv5_2','conv5_3','rpn_conv/3x3','rpn_cls_score','rpn_bbox_pred','fc6','fc7','cls_score','bbox_pred']
        data_dict = {}
        for scope in scopes:
            # Freezed layers
            if scope in ['conv1_1','conv1_2','conv2_1','conv2_2']:
                [w, b] = tf.get_collection(tf.GraphKeys.VARIABLES, scope=scope)
            # We don't need momentum variables
            else:
                [w, b] = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
            data_dict[scope] = {'weights':w.eval(), 'biases':b.eval()}
        file_name = data_path[0:-5]
        np.save(file_name, data_dict)
        ipdb.set_trace()       
        return file_name + '.npy'
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号