solver.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:deepwater-nae 作者: h2oai 项目源码 文件源码
def net_def(self, phase):
        print('sizes', self.cmd.sizes, file = sys.stderr)
        print('types', self.cmd.types, file = sys.stderr)
        if len(self.cmd.sizes) != len(self.cmd.types):
            raise Exception

        n = caffe.NetSpec()
        name = ''

        for i in range(len(self.cmd.types)):
            if self.cmd.types[i] == 'data':
                name = 'data'
                if phase == caffe.TRAIN:
                    n[name], n.label = L.Python(
                        module = 'solver',
                        layer = 'DataLayer',
                        ntop = 2,
                    )
                else:
                    n[name] = L.Python(
                        module = 'solver',
                        layer = 'DataLayer',
                    )

            else:
                fc = L.InnerProduct(
                    n[name],
                    inner_product_param = {'num_output': self.cmd.sizes[i],
                                           'weight_filler': {'type': 'xavier',
                                                             'std': 0.1},
                                           'bias_filler': {'type': 'constant',
                                                           'value': 0}})
                name = 'fc%d' % i
                n[name] = fc

                if self.cmd.types[i] == 'relu':
                    relu = L.ReLU(n[name], in_place = True)
                    name = 'relu%d' % i
                    n[name] = relu
                elif self.cmd.types[i] == 'loss':
                    if self.cmd.regression:
                        if phase == caffe.TRAIN:
                            n.loss = L.EuclideanLoss(n[name], n.label)
                    else:
                        if phase == caffe.TRAIN:
                            n.loss = L.SoftmaxWithLoss(n[name], n.label)
                        else:
                            n.output = L.Softmax(n[name])
                else:
                    raise Exception('TODO unsupported: ' + self.cmd.types[i])

        return n.to_proto()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号