def _define_vars(self, params, **kwargs):
with ops.device(self.device_assigner.get_device(self.layer_num)):
self.tree_parameters = variable_scope.get_variable(
name='hard_tree_parameters_%d' % self.layer_num,
shape=[params.num_nodes, params.num_features],
initializer=variable_scope.truncated_normal_initializer(
mean=params.weight_init_mean, stddev=params.weight_init_std))
self.tree_thresholds = variable_scope.get_variable(
name='hard_tree_thresholds_%d' % self.layer_num,
shape=[params.num_nodes],
initializer=variable_scope.truncated_normal_initializer(
mean=params.weight_init_mean, stddev=params.weight_init_std))
评论列表
文章目录