def _get_train_ops(self, features, targets):
"""Method that builds model graph and returns trainer ops.
Args:
features: `Tensor` or `dict` of `Tensor` objects.
targets: `Tensor` or `dict` of `Tensor` objects.
Returns:
Tuple of train `Operation` and loss `Tensor`.
"""
features, _, spec = data_ops.ParseDataTensorOrDict(features)
labels = data_ops.ParseLabelTensorOrDict(targets)
_assert_float32(features)
_assert_float32(labels)
graph_builder = self.graph_builder_class(
self.params, device_assigner=self.device_assigner,
**self.construction_args)
epoch = None
if self.data_feeder:
epoch = self.data_feeder.make_epoch_variable()
train = control_flow_ops.group(
graph_builder.training_graph(
features, labels, data_spec=spec, epoch=epoch,
**self.training_args),
state_ops.assign_add(contrib_framework.get_global_step(), 1))
self.training_loss = graph_builder.training_loss(features, targets)
return train, self.training_loss
评论列表
文章目录