resnet_imagenet_model_multi_wrapper.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:tfplus 作者: renmengye 项目源码 文件源码
def build_input(self):
        results = {}
        phase_train = self.add_input_var('phase_train', None, 'bool')
        results['phase_train'] = phase_train
        inp_depth = self.get_option('inp_depth')
        orig_x = []
        for ii in xrange(self.num_replica):
            with tf.name_scope('%s_%d' % ('replica', ii)) as scope:
                device = '/gpu:{}'.format(ii)
                with tf.device(device):
                    x_ = self.add_input_var('x_{}'.format(
                        ii), [None, None, None, inp_depth], 'float')
                    results['x_{}'.format(ii)] = x_
                    y_gt_ = self.add_input_var('y_gt_{}'.format(ii), [
                                               None, NUM_CLS], 'float')
                    results['y_gt_{}'.format(ii)] = y_gt_
                    orig_x.append(
                        (x_ + self.sub_models[0].res_net._img_mean) / 255.0)
                    # self.log.error(x_.device)
        # self.log.fatal('')
        self.register_var('orig_x', tf.concat(0, orig_x))
        return results
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号