def create_training_output(self, shared_resources: SharedResources,
training_input_tensors: Mapping[TensorPort, tf.Tensor]) \
-> Mapping[TensorPort, tf.Tensor]:
"""
This function needs to be implemented in order to define how the module produces tensors only used
during training given tensors corresponding to the ones defined by `training_input_ports`, which might include
tensors corresponding to ports defined by `output_ports`. This sub-graph should only be created during training.
Args:
shared_resources: contains resources shared by modules, such as hyper-parameters or vocabularies.
training_input_tensors: a mapping from training input tensorports to tensors.
Returns:
mapping from defined training output ports to their tensors.
"""
raise NotImplementedError
评论列表
文章目录