def feed(self, batch):
"""
feed one batch to placeholders by constructing the feed dict
:param batch: a Batch object
:return: feed dict of inputs
"""
input_feed = {}
input_feed[self.inputs_.name] = batch.inputs
input_feed[self.targets_.name] = batch.targets
input_feed[self.mask_.name] = batch.masks
input_feed[self.dests_label_.name] = batch.dests
input_feed[self.seq_len_.name] = batch.seq_lens
if self.logits_mask__ is not None:
values = np.ones(len(batch.adj_indices), np.float32)
shape = np.array([np.size(batch.inputs), self.config.state_size], dtype=np.int32)
input_feed[self.logits_mask__] = tf.SparseTensorValue(batch.adj_indices, values, shape)
input_feed[self.lr_] = self.config.lr
if self.sub_onehot_targets_ is not None:
input_feed[self.sub_onehot_targets_] = batch.sub_onehot_target
return input_feed
评论列表
文章目录