resnet_imagenet_model_multi_wrapper.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:tfplus 作者: renmengye 项目源码 文件源码
def build(self, inp):
        # Divide input equally.
        self.lazy_init_var()
        inp_list = []
        output = []
        for ii in xrange(self.num_replica):
            with tf.name_scope('%s_%d' % ('replica', ii)) as scope:
                device = '/gpu:{}'.format(ii)
                with tf.device(device):
                    tf.get_variable_scope().reuse_variables()
                    inp_ = {
                        'x': inp['x_{}'.format(ii)],
                        'y_gt': inp['y_gt_{}'.format(ii)],
                        'phase_train': inp['phase_train']
                    }
                    output.append(self.sub_models[ii].build(inp_))
                    inp_list.append(inp_)
        self.output_list = output
        self.input_list = inp_list
        output = tf.concat(0, [oo['y_out'] for oo in output])
        self.register_var('y_out', output)
        output2 = tf.concat(0, [mm.get_var('score_out')
                                for mm in self.sub_models])
        self.register_var('score_out', output2)
        return {'y_out': output}
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号